diff options
author | Raghuram Subramani <raghus2247@gmail.com> | 2024-06-08 22:29:11 +0530 |
---|---|---|
committer | Raghuram Subramani <raghus2247@gmail.com> | 2024-06-08 22:29:11 +0530 |
commit | 40240b0b383abc2d3e81e2bcfe5e4b6d6fdfec2a (patch) | |
tree | 61c4b88b000bdc4566a76c090879ea6066343c9b /src/scalar.py | |
parent | fde98a1ee5e1ee4f5cc4e93654c4d3c2e5148576 (diff) |
Fix
Diffstat (limited to '')
-rw-r--r-- | src/scalar.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/src/scalar.py b/src/scalar.py index 37495f8..42a58ef 100644 --- a/src/scalar.py +++ b/src/scalar.py @@ -20,7 +20,7 @@ class Scalar: self.grad += result.grad y.grad += result.grad - self._backward = _backward + result._backward = _backward return result @@ -32,7 +32,7 @@ class Scalar: self.grad += y.data * result.grad y.grad += self.data * result.grad - self._backward = _backward + result._backward = _backward return result @@ -43,7 +43,7 @@ class Scalar: def _backward(): self.grad += (y * self.data ** (y - 1)) * result.grad - self._backward = _backward + result._backward = _backward return result @@ -55,7 +55,7 @@ class Scalar: def _backward(): self.grad += result.data * result.grad - self._backward = _backward + result._backward = _backward return result @@ -67,24 +67,29 @@ class Scalar: def _backward(): self.grad += (1 - t ** 2) * result.grad - self._backward = _backward + result._backward = _backward return result def build_children(self): result = [] + visited = set() - result.append(self) - for child in self._prev: - result += child.build_children() + def build(v): + if v not in visited: + visited.add(v) + for child in v._prev: + build(child) + result.append(v) + build(self) return result def backward(self): self.grad = 1.0 children = self.build_children() - for child in children: + for child in reversed(children): child._backward() def zero_grad(self): @@ -117,4 +122,3 @@ class Scalar: def __repr__(self) -> str: return f'Scalar({self.data})' - |