aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xexample.py19
-rw-r--r--flake.lock27
-rw-r--r--flake.nix27
-rw-r--r--src/graph.py56
-rw-r--r--src/value.py20
5 files changed, 149 insertions, 0 deletions
diff --git a/example.py b/example.py
new file mode 100755
index 0000000..218adc3
--- /dev/null
+++ b/example.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python
+
+from src.value import Value
+from src.graph import Graph
+
+a = Value(2, label='a')
+b = Value(-3, label='b')
+c = Value(10, label='c')
+f = Value(-2, label='f')
+
+d = a * b; d.label = 'd'
+e = d + c; e.label = 'e'
+L = e * f; L.label = 'L'
+
+L.grad = 1.0
+e.grad = -2.0
+f.grad = 4.0
+
+Graph(L).show()
diff --git a/flake.lock b/flake.lock
new file mode 100644
index 0000000..089a9ba
--- /dev/null
+++ b/flake.lock
@@ -0,0 +1,27 @@
+{
+ "nodes": {
+ "nixpkgs": {
+ "locked": {
+ "lastModified": 1716509168,
+ "narHash": "sha256-4zSIhSRRIoEBwjbPm3YiGtbd8HDWzFxJjw5DYSDy1n8=",
+ "owner": "nixos",
+ "repo": "nixpkgs",
+ "rev": "bfb7a882678e518398ce9a31a881538679f6f092",
+ "type": "github"
+ },
+ "original": {
+ "owner": "nixos",
+ "ref": "nixos-unstable",
+ "repo": "nixpkgs",
+ "type": "github"
+ }
+ },
+ "root": {
+ "inputs": {
+ "nixpkgs": "nixpkgs"
+ }
+ }
+ },
+ "root": "root",
+ "version": 7
+}
diff --git a/flake.nix b/flake.nix
new file mode 100644
index 0000000..915e9f5
--- /dev/null
+++ b/flake.nix
@@ -0,0 +1,27 @@
+{
+ inputs = {
+ nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
+ };
+
+ outputs = { self, nixpkgs, ... }:
+ let
+ pkgs = import nixpkgs { system = "x86_64-linux"; };
+ in {
+ devShells.x86_64-linux.default = pkgs.mkShell {
+ buildInputs = with pkgs; [
+ (python312.withPackages (python-pkgs: [
+ python-pkgs.numpy
+ python-pkgs.matplotlib
+ python-pkgs.graphviz
+ ]))
+
+ graphviz
+ ];
+
+ shellHook = ''
+ tmux -L autograd
+ exit
+ '';
+ };
+ };
+}
diff --git a/src/graph.py b/src/graph.py
new file mode 100644
index 0000000..2067071
--- /dev/null
+++ b/src/graph.py
@@ -0,0 +1,56 @@
+import io
+
+from graphviz import Digraph
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+
+from .value import Value
+
+class Graph:
+ def __init__(self, root: Value) -> None:
+ self.dot = Digraph(format='png', graph_attr={ 'rankdir': 'LR' })
+
+ self.nodes = set()
+ self.edges = set()
+
+ self.build(root)
+ self.draw()
+
+ def build(self, x: Value):
+ self.nodes.add(x)
+ for child in x._prev:
+ # Add a line from child to x
+ self.edges.add((child, x))
+
+ # Recursively build for all children
+ self.build(child)
+
+ def draw(self):
+ for node in self.nodes:
+ # UID of the node
+ uid = str(id(node))
+
+ # Create a node in the graph
+ self.dot.node(
+ name=uid,
+ label=f"{node.label} | {node.data} | grad: {node.grad}",
+ shape='record'
+ )
+
+ if node._op:
+ # Create a node for the operation
+ self.dot.node(name=uid + node._op, label=node._op)
+
+ # Add a line from the operation node to the current node
+ self.dot.edge(uid + node._op, uid)
+
+ for node1, node2 in self.edges:
+ # Add a line from node1 to node2's operation node
+ self.dot.edge(str(id(node1)), str(id(node2)) + node2._op)
+
+ def show(self):
+ fp = io.BytesIO(self.dot.pipe(format='jpeg'))
+ with fp:
+ img = mpimg.imread(fp, format='jpeg')
+ plt.imshow(img)
+ plt.show()
diff --git a/src/value.py b/src/value.py
new file mode 100644
index 0000000..af0ab35
--- /dev/null
+++ b/src/value.py
@@ -0,0 +1,20 @@
+class Value:
+ def __init__(self, data, _children=(), _op='', label='') -> None:
+ self.label = label
+
+ self.data = float(data)
+ self.grad = 0.0
+
+ self._prev = set(_children)
+ self._op = _op
+
+ def __repr__(self) -> str:
+ return f'Value({self.data})'
+
+ def __add__(self, y):
+ result = self.data + y.data
+ return Value(result, (self, y), _op='+')
+
+ def __mul__(self, y):
+ result = self.data * y.data
+ return Value(result, (self, y), _op='*')