blob: 6c8ab53708ca0eaefa4b13b251933aacb86cd413 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
|
import tkinter as tk
from graphviz import Digraph
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from .scalar import Scalar
class Graph:
def __init__(self, root: Scalar) -> None:
self.dot = Digraph(format='png', graph_attr={ 'rankdir': 'LR' })
self.nodes = set()
self.edges = set()
self.build(root)
self.draw()
def build(self, x: Scalar):
self.nodes.add(x)
for child in x._prev:
# Add a line from child to x
self.edges.add((child, x))
# Recursively build for all children
self.build(child)
def draw(self):
for node in self.nodes:
# UID of the node
uid = str(id(node))
# Create a node in the graph
self.dot.node(
name=uid,
label=f"{node.label} | {node.data} | grad: {node.grad}",
shape='record'
)
if node._op:
# Create a node for the operation
self.dot.node(name=uid + node._op, label=node._op)
# Add a line from the operation node to the current node
self.dot.edge(uid + node._op, uid)
for node1, node2 in self.edges:
# Add a line from node1 to node2's operation node
self.dot.edge(str(id(node1)), str(id(node2)) + node2._op)
def show(self):
root = tk.Tk()
root.title('float')
data = self.dot.pipe(format='png')
img = tk.PhotoImage(data=data, format='png')
panel = tk.Label(root, image=img)
panel.pack();
tk.mainloop()
|