diff options
Diffstat (limited to 'src/nn.py')
-rw-r--r-- | src/nn.py | 13 |
1 files changed, 13 insertions, 0 deletions
@@ -17,6 +17,9 @@ class Neuron: return result.tanh() + def parameters(self): + return self.w + [ self.b ] + class Layer: def __init__(self, n_X, n_y): self.neurons = [ Neuron(n_X) for _ in range(n_y) ] @@ -25,6 +28,9 @@ class Layer: result = [ n(X) for n in self.neurons ] return result[0] if len(result) == 1 else result + def parameters(self): + return [ param for neuron in self.neurons for param in neuron.parameters() ] + class MLP: def __init__(self, n_X, layers): sz = [ n_X ] + layers @@ -35,3 +41,10 @@ class MLP: X = layer(X) return X + + def parameters(self): + return [ param for layer in self.layers for param in layer.parameters() ] + + def optimise(self, lr): + for parameter in self.parameters(): + parameter.data -= lr * parameter.grad |