aboutsummaryrefslogtreecommitdiff
path: root/circuitpython/extmod/ulab/code/ulab_tools.h
blob: 2898ef1f1144a91cc62f38fdea6e8c6e88a2200a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/*
 * This file is part of the micropython-ulab project,
 *
 * https://github.com/v923z/micropython-ulab
 *
 * The MIT License (MIT)
 *
 * Copyright (c) 2020-2022 Zoltán Vörös
*/

#ifndef _TOOLS_
#define _TOOLS_

#include "ndarray.h"

#define SWAP(t, a, b) { t tmp = a; a = b; b = tmp; }

typedef struct _shape_strides_t {
    uint8_t increment;
    uint8_t ndim;
    size_t *shape;
    int32_t *strides;
} shape_strides;

mp_float_t ndarray_get_float_uint8(void *);
mp_float_t ndarray_get_float_int8(void *);
mp_float_t ndarray_get_float_uint16(void *);
mp_float_t ndarray_get_float_int16(void *);
mp_float_t ndarray_get_float_float(void *);
void *ndarray_get_float_function(uint8_t );

uint8_t ndarray_upcast_dtype(uint8_t , uint8_t );
void *ndarray_set_float_function(uint8_t );

shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
int8_t tools_get_axis(mp_obj_t , uint8_t );
ndarray_obj_t *tools_object_is_square(mp_obj_t );

uint8_t ulab_binary_get_size(uint8_t );

#if ULAB_SUPPORTS_COMPLEX
void ulab_rescale_float_strides(int32_t *);
#endif

#endif