diff options
author | Raghuram Subramani <raghus2247@gmail.com> | 2024-05-29 18:05:29 +0530 |
---|---|---|
committer | Raghuram Subramani <raghus2247@gmail.com> | 2024-05-29 18:08:53 +0530 |
commit | a4c99c97b66c6aed0737430aa9bdeb8ec64e3d9f (patch) | |
tree | a90c2b674a46d30b8c5bba59cd4000584089c7cb | |
parent | 3174341025787088358f5742bfc3e1e4e46fb9b8 (diff) |
The real initial commit
Diffstat (limited to '')
-rwxr-xr-x | example.py | 19 | ||||
-rw-r--r-- | flake.lock | 27 | ||||
-rw-r--r-- | flake.nix | 27 | ||||
-rw-r--r-- | src/graph.py | 56 | ||||
-rw-r--r-- | src/value.py | 20 |
5 files changed, 149 insertions, 0 deletions
diff --git a/example.py b/example.py new file mode 100755 index 0000000..218adc3 --- /dev/null +++ b/example.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +from src.value import Value +from src.graph import Graph + +a = Value(2, label='a') +b = Value(-3, label='b') +c = Value(10, label='c') +f = Value(-2, label='f') + +d = a * b; d.label = 'd' +e = d + c; e.label = 'e' +L = e * f; L.label = 'L' + +L.grad = 1.0 +e.grad = -2.0 +f.grad = 4.0 + +Graph(L).show() diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..089a9ba --- /dev/null +++ b/flake.lock @@ -0,0 +1,27 @@ +{ + "nodes": { + "nixpkgs": { + "locked": { + "lastModified": 1716509168, + "narHash": "sha256-4zSIhSRRIoEBwjbPm3YiGtbd8HDWzFxJjw5DYSDy1n8=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "bfb7a882678e518398ce9a31a881538679f6f092", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..915e9f5 --- /dev/null +++ b/flake.nix @@ -0,0 +1,27 @@ +{ + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + }; + + outputs = { self, nixpkgs, ... }: + let + pkgs = import nixpkgs { system = "x86_64-linux"; }; + in { + devShells.x86_64-linux.default = pkgs.mkShell { + buildInputs = with pkgs; [ + (python312.withPackages (python-pkgs: [ + python-pkgs.numpy + python-pkgs.matplotlib + python-pkgs.graphviz + ])) + + graphviz + ]; + + shellHook = '' + tmux -L autograd + exit + ''; + }; + }; +} diff --git a/src/graph.py b/src/graph.py new file mode 100644 index 0000000..2067071 --- /dev/null +++ b/src/graph.py @@ -0,0 +1,56 @@ +import io + +from graphviz import Digraph +import matplotlib.pyplot as plt +import matplotlib.image as mpimg + +from .value import Value + +class Graph: + def __init__(self, root: Value) -> None: + self.dot = Digraph(format='png', graph_attr={ 'rankdir': 'LR' }) + + self.nodes = set() + self.edges = set() + + self.build(root) + self.draw() + + def build(self, x: Value): + self.nodes.add(x) + for child in x._prev: + # Add a line from child to x + self.edges.add((child, x)) + + # Recursively build for all children + self.build(child) + + def draw(self): + for node in self.nodes: + # UID of the node + uid = str(id(node)) + + # Create a node in the graph + self.dot.node( + name=uid, + label=f"{node.label} | {node.data} | grad: {node.grad}", + shape='record' + ) + + if node._op: + # Create a node for the operation + self.dot.node(name=uid + node._op, label=node._op) + + # Add a line from the operation node to the current node + self.dot.edge(uid + node._op, uid) + + for node1, node2 in self.edges: + # Add a line from node1 to node2's operation node + self.dot.edge(str(id(node1)), str(id(node2)) + node2._op) + + def show(self): + fp = io.BytesIO(self.dot.pipe(format='jpeg')) + with fp: + img = mpimg.imread(fp, format='jpeg') + plt.imshow(img) + plt.show() diff --git a/src/value.py b/src/value.py new file mode 100644 index 0000000..af0ab35 --- /dev/null +++ b/src/value.py @@ -0,0 +1,20 @@ +class Value: + def __init__(self, data, _children=(), _op='', label='') -> None: + self.label = label + + self.data = float(data) + self.grad = 0.0 + + self._prev = set(_children) + self._op = _op + + def __repr__(self) -> str: + return f'Value({self.data})' + + def __add__(self, y): + result = self.data + y.data + return Value(result, (self, y), _op='+') + + def __mul__(self, y): + result = self.data * y.data + return Value(result, (self, y), _op='*') |