diff options
Diffstat (limited to 'circuitpython/extmod/ulab/tests/1d/numpy/argminmax.py')
| -rw-r--r-- | circuitpython/extmod/ulab/tests/1d/numpy/argminmax.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/circuitpython/extmod/ulab/tests/1d/numpy/argminmax.py b/circuitpython/extmod/ulab/tests/1d/numpy/argminmax.py new file mode 100644 index 0000000..e2aa0bc --- /dev/null +++ b/circuitpython/extmod/ulab/tests/1d/numpy/argminmax.py @@ -0,0 +1,62 @@ +from ulab import numpy as np + +# Adapted from https://docs.python.org/3.8/library/itertools.html#itertools.permutations +def permutations(iterable, r=None): + # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC + # permutations(range(3)) --> 012 021 102 120 201 210 + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + if r > n: + return + indices = list(range(n)) + cycles = list(range(n, n-r, -1)) + yield tuple(pool[i] for i in indices[:r]) + while n: + for i in reversed(range(r)): + cycles[i] -= 1 + if cycles[i] == 0: + indices[i:] = indices[i+1:] + indices[i:i+1] + cycles[i] = n - i + else: + j = cycles[i] + indices[i], indices[-j] = indices[-j], indices[i] + yield tuple(pool[i] for i in indices[:r]) + break + else: + return + +# Combinations expected to throw +try: + print(np.argmin([])) +except ValueError: + print("ValueError") + +try: + print(np.argmax([])) +except ValueError: + print("ValueError") + +# Combinations expected to succeed +print(np.argmin([1])) +print(np.argmax([1])) +print(np.argmin(np.array([1]))) +print(np.argmax(np.array([1]))) + +print() +print("max tests") +for p in permutations((100,200,300)): + m1 = np.argmax(p) + m2 = np.argmax(np.array(p)) + print(p, m1, m2) + if m1 != m2 or p[m1] != max(p): + print("FAIL", p, m1, m2, max(p)) + +print() +print("min tests") +for p in permutations((100,200,300)): + m1 = np.argmin(p) + m2 = np.argmin(np.array(p)) + print(p, m1, m2) + if m1 != m2 or p[m1] != min(p): + print("FAIL", p, m1, m2, min(p)) |
