aboutsummaryrefslogtreecommitdiff
path: root/src/scalar.py
blob: a67b7ae6e3ac1cc320998af3a35b363e30d0deca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Scalar:
    def __init__(self, data, _children=(), _op='', label='') -> None:
        self.label = label

        self.data = float(data)
        self.grad = 0.0

        self._prev = set(_children)
        self._op = _op
    
    def __repr__(self) -> str:
        return f'Scalar({self.data})'

    def __add__(self, y):
        result = self.data + y.data
        return Scalar(result, (self, y), _op='+')

    def __mul__(self, y):
        result = self.data * y.data
        return Scalar(result, (self, y), _op='*')