aboutsummaryrefslogtreecommitdiff
path: root/src/graph.py
blob: 2067071cb0338de7dfa5906382b139a024d0dd76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import io

from graphviz import Digraph
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from .value import Value

class Graph:
    def __init__(self, root: Value) -> None:
        self.dot = Digraph(format='png', graph_attr={ 'rankdir': 'LR' })

        self.nodes = set()
        self.edges = set()

        self.build(root)
        self.draw()

    def build(self, x: Value):
        self.nodes.add(x)
        for child in x._prev:
            # Add a line from child to x
            self.edges.add((child, x))

            # Recursively build for all children
            self.build(child)
    
    def draw(self):
        for node in self.nodes:
            # UID of the node
            uid = str(id(node))

            # Create a node in the graph
            self.dot.node(
                name=uid,
                label=f"{node.label} | {node.data} | grad: {node.grad}",
                shape='record'
            )

            if node._op:
                # Create a node for the operation
                self.dot.node(name=uid + node._op, label=node._op)

                # Add a line from the operation node to the current node
                self.dot.edge(uid + node._op, uid)

        for node1, node2 in self.edges:
            # Add a line from node1 to node2's operation node
            self.dot.edge(str(id(node1)), str(id(node2)) + node2._op)

    def show(self):
        fp = io.BytesIO(self.dot.pipe(format='jpeg'))
        with fp:
            img = mpimg.imread(fp, format='jpeg')
        plt.imshow(img)
        plt.show()