aboutsummaryrefslogtreecommitdiff
path: root/circuitpython/extmod/ulab/code/ndarray_operators.h
diff options
context:
space:
mode:
Diffstat (limited to 'circuitpython/extmod/ulab/code/ndarray_operators.h')
-rw-r--r--circuitpython/extmod/ulab/code/ndarray_operators.h277
1 files changed, 277 insertions, 0 deletions
diff --git a/circuitpython/extmod/ulab/code/ndarray_operators.h b/circuitpython/extmod/ulab/code/ndarray_operators.h
new file mode 100644
index 0000000..7849e03
--- /dev/null
+++ b/circuitpython/extmod/ulab/code/ndarray_operators.h
@@ -0,0 +1,277 @@
+/*
+ * This file is part of the micropython-ulab project,
+ *
+ * https://github.com/v923z/micropython-ulab
+ *
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2020-2021 Zoltán Vörös
+*/
+
+#include "ndarray.h"
+
+mp_obj_t ndarray_binary_equality(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *, mp_binary_op_t );
+mp_obj_t ndarray_binary_add(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
+mp_obj_t ndarray_binary_multiply(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
+mp_obj_t ndarray_binary_more(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *, mp_binary_op_t );
+mp_obj_t ndarray_binary_power(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
+mp_obj_t ndarray_binary_subtract(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
+mp_obj_t ndarray_binary_true_divide(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
+
+mp_obj_t ndarray_inplace_ams(ndarray_obj_t *, ndarray_obj_t *, int32_t *, uint8_t );
+mp_obj_t ndarray_inplace_power(ndarray_obj_t *, ndarray_obj_t *, int32_t *);
+mp_obj_t ndarray_inplace_divide(ndarray_obj_t *, ndarray_obj_t *, int32_t *);
+
+#define UNWRAP_INPLACE_OPERATOR(lhs, larray, rarray, rstrides, OPERATOR)\
+({\
+ if((lhs)->dtype == NDARRAY_UINT8) {\
+ if((rhs)->dtype == NDARRAY_UINT8) {\
+ INPLACE_LOOP((lhs), uint8_t, uint8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_INT8) {\
+ INPLACE_LOOP((lhs), uint8_t, int8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_UINT16) {\
+ INPLACE_LOOP((lhs), uint8_t, uint16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else {\
+ INPLACE_LOOP((lhs), uint8_t, int16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ }\
+ } else if(lhs->dtype == NDARRAY_INT8) {\
+ if(rhs->dtype == NDARRAY_UINT8) {\
+ INPLACE_LOOP((lhs), int8_t, uint8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_INT8) {\
+ INPLACE_LOOP((lhs), int8_t, int8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_UINT16) {\
+ INPLACE_LOOP((lhs), int8_t, uint16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else {\
+ INPLACE_LOOP((lhs), int8_t, int16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ }\
+ } else if(lhs->dtype == NDARRAY_UINT16) {\
+ if(rhs->dtype == NDARRAY_UINT8) {\
+ INPLACE_LOOP((lhs), uint16_t, uint8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_INT8) {\
+ INPLACE_LOOP((lhs), uint16_t, int8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_UINT16) {\
+ INPLACE_LOOP((lhs), uint16_t, uint16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else {\
+ INPLACE_LOOP((lhs), uint16_t, int16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ }\
+ } else if(lhs->dtype == NDARRAY_INT16) {\
+ if(rhs->dtype == NDARRAY_UINT8) {\
+ INPLACE_LOOP((lhs), int16_t, uint8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_INT8) {\
+ INPLACE_LOOP((lhs), int16_t, int8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_UINT16) {\
+ INPLACE_LOOP((lhs), int16_t, uint16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else {\
+ INPLACE_LOOP((lhs), int16_t, int16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ }\
+ } else if(lhs->dtype == NDARRAY_FLOAT) {\
+ if(rhs->dtype == NDARRAY_UINT8) {\
+ INPLACE_LOOP((lhs), mp_float_t, uint8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_INT8) {\
+ INPLACE_LOOP((lhs), mp_float_t, int8_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_UINT16) {\
+ INPLACE_LOOP((lhs), mp_float_t, uint16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else if(rhs->dtype == NDARRAY_INT16) {\
+ INPLACE_LOOP((lhs), mp_float_t, int16_t, (larray), (rarray), (rstrides), OPERATOR);\
+ } else {\
+ INPLACE_LOOP((lhs), mp_float_t, mp_float_t, (larray), (rarray), (rstrides), OPERATOR);\
+ }\
+ }\
+})
+
+#if ULAB_MAX_DIMS == 1
+#define INPLACE_POWER(results, type_left, type_right, larray, rarray, rstrides)\
+({ size_t l = 0;\
+ do {\
+ *((type_left *)(larray)) = MICROPY_FLOAT_C_FUN(pow)(*((type_left *)(larray)), *((type_right *)(rarray)));\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
+})
+
+#define FUNC_POINTER_LOOP(results, array, get_lhs, get_rhs, larray, lstrides, rarray, rstrides, OPERATION)\
+({ size_t l = 0;\
+ do {\
+ mp_float_t lvalue = (get_lhs)((larray));\
+ mp_float_t rvalue = (get_rhs)((rarray));\
+ (set_result)((array), OPERATION);\
+ (array) += (results)->itemsize;\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
+})
+#endif /* ULAB_MAX_DIMS == 1 */
+
+#if ULAB_MAX_DIMS == 2
+#define INPLACE_POWER(results, type_left, type_right, larray, rarray, rstrides)\
+({ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ *((type_left *)(larray)) = MICROPY_FLOAT_C_FUN(pow)(*((type_left *)(larray)), *((type_right *)(rarray)));\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
+})
+
+#define FUNC_POINTER_LOOP(results, array, get_lhs, get_rhs, larray, lstrides, rarray, rstrides, OPERATION)\
+({ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ mp_float_t lvalue = (get_lhs)((larray));\
+ mp_float_t rvalue = (get_rhs)((rarray));\
+ (set_result)((array), OPERATION);\
+ (array) += (results)->itemsize;\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < results->shape[ULAB_MAX_DIMS - 2]);\
+})
+#endif /* ULAB_MAX_DIMS == 2 */
+
+#if ULAB_MAX_DIMS == 3
+#define INPLACE_POWER(results, type_left, type_right, larray, rarray, rstrides)\
+({ size_t j = 0;\
+ do {\
+ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ *((type_left *)(larray)) = MICROPY_FLOAT_C_FUN(pow)(*((type_left *)(larray)), *((type_right *)(rarray)));\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 3];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
+ j++;\
+ } while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
+})
+
+
+#define FUNC_POINTER_LOOP(results, array, get_lhs, get_rhs, larray, lstrides, rarray, rstrides, OPERATION)\
+({ size_t j = 0;\
+ do {\
+ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ mp_float_t lvalue = (get_lhs)((larray));\
+ mp_float_t rvalue = (get_rhs)((rarray));\
+ (set_result)((array), OPERATION);\
+ (array) += (results)->itemsize;\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < results->shape[ULAB_MAX_DIMS - 2]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 3];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
+ j++;\
+ } while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
+})
+#endif /* ULAB_MAX_DIMS == 3 */
+
+#if ULAB_MAX_DIMS == 4
+#define INPLACE_POWER(results, type_left, type_right, larray, rarray, rstrides)\
+({ size_t i = 0;\
+ do {\
+ size_t j = 0;\
+ do {\
+ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ *((type_left *)(larray)) = MICROPY_FLOAT_C_FUN(pow)(*((type_left *)(larray)), *((type_right *)(rarray)));\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 3];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
+ j++;\
+ } while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 4];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
+ i++;\
+ } while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\
+})
+
+#define FUNC_POINTER_LOOP(results, array, get_lhs, get_rhs, larray, lstrides, rarray, rstrides, OPERATION)\
+({ size_t i = 0;\
+ do {\
+ size_t j = 0;\
+ do {\
+ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ mp_float_t lvalue = (get_lhs)((larray));\
+ mp_float_t rvalue = (get_rhs)((rarray));\
+ (set_result)((array), OPERATION);\
+ (array) += (results)->itemsize;\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < results->shape[ULAB_MAX_DIMS - 2]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 3];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
+ j++;\
+ } while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
+ (larray) -= (results)->strides[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
+ (larray) += (results)->strides[ULAB_MAX_DIMS - 4];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
+ i++;\
+ } while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\
+})
+#endif /* ULAB_MAX_DIMS == 4 */