aboutsummaryrefslogtreecommitdiff
path: root/circuitpython/extmod/ulab/code/numpy/compare.h
diff options
context:
space:
mode:
Diffstat (limited to 'circuitpython/extmod/ulab/code/numpy/compare.h')
-rw-r--r--circuitpython/extmod/ulab/code/numpy/compare.h150
1 files changed, 150 insertions, 0 deletions
diff --git a/circuitpython/extmod/ulab/code/numpy/compare.h b/circuitpython/extmod/ulab/code/numpy/compare.h
new file mode 100644
index 0000000..90ceaf7
--- /dev/null
+++ b/circuitpython/extmod/ulab/code/numpy/compare.h
@@ -0,0 +1,150 @@
+
+/*
+ * This file is part of the micropython-ulab project,
+ *
+ * https://github.com/v923z/micropython-ulab
+ *
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2020-2021 Zoltán Vörös
+*/
+
+#ifndef _COMPARE_
+#define _COMPARE_
+
+#include "../ulab.h"
+#include "../ndarray.h"
+
+enum COMPARE_FUNCTION_TYPE {
+ COMPARE_EQUAL,
+ COMPARE_NOT_EQUAL,
+ COMPARE_MINIMUM,
+ COMPARE_MAXIMUM,
+ COMPARE_CLIP,
+};
+
+MP_DECLARE_CONST_FUN_OBJ_3(compare_clip_obj);
+MP_DECLARE_CONST_FUN_OBJ_2(compare_equal_obj);
+MP_DECLARE_CONST_FUN_OBJ_2(compare_isfinite_obj);
+MP_DECLARE_CONST_FUN_OBJ_2(compare_isinf_obj);
+MP_DECLARE_CONST_FUN_OBJ_2(compare_minimum_obj);
+MP_DECLARE_CONST_FUN_OBJ_2(compare_maximum_obj);
+MP_DECLARE_CONST_FUN_OBJ_2(compare_not_equal_obj);
+MP_DECLARE_CONST_FUN_OBJ_3(compare_where_obj);
+
+#if ULAB_MAX_DIMS == 1
+#define COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
+ size_t l = 0;\
+ do {\
+ *((type_out *)(array)) = *((type_left *)(larray)) OPERATOR *((type_right *)(rarray)) ? (type_out)(*((type_left *)(larray))) : (type_out)(*((type_right *)(rarray)));\
+ (array) += (results)->strides[ULAB_MAX_DIMS - 1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < results->shape[ULAB_MAX_DIMS - 1]);\
+ return MP_OBJ_FROM_PTR(results);\
+
+#endif // ULAB_MAX_DIMS == 1
+
+#if ULAB_MAX_DIMS == 2
+#define COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
+ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ *((type_out *)(array)) = *((type_left *)(larray)) OPERATOR *((type_right *)(rarray)) ? (type_out)(*((type_left *)(larray))) : (type_out)(*((type_right *)(rarray)));\
+ (array) += (results)->strides[ULAB_MAX_DIMS - 1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < results->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < results->shape[ULAB_MAX_DIMS - 2]);\
+ return MP_OBJ_FROM_PTR(results);\
+
+#endif // ULAB_MAX_DIMS == 2
+
+#if ULAB_MAX_DIMS == 3
+#define COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
+ size_t j = 0;\
+ do {\
+ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ *((type_out *)(array)) = *((type_left *)(larray)) OPERATOR *((type_right *)(rarray)) ? (type_out)(*((type_left *)(larray))) : (type_out)(*((type_right *)(rarray)));\
+ (array) += (results)->strides[ULAB_MAX_DIMS - 1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < results->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < results->shape[ULAB_MAX_DIMS - 2]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
+ j++;\
+ } while(j < results->shape[ULAB_MAX_DIMS - 3]);\
+ return MP_OBJ_FROM_PTR(results);\
+
+#endif // ULAB_MAX_DIMS == 3
+
+#if ULAB_MAX_DIMS == 4
+#define COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
+ size_t i = 0;\
+ do {\
+ size_t j = 0;\
+ do {\
+ size_t k = 0;\
+ do {\
+ size_t l = 0;\
+ do {\
+ *((type_out *)(array)) = *((type_left *)(larray)) OPERATOR *((type_right *)(rarray)) ? (type_out)(*((type_left *)(larray))) : (type_out)(*((type_right *)(rarray)));\
+ (array) += (results)->strides[ULAB_MAX_DIMS - 1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
+ l++;\
+ } while(l < results->shape[ULAB_MAX_DIMS - 1]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
+ k++;\
+ } while(k < results->shape[ULAB_MAX_DIMS - 2]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
+ j++;\
+ } while(j < results->shape[ULAB_MAX_DIMS - 3]);\
+ (larray) -= (lstrides)[ULAB_MAX_DIMS - 3] * results->shape[ULAB_MAX_DIMS-3];\
+ (larray) += (lstrides)[ULAB_MAX_DIMS - 4];\
+ (rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * results->shape[ULAB_MAX_DIMS-3];\
+ (rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
+ i++;\
+ } while(i < results->shape[ULAB_MAX_DIMS - 4]);\
+ return MP_OBJ_FROM_PTR(results);\
+
+#endif // ULAB_MAX_DIMS == 4
+
+#define RUN_COMPARE_LOOP(dtype, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, ndim, shape, op) do {\
+ ndarray_obj_t *results = ndarray_new_dense_ndarray((ndim), (shape), (dtype));\
+ uint8_t *array = (uint8_t *)results->array;\
+ if((op) == COMPARE_MINIMUM) {\
+ COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, <);\
+ }\
+ if((op) == COMPARE_MAXIMUM) {\
+ COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, >);\
+ }\
+} while(0)
+
+#endif