aboutsummaryrefslogtreecommitdiff
path: root/src/nn.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.py')
-rw-r--r--src/nn.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/src/nn.py b/src/nn.py
index e5d20d4..d549091 100644
--- a/src/nn.py
+++ b/src/nn.py
@@ -17,6 +17,9 @@ class Neuron:
return result.tanh()
+ def parameters(self):
+ return self.w + [ self.b ]
+
class Layer:
def __init__(self, n_X, n_y):
self.neurons = [ Neuron(n_X) for _ in range(n_y) ]
@@ -25,6 +28,9 @@ class Layer:
result = [ n(X) for n in self.neurons ]
return result[0] if len(result) == 1 else result
+ def parameters(self):
+ return [ param for neuron in self.neurons for param in neuron.parameters() ]
+
class MLP:
def __init__(self, n_X, layers):
sz = [ n_X ] + layers
@@ -35,3 +41,10 @@ class MLP:
X = layer(X)
return X
+
+ def parameters(self):
+ return [ param for layer in self.layers for param in layer.parameters() ]
+
+ def optimise(self, lr):
+ for parameter in self.parameters():
+ parameter.data -= lr * parameter.grad