diff options
author | Raghuram Subramani <raghus2247@gmail.com> | 2024-05-29 18:05:29 +0530 |
---|---|---|
committer | Raghuram Subramani <raghus2247@gmail.com> | 2024-05-29 18:08:53 +0530 |
commit | a4c99c97b66c6aed0737430aa9bdeb8ec64e3d9f (patch) | |
tree | a90c2b674a46d30b8c5bba59cd4000584089c7cb /src | |
parent | 3174341025787088358f5742bfc3e1e4e46fb9b8 (diff) |
The real initial commit
Diffstat (limited to '')
-rw-r--r-- | src/graph.py | 56 | ||||
-rw-r--r-- | src/value.py | 20 |
2 files changed, 76 insertions, 0 deletions
diff --git a/src/graph.py b/src/graph.py new file mode 100644 index 0000000..2067071 --- /dev/null +++ b/src/graph.py @@ -0,0 +1,56 @@ +import io + +from graphviz import Digraph +import matplotlib.pyplot as plt +import matplotlib.image as mpimg + +from .value import Value + +class Graph: + def __init__(self, root: Value) -> None: + self.dot = Digraph(format='png', graph_attr={ 'rankdir': 'LR' }) + + self.nodes = set() + self.edges = set() + + self.build(root) + self.draw() + + def build(self, x: Value): + self.nodes.add(x) + for child in x._prev: + # Add a line from child to x + self.edges.add((child, x)) + + # Recursively build for all children + self.build(child) + + def draw(self): + for node in self.nodes: + # UID of the node + uid = str(id(node)) + + # Create a node in the graph + self.dot.node( + name=uid, + label=f"{node.label} | {node.data} | grad: {node.grad}", + shape='record' + ) + + if node._op: + # Create a node for the operation + self.dot.node(name=uid + node._op, label=node._op) + + # Add a line from the operation node to the current node + self.dot.edge(uid + node._op, uid) + + for node1, node2 in self.edges: + # Add a line from node1 to node2's operation node + self.dot.edge(str(id(node1)), str(id(node2)) + node2._op) + + def show(self): + fp = io.BytesIO(self.dot.pipe(format='jpeg')) + with fp: + img = mpimg.imread(fp, format='jpeg') + plt.imshow(img) + plt.show() diff --git a/src/value.py b/src/value.py new file mode 100644 index 0000000..af0ab35 --- /dev/null +++ b/src/value.py @@ -0,0 +1,20 @@ +class Value: + def __init__(self, data, _children=(), _op='', label='') -> None: + self.label = label + + self.data = float(data) + self.grad = 0.0 + + self._prev = set(_children) + self._op = _op + + def __repr__(self) -> str: + return f'Value({self.data})' + + def __add__(self, y): + result = self.data + y.data + return Value(result, (self, y), _op='+') + + def __mul__(self, y): + result = self.data * y.data + return Value(result, (self, y), _op='*') |