aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xexample.py3
-rw-r--r--src/scalar.py24
2 files changed, 14 insertions, 13 deletions
diff --git a/example.py b/example.py
index aa6f787..66df486 100755
--- a/example.py
+++ b/example.py
@@ -48,7 +48,6 @@ x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1w1 + x2w2'
L = x1w1x2w2 + b; L.label = 'L'
o = L.tanh(); o.label = 'o'
-print(o)
o.zero_grad()
o.backward()
@@ -59,9 +58,7 @@ e = 2 * L
f = e.exp()
a = f - 1
b = f + 1
-print(a, b)
o = a / b
-print(o)
o.zero_grad()
o.backward()
diff --git a/src/scalar.py b/src/scalar.py
index 37495f8..42a58ef 100644
--- a/src/scalar.py
+++ b/src/scalar.py
@@ -20,7 +20,7 @@ class Scalar:
self.grad += result.grad
y.grad += result.grad
- self._backward = _backward
+ result._backward = _backward
return result
@@ -32,7 +32,7 @@ class Scalar:
self.grad += y.data * result.grad
y.grad += self.data * result.grad
- self._backward = _backward
+ result._backward = _backward
return result
@@ -43,7 +43,7 @@ class Scalar:
def _backward():
self.grad += (y * self.data ** (y - 1)) * result.grad
- self._backward = _backward
+ result._backward = _backward
return result
@@ -55,7 +55,7 @@ class Scalar:
def _backward():
self.grad += result.data * result.grad
- self._backward = _backward
+ result._backward = _backward
return result
@@ -67,24 +67,29 @@ class Scalar:
def _backward():
self.grad += (1 - t ** 2) * result.grad
- self._backward = _backward
+ result._backward = _backward
return result
def build_children(self):
result = []
+ visited = set()
- result.append(self)
- for child in self._prev:
- result += child.build_children()
+ def build(v):
+ if v not in visited:
+ visited.add(v)
+ for child in v._prev:
+ build(child)
+ result.append(v)
+ build(self)
return result
def backward(self):
self.grad = 1.0
children = self.build_children()
- for child in children:
+ for child in reversed(children):
child._backward()
def zero_grad(self):
@@ -117,4 +122,3 @@ class Scalar:
def __repr__(self) -> str:
return f'Scalar({self.data})'
-