aboutsummaryrefslogtreecommitdiff
path: root/src/graph.py
diff options
context:
space:
mode:
authorRaghuram Subramani <raghus2247@gmail.com>2024-05-29 18:05:29 +0530
committerRaghuram Subramani <raghus2247@gmail.com>2024-05-29 18:08:53 +0530
commita4c99c97b66c6aed0737430aa9bdeb8ec64e3d9f (patch)
treea90c2b674a46d30b8c5bba59cd4000584089c7cb /src/graph.py
parent3174341025787088358f5742bfc3e1e4e46fb9b8 (diff)
The real initial commit
Diffstat (limited to '')
-rw-r--r--src/graph.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/src/graph.py b/src/graph.py
new file mode 100644
index 0000000..2067071
--- /dev/null
+++ b/src/graph.py
@@ -0,0 +1,56 @@
+import io
+
+from graphviz import Digraph
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+
+from .value import Value
+
+class Graph:
+ def __init__(self, root: Value) -> None:
+ self.dot = Digraph(format='png', graph_attr={ 'rankdir': 'LR' })
+
+ self.nodes = set()
+ self.edges = set()
+
+ self.build(root)
+ self.draw()
+
+ def build(self, x: Value):
+ self.nodes.add(x)
+ for child in x._prev:
+ # Add a line from child to x
+ self.edges.add((child, x))
+
+ # Recursively build for all children
+ self.build(child)
+
+ def draw(self):
+ for node in self.nodes:
+ # UID of the node
+ uid = str(id(node))
+
+ # Create a node in the graph
+ self.dot.node(
+ name=uid,
+ label=f"{node.label} | {node.data} | grad: {node.grad}",
+ shape='record'
+ )
+
+ if node._op:
+ # Create a node for the operation
+ self.dot.node(name=uid + node._op, label=node._op)
+
+ # Add a line from the operation node to the current node
+ self.dot.edge(uid + node._op, uid)
+
+ for node1, node2 in self.edges:
+ # Add a line from node1 to node2's operation node
+ self.dot.edge(str(id(node1)), str(id(node2)) + node2._op)
+
+ def show(self):
+ fp = io.BytesIO(self.dot.pipe(format='jpeg'))
+ with fp:
+ img = mpimg.imread(fp, format='jpeg')
+ plt.imshow(img)
+ plt.show()