aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/scalar.py70
1 files changed, 62 insertions, 8 deletions
diff --git a/src/scalar.py b/src/scalar.py
index c8c7601..37495f8 100644
--- a/src/scalar.py
+++ b/src/scalar.py
@@ -12,26 +12,48 @@ class Scalar:
self._backward = lambda: None
- def __repr__(self) -> str:
- return f'Scalar({self.label}: {self.data})'
-
def __add__(self, y):
+ y = y if isinstance(y, Scalar) else Scalar(y)
result = Scalar(self.data + y.data, (self, y), _op='+')
def _backward():
- self.grad = result.grad
- y.grad = result.grad
+ self.grad += result.grad
+ y.grad += result.grad
self._backward = _backward
return result
def __mul__(self, y):
+ y = y if isinstance(y, Scalar) else Scalar(y)
result = Scalar(self.data * y.data, (self, y), _op='*')
def _backward():
- self.grad = y.data * result.grad
- y.grad = self.data * result.grad
+ self.grad += y.data * result.grad
+ y.grad += self.data * result.grad
+
+ self._backward = _backward
+
+ return result
+
+ def __pow__(self, y):
+ assert isinstance(y, (int, float))
+ result = Scalar(self.data ** y, (self, ), _op=f'** {y}')
+
+ def _backward():
+ self.grad += (y * self.data ** (y - 1)) * result.grad
+
+ self._backward = _backward
+
+ return result
+
+ def exp(self):
+ x = self.data
+ e = math.exp(x)
+ result = Scalar(e, (self, ), 'exp')
+
+ def _backward():
+ self.grad += result.data * result.grad
self._backward = _backward
@@ -43,7 +65,7 @@ class Scalar:
result = Scalar(t, (self, ), 'tanh')
def _backward():
- self.grad = (1 - (t ** 2)) * result.grad
+ self.grad += (1 - t ** 2) * result.grad
self._backward = _backward
@@ -64,3 +86,35 @@ class Scalar:
for child in children:
child._backward()
+
+ def zero_grad(self):
+ self.grad = 0.0
+ children = self.build_children()
+
+ for child in children:
+ child.grad = 0.0
+
+ def __truediv__(self, y):
+ return self * y ** -1
+
+ def __rtruediv__(self, y):
+ return self * y ** -1
+
+ def __neg__(self):
+ return self * -1
+
+ def __sub__(self, y):
+ return self + (-y)
+
+ def __rsub__(self, y):
+ return self + (-y)
+
+ def __radd__(self, y):
+ return self + y
+
+ def __rmul__(self, y):
+ return self * y
+
+ def __repr__(self) -> str:
+ return f'Scalar({self.data})'
+