diff options
-rwxr-xr-x | example.py | 3 | ||||
-rw-r--r-- | src/scalar.py | 24 |
2 files changed, 14 insertions, 13 deletions
@@ -48,7 +48,6 @@ x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1w1 + x2w2' L = x1w1x2w2 + b; L.label = 'L' o = L.tanh(); o.label = 'o' -print(o) o.zero_grad() o.backward() @@ -59,9 +58,7 @@ e = 2 * L f = e.exp() a = f - 1 b = f + 1 -print(a, b) o = a / b -print(o) o.zero_grad() o.backward() diff --git a/src/scalar.py b/src/scalar.py index 37495f8..42a58ef 100644 --- a/src/scalar.py +++ b/src/scalar.py @@ -20,7 +20,7 @@ class Scalar: self.grad += result.grad y.grad += result.grad - self._backward = _backward + result._backward = _backward return result @@ -32,7 +32,7 @@ class Scalar: self.grad += y.data * result.grad y.grad += self.data * result.grad - self._backward = _backward + result._backward = _backward return result @@ -43,7 +43,7 @@ class Scalar: def _backward(): self.grad += (y * self.data ** (y - 1)) * result.grad - self._backward = _backward + result._backward = _backward return result @@ -55,7 +55,7 @@ class Scalar: def _backward(): self.grad += result.data * result.grad - self._backward = _backward + result._backward = _backward return result @@ -67,24 +67,29 @@ class Scalar: def _backward(): self.grad += (1 - t ** 2) * result.grad - self._backward = _backward + result._backward = _backward return result def build_children(self): result = [] + visited = set() - result.append(self) - for child in self._prev: - result += child.build_children() + def build(v): + if v not in visited: + visited.add(v) + for child in v._prev: + build(child) + result.append(v) + build(self) return result def backward(self): self.grad = 1.0 children = self.build_children() - for child in children: + for child in reversed(children): child._backward() def zero_grad(self): @@ -117,4 +122,3 @@ class Scalar: def __repr__(self) -> str: return f'Scalar({self.data})' - |