aboutsummaryrefslogtreecommitdiff
path: root/circuitpython/extmod/ulab/snippets/rclass.py
blob: cb95021a1eaf937428845558ce7460619635d09f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import List, Tuple, Union  # upip.install("pycopy-typing")
from ulab import numpy as np

_DType = int
_RClassKeyType = Union[slice, int, float, list, tuple, np.ndarray]

# this is a stripped down version of RClass (used by np.r_[...etc])
# it doesn't include support for string arguments as the first index element
class RClass:

    def __getitem__(self, key: Union[_RClassKeyType, Tuple[_RClassKeyType, ...]]):

        if not isinstance(key, tuple):
            key = (key,)

        objs: List[np.ndarray] = []
        scalars: List[int] = []
        arraytypes: List[_DType] = []
        scalartypes: List[_DType] = []

        # these may get overridden in following loop
        axis = 0

        for idx, item in enumerate(key):
            scalar = False

            try:
                if isinstance(item, np.ndarray):
                    newobj = item

                elif isinstance(item, slice):
                    step = item.step
                    start = item.start
                    stop = item.stop
                    if start is None:
                        start = 0
                    if step is None:
                        step = 1
                    if isinstance(step, complex):
                        size = int(abs(step))
                        newobj: np.ndarray = np.linspace(start, stop, num=size)
                    else:
                        newobj = np.arange(start, stop, step)

                # if is number
                elif isinstance(item, (int, float, bool)):
                    newobj = np.array([item])
                    scalars.append(len(objs))
                    scalar = True
                    scalartypes.append(newobj.dtype())
                    
                else:
                    newobj = np.array(item)
                    
            except TypeError:
                raise Exception("index object %s of type %s is not supported by r_[]" % (
                    str(item), type(item)))

            objs.append(newobj)
            if not scalar and isinstance(newobj, np.ndarray):
                arraytypes.append(newobj.dtype())

        # Ensure that scalars won't up-cast unless warranted
        final_dtype = min(arraytypes + scalartypes)
        for idx, obj in enumerate(objs):
            if obj.dtype != final_dtype:
                objs[idx] = np.array(objs[idx], dtype=final_dtype)

        return np.concatenate(tuple(objs), axis=axis)

    # this seems weird - not sure what it's for
    def __len__(self):
        return 0
        
r_ = RClass()