blob: 9aecde1ff7dbd16aedab6bff7d300295505e63f0 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
import io
from graphviz import Digraph
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from .scalar import Scalar
class Graph:
def __init__(self, root: Scalar) -> None:
self.dot = Digraph(format='png', graph_attr={ 'rankdir': 'LR' })
self.nodes = set()
self.edges = set()
self.build(root)
self.draw()
def build(self, x: Scalar):
self.nodes.add(x)
for child in x._prev:
# Add a line from child to x
self.edges.add((child, x))
# Recursively build for all children
self.build(child)
def draw(self):
for node in self.nodes:
# UID of the node
uid = str(id(node))
# Create a node in the graph
self.dot.node(
name=uid,
label=f"{node.label} | {node.data} | grad: {node.grad}",
shape='record'
)
if node._op:
# Create a node for the operation
self.dot.node(name=uid + node._op, label=node._op)
# Add a line from the operation node to the current node
self.dot.edge(uid + node._op, uid)
for node1, node2 in self.edges:
# Add a line from node1 to node2's operation node
self.dot.edge(str(id(node1)), str(id(node2)) + node2._op)
def show(self):
fp = io.BytesIO(self.dot.pipe(format='jpeg'))
with fp:
img = mpimg.imread(fp, format='jpeg')
plt.imshow(img)
plt.show()
|