aboutsummaryrefslogtreecommitdiff
path: root/circuitpython/extmod/ulab/tests/1d/numpy/argminmax.py
blob: e2aa0bc99c288c391eb2b6649083639b1cd58d30 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from ulab import numpy as np

# Adapted from https://docs.python.org/3.8/library/itertools.html#itertools.permutations
def permutations(iterable, r=None):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return

# Combinations expected to throw
try:
    print(np.argmin([]))
except ValueError:
    print("ValueError")

try:
    print(np.argmax([]))
except ValueError:
    print("ValueError")

# Combinations expected to succeed
print(np.argmin([1]))
print(np.argmax([1]))
print(np.argmin(np.array([1])))
print(np.argmax(np.array([1])))

print()
print("max tests")
for p in permutations((100,200,300)):
    m1 = np.argmax(p)
    m2 = np.argmax(np.array(p))
    print(p, m1, m2)
    if m1 != m2 or p[m1] != max(p):
        print("FAIL", p, m1, m2, max(p))

print()
print("min tests")
for p in permutations((100,200,300)):
    m1 = np.argmin(p)
    m2 = np.argmin(np.array(p))
    print(p, m1, m2)
    if m1 != m2 or p[m1] != min(p):
        print("FAIL", p, m1, m2, min(p))