aboutsummaryrefslogtreecommitdiff
path: root/example.py
blob: 4c00845f72b65342fa739586167fa056de2e2b6e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/env python

from src.scalar import Scalar

h = 0.0001

# def one():
#     a = Scalar(2, label='a')
#     b = Scalar(-3, label='b')
#     c = Scalar(10, label='c')
#     f = Scalar(-2, label='f')
#
#     d = a * b; d.label = 'd'
#     e = d + c; e.label = 'e'
#     L = e * f; L.label = 'L'
#
#     return L.data
#
# def two():
#     a = Scalar(2, label='a')
#     b = Scalar(-3, label='b')
#     c = Scalar(10, label='c')
#     f = Scalar(-2, label='f')
#
#     d = a * b; d.label = 'd'
#     d.data += h
#     e = d + c; e.label = 'e'
#     L = e * f; L.label = 'L'
#
#     return L.data
#
# print((two() - one()) / h)

a = Scalar(2, label='a')
b = Scalar(-3, label='b')
c = Scalar(10, label='c')
f = Scalar(-2, label='f')

d = a * b; d.label = 'd'
e = d + c; e.label = 'e'
L = e * f; L.label = 'L'

L.grad = 1.0
e.grad = -2.0
f.grad = 4.0
d.grad = -2.0
c.grad = -2.0

from src.graph import Graph
Graph(L).show()