aboutsummaryrefslogtreecommitdiff
path: root/src/scalar.py
diff options
context:
space:
mode:
authorRaghuram Subramani <raghus2247@gmail.com>2024-06-08 22:29:11 +0530
committerRaghuram Subramani <raghus2247@gmail.com>2024-06-08 22:29:11 +0530
commit40240b0b383abc2d3e81e2bcfe5e4b6d6fdfec2a (patch)
tree61c4b88b000bdc4566a76c090879ea6066343c9b /src/scalar.py
parentfde98a1ee5e1ee4f5cc4e93654c4d3c2e5148576 (diff)
Fix
Diffstat (limited to 'src/scalar.py')
-rw-r--r--src/scalar.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/src/scalar.py b/src/scalar.py
index 37495f8..42a58ef 100644
--- a/src/scalar.py
+++ b/src/scalar.py
@@ -20,7 +20,7 @@ class Scalar:
self.grad += result.grad
y.grad += result.grad
- self._backward = _backward
+ result._backward = _backward
return result
@@ -32,7 +32,7 @@ class Scalar:
self.grad += y.data * result.grad
y.grad += self.data * result.grad
- self._backward = _backward
+ result._backward = _backward
return result
@@ -43,7 +43,7 @@ class Scalar:
def _backward():
self.grad += (y * self.data ** (y - 1)) * result.grad
- self._backward = _backward
+ result._backward = _backward
return result
@@ -55,7 +55,7 @@ class Scalar:
def _backward():
self.grad += result.data * result.grad
- self._backward = _backward
+ result._backward = _backward
return result
@@ -67,24 +67,29 @@ class Scalar:
def _backward():
self.grad += (1 - t ** 2) * result.grad
- self._backward = _backward
+ result._backward = _backward
return result
def build_children(self):
result = []
+ visited = set()
- result.append(self)
- for child in self._prev:
- result += child.build_children()
+ def build(v):
+ if v not in visited:
+ visited.add(v)
+ for child in v._prev:
+ build(child)
+ result.append(v)
+ build(self)
return result
def backward(self):
self.grad = 1.0
children = self.build_children()
- for child in children:
+ for child in reversed(children):
child._backward()
def zero_grad(self):
@@ -117,4 +122,3 @@ class Scalar:
def __repr__(self) -> str:
return f'Scalar({self.data})'
-