diff options
Diffstat (limited to '')
-rw-r--r-- | src/scalar.py | 70 |
1 files changed, 62 insertions, 8 deletions
diff --git a/src/scalar.py b/src/scalar.py index c8c7601..37495f8 100644 --- a/src/scalar.py +++ b/src/scalar.py @@ -12,26 +12,48 @@ class Scalar: self._backward = lambda: None - def __repr__(self) -> str: - return f'Scalar({self.label}: {self.data})' - def __add__(self, y): + y = y if isinstance(y, Scalar) else Scalar(y) result = Scalar(self.data + y.data, (self, y), _op='+') def _backward(): - self.grad = result.grad - y.grad = result.grad + self.grad += result.grad + y.grad += result.grad self._backward = _backward return result def __mul__(self, y): + y = y if isinstance(y, Scalar) else Scalar(y) result = Scalar(self.data * y.data, (self, y), _op='*') def _backward(): - self.grad = y.data * result.grad - y.grad = self.data * result.grad + self.grad += y.data * result.grad + y.grad += self.data * result.grad + + self._backward = _backward + + return result + + def __pow__(self, y): + assert isinstance(y, (int, float)) + result = Scalar(self.data ** y, (self, ), _op=f'** {y}') + + def _backward(): + self.grad += (y * self.data ** (y - 1)) * result.grad + + self._backward = _backward + + return result + + def exp(self): + x = self.data + e = math.exp(x) + result = Scalar(e, (self, ), 'exp') + + def _backward(): + self.grad += result.data * result.grad self._backward = _backward @@ -43,7 +65,7 @@ class Scalar: result = Scalar(t, (self, ), 'tanh') def _backward(): - self.grad = (1 - (t ** 2)) * result.grad + self.grad += (1 - t ** 2) * result.grad self._backward = _backward @@ -64,3 +86,35 @@ class Scalar: for child in children: child._backward() + + def zero_grad(self): + self.grad = 0.0 + children = self.build_children() + + for child in children: + child.grad = 0.0 + + def __truediv__(self, y): + return self * y ** -1 + + def __rtruediv__(self, y): + return self * y ** -1 + + def __neg__(self): + return self * -1 + + def __sub__(self, y): + return self + (-y) + + def __rsub__(self, y): + return self + (-y) + + def __radd__(self, y): + return self + y + + def __rmul__(self, y): + return self * y + + def __repr__(self) -> str: + return f'Scalar({self.data})' + |