From 22f6c7e5b3fd45480f2ef7827474fdb9d6c6f82c Mon Sep 17 00:00:00 2001 From: Raghuram Subramani Date: Wed, 29 May 2024 18:11:05 +0530 Subject: Value -> Scalar --- src/graph.py | 6 +++--- src/scalar.py | 20 ++++++++++++++++++++ src/value.py | 20 -------------------- 3 files changed, 23 insertions(+), 23 deletions(-) create mode 100644 src/scalar.py delete mode 100644 src/value.py (limited to 'src') diff --git a/src/graph.py b/src/graph.py index 2067071..9aecde1 100644 --- a/src/graph.py +++ b/src/graph.py @@ -4,10 +4,10 @@ from graphviz import Digraph import matplotlib.pyplot as plt import matplotlib.image as mpimg -from .value import Value +from .scalar import Scalar class Graph: - def __init__(self, root: Value) -> None: + def __init__(self, root: Scalar) -> None: self.dot = Digraph(format='png', graph_attr={ 'rankdir': 'LR' }) self.nodes = set() @@ -16,7 +16,7 @@ class Graph: self.build(root) self.draw() - def build(self, x: Value): + def build(self, x: Scalar): self.nodes.add(x) for child in x._prev: # Add a line from child to x diff --git a/src/scalar.py b/src/scalar.py new file mode 100644 index 0000000..a67b7ae --- /dev/null +++ b/src/scalar.py @@ -0,0 +1,20 @@ +class Scalar: + def __init__(self, data, _children=(), _op='', label='') -> None: + self.label = label + + self.data = float(data) + self.grad = 0.0 + + self._prev = set(_children) + self._op = _op + + def __repr__(self) -> str: + return f'Scalar({self.data})' + + def __add__(self, y): + result = self.data + y.data + return Scalar(result, (self, y), _op='+') + + def __mul__(self, y): + result = self.data * y.data + return Scalar(result, (self, y), _op='*') diff --git a/src/value.py b/src/value.py deleted file mode 100644 index af0ab35..0000000 --- a/src/value.py +++ /dev/null @@ -1,20 +0,0 @@ -class Value: - def __init__(self, data, _children=(), _op='', label='') -> None: - self.label = label - - self.data = float(data) - self.grad = 0.0 - - self._prev = set(_children) - self._op = _op - - def __repr__(self) -> str: - return f'Value({self.data})' - - def __add__(self, y): - result = self.data + y.data - return Value(result, (self, y), _op='+') - - def __mul__(self, y): - result = self.data * y.data - return Value(result, (self, y), _op='*') -- cgit v1.2.3