From 3bc24933f4128e76ccbd6e37155ff6cccb20a182 Mon Sep 17 00:00:00 2001 From: Raghuram Subramani Date: Sat, 8 Jun 2024 18:45:09 +0530 Subject: Automate backward propagation --- src/graph.py | 5 +---- src/scalar.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 52 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/graph.py b/src/graph.py index 6c8ab53..6bebb06 100644 --- a/src/graph.py +++ b/src/graph.py @@ -1,8 +1,5 @@ import tkinter as tk - from graphviz import Digraph -import matplotlib.pyplot as plt -import matplotlib.image as mpimg from .scalar import Scalar @@ -33,7 +30,7 @@ class Graph: # Create a node in the graph self.dot.node( name=uid, - label=f"{node.label} | {node.data} | grad: {node.grad}", + label=f"{node.label} | {node.data:.4f} | grad: {node.grad:.4f}", shape='record' ) diff --git a/src/scalar.py b/src/scalar.py index a67b7ae..c8c7601 100644 --- a/src/scalar.py +++ b/src/scalar.py @@ -1,3 +1,5 @@ +import math + class Scalar: def __init__(self, data, _children=(), _op='', label='') -> None: self.label = label @@ -7,14 +9,58 @@ class Scalar: self._prev = set(_children) self._op = _op + + self._backward = lambda: None def __repr__(self) -> str: - return f'Scalar({self.data})' + return f'Scalar({self.label}: {self.data})' def __add__(self, y): - result = self.data + y.data - return Scalar(result, (self, y), _op='+') + result = Scalar(self.data + y.data, (self, y), _op='+') + + def _backward(): + self.grad = result.grad + y.grad = result.grad + + self._backward = _backward + + return result def __mul__(self, y): - result = self.data * y.data - return Scalar(result, (self, y), _op='*') + result = Scalar(self.data * y.data, (self, y), _op='*') + + def _backward(): + self.grad = y.data * result.grad + y.grad = self.data * result.grad + + self._backward = _backward + + return result + + def tanh(self): + x = self.data + t = (math.exp(2 * x) - 1) / (math.exp(2 * x) + 1) + result = Scalar(t, (self, ), 'tanh') + + def _backward(): + self.grad = (1 - (t ** 2)) * result.grad + + self._backward = _backward + + return result + + def build_children(self): + result = [] + + result.append(self) + for child in self._prev: + result += child.build_children() + + return result + + def backward(self): + self.grad = 1.0 + children = self.build_children() + + for child in children: + child._backward() -- cgit v1.2.3