From fde98a1ee5e1ee4f5cc4e93654c4d3c2e5148576 Mon Sep 17 00:00:00 2001 From: Raghuram Subramani Date: Sat, 8 Jun 2024 22:20:34 +0530 Subject: Update --- src/scalar.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 8 deletions(-) (limited to 'src/scalar.py') diff --git a/src/scalar.py b/src/scalar.py index c8c7601..37495f8 100644 --- a/src/scalar.py +++ b/src/scalar.py @@ -12,26 +12,48 @@ class Scalar: self._backward = lambda: None - def __repr__(self) -> str: - return f'Scalar({self.label}: {self.data})' - def __add__(self, y): + y = y if isinstance(y, Scalar) else Scalar(y) result = Scalar(self.data + y.data, (self, y), _op='+') def _backward(): - self.grad = result.grad - y.grad = result.grad + self.grad += result.grad + y.grad += result.grad self._backward = _backward return result def __mul__(self, y): + y = y if isinstance(y, Scalar) else Scalar(y) result = Scalar(self.data * y.data, (self, y), _op='*') def _backward(): - self.grad = y.data * result.grad - y.grad = self.data * result.grad + self.grad += y.data * result.grad + y.grad += self.data * result.grad + + self._backward = _backward + + return result + + def __pow__(self, y): + assert isinstance(y, (int, float)) + result = Scalar(self.data ** y, (self, ), _op=f'** {y}') + + def _backward(): + self.grad += (y * self.data ** (y - 1)) * result.grad + + self._backward = _backward + + return result + + def exp(self): + x = self.data + e = math.exp(x) + result = Scalar(e, (self, ), 'exp') + + def _backward(): + self.grad += result.data * result.grad self._backward = _backward @@ -43,7 +65,7 @@ class Scalar: result = Scalar(t, (self, ), 'tanh') def _backward(): - self.grad = (1 - (t ** 2)) * result.grad + self.grad += (1 - t ** 2) * result.grad self._backward = _backward @@ -64,3 +86,35 @@ class Scalar: for child in children: child._backward() + + def zero_grad(self): + self.grad = 0.0 + children = self.build_children() + + for child in children: + child.grad = 0.0 + + def __truediv__(self, y): + return self * y ** -1 + + def __rtruediv__(self, y): + return self * y ** -1 + + def __neg__(self): + return self * -1 + + def __sub__(self, y): + return self + (-y) + + def __rsub__(self, y): + return self + (-y) + + def __radd__(self, y): + return self + y + + def __rmul__(self, y): + return self * y + + def __repr__(self) -> str: + return f'Scalar({self.data})' + -- cgit v1.2.3