From 2cf0bc1f295518c91915804a1185037b733e0d2b Mon Sep 17 00:00:00 2001 From: Jon Nordby Date: Tue, 8 Jul 2025 01:15:10 +0200 Subject: [PATCH 1/6] extratrees: Initial code for extremely randomized trees learner Designed to be as embedded friendly as possible --- Makefile | 6 + src/emlearn_extratrees/Makefile | 37 ++ src/emlearn_extratrees/eml_extratrees.c | 449 ++++++++++++++++++++++++ src/emlearn_extratrees/extratrees.c | 325 +++++++++++++++++ 4 files changed, 817 insertions(+) create mode 100644 src/emlearn_extratrees/Makefile create mode 100644 src/emlearn_extratrees/eml_extratrees.c create mode 100644 src/emlearn_extratrees/extratrees.c diff --git a/Makefile b/Makefile index 436a486..47b645e 100644 --- a/Makefile +++ b/Makefile @@ -58,6 +58,9 @@ $(MODULES_PATH)/emlearn_arrayutils.mpy: $(MODULES_PATH)/emlearn_linreg.mpy: make -C src/emlearn_linreg/ ARCH=$(ARCH) MPY_DIR=$(MPY_DIR_ABS) V=1 clean dist +$(MODULES_PATH)/emlearn_extratrees.mpy: + make -C src/emlearn_extratrees/ ARCH=$(ARCH) MPY_DIR=$(MPY_DIR_ABS) V=1 clean dist + emlearn_trees.results: $(MODULES_PATH)/emlearn_trees.mpy MICROPYPATH=$(MODULES_PATH) $(MICROPYTHON_BIN) tests/test_trees.py @@ -85,6 +88,9 @@ emlearn_arrayutils.results: $(MODULES_PATH)/emlearn_arrayutils.mpy emlearn_linreg.results: $(MODULES_PATH)/emlearn_linreg.mpy MICROPYPATH=$(MODULES_PATH) $(MICROPYTHON_BIN) tests/test_linreg.py +emlearn_extratrees.results: $(MODULES_PATH)/emlearn_extratrees.mpy + MICROPYPATH=$(MODULES_PATH) $(MICROPYTHON_BIN) tests/test_extratrees_xor.py + $(PORT_DIR): mkdir -p $@ diff --git a/src/emlearn_extratrees/Makefile b/src/emlearn_extratrees/Makefile new file mode 100644 index 0000000..815624b --- /dev/null +++ b/src/emlearn_extratrees/Makefile @@ -0,0 +1,37 @@ +# Location of top-level MicroPython directory +MPY_DIR = ../../micropython + +# Architecture to build for (x86, x64, armv6m, armv7m, xtensa, xtensawin) +ARCH = x64 + +# The ABI version for .mpy files +MPY_ABI_VERSION := 6.3 + +# Location of emlearn library +EMLEARN_DIR := $(shell python3 -c "import emlearn; print(emlearn.includedir)") + +# enable linking of libm etc +LINK_RUNTIME=1 + +DIST_DIR := ../../dist/$(ARCH)_$(MPY_ABI_VERSION) + +# Name of module +MOD = emlearn_extratrees + +# Source files (.c or .py) +SRC = extratrees.c + +# Include to get the rules for compiling and linking the module +include $(MPY_DIR)/py/dynruntime.mk + +# Releases +DIST_FILE = $(DIST_DIR)/$(MOD).mpy +$(DIST_DIR): + mkdir -p $@ + +$(DIST_FILE): $(MOD).mpy $(DIST_DIR) + cp $< $@ + +CFLAGS += -I$(EMLEARN_DIR) -Wno-unused-function + +dist: $(DIST_FILE) diff --git a/src/emlearn_extratrees/eml_extratrees.c b/src/emlearn_extratrees/eml_extratrees.c new file mode 100644 index 0000000..7acf2ee --- /dev/null +++ b/src/emlearn_extratrees/eml_extratrees.c @@ -0,0 +1,449 @@ +#include +#include +#include + +typedef struct _EmlTreesNode { + int8_t feature; // -1 for leaf nodes + int16_t value; // threshold or class label + int16_t left; // left child index + int16_t right; // right child index +} EmlTreesNode; + +typedef struct _NodeState { + int16_t node_idx; // current node being processed + int16_t start; // sample range start + int16_t end; // sample range end + int16_t depth; // current depth +} NodeState; + +typedef struct _EmlTreesConfig { + int16_t max_depth; + int16_t min_samples_leaf; + int16_t n_thresholds; + float subsample_ratio; // subsample ratio as float (0.0 to 1.0) + float feature_subsample_ratio; // feature subsample ratio as float (0.0 to 1.0) + uint32_t rng_seed; +} EmlTreesConfig; + +typedef struct _EmlTreesModel { + EmlTreesNode *nodes; // Pre-allocated node array + int16_t *tree_starts; // Start index for each tree + int16_t max_nodes; // Maximum nodes available + int16_t n_nodes_used; // Current nodes used + int16_t n_features; // Number of features + int16_t n_classes; // Number of classes + int16_t n_trees; // Number of trees + EmlTreesConfig config; +} EmlTreesModel; + +typedef struct _EmlTreesWorkspace { + int16_t *sample_indices; // Sample indices for current tree + int16_t *feature_indices; // Feature indices for current tree + int16_t *min_vals; // Min values per feature [n_features] + int16_t *max_vals; // Max values per feature [n_features] + NodeState *node_stack; // Stack for tree building + uint32_t rng_state; // Simple RNG state + int16_t n_samples; // Number of samples +} EmlTreesWorkspace; + +// Simple linear congruential generator +static uint32_t eml_rand(uint32_t *state) { + *state = *state * 1103515245 + 12345; + return *state; +} + +// Fisher-Yates shuffle for subsampling +static void shuffle_indices(int16_t *indices, int16_t n, uint32_t *rng_state) { + for (int16_t i = 0; i < n - 1; i++) { + int16_t j = i + (eml_rand(rng_state) % (n - i)); + int16_t temp = indices[i]; + indices[i] = indices[j]; + indices[j] = temp; + } +} + +// Calculate Gini impurity from class counts +static float calculate_gini_from_counts(const int16_t *counts, int16_t total, int16_t n_classes) { + if (total == 0) return 0.0f; + + float gini = 1.0f; + for (int16_t i = 0; i < n_classes; i++) { + if (counts[i] > 0) { + float prob = (float)counts[i] / (float)total; + gini -= prob * prob; + } + } + + return gini; +} + +// Find best split for a node using count-based Gini (no temp_indices needed) +static int16_t find_best_split(const int16_t *features, const int16_t *labels, + EmlTreesModel *model, EmlTreesWorkspace *workspace, + int16_t start, int16_t end, int16_t n_features_subset, + int8_t *best_feature, int16_t *best_threshold) { + + float best_improvement = -1.0f; // Best information gain found + *best_feature = -1; + *best_threshold = 0; + + int16_t total_samples = end - start; + if (total_samples < 2) return -1; + + // Calculate parent impurity + int16_t parent_counts[256] = {0}; + for (int16_t i = start; i < end; i++) { + parent_counts[labels[workspace->sample_indices[i]]]++; + } + float parent_gini = calculate_gini_from_counts(parent_counts, total_samples, model->n_classes); + + // Try each feature in the subset + for (int16_t f = 0; f < n_features_subset; f++) { + int16_t feature_idx = workspace->feature_indices[f]; + int16_t min_val = workspace->min_vals[feature_idx]; + int16_t max_val = workspace->max_vals[feature_idx]; + + if (min_val >= max_val) continue; + + // Try random thresholds + for (int16_t t = 0; t < model->config.n_thresholds; t++) { + int16_t threshold = min_val + (eml_rand(&workspace->rng_state) % (max_val - min_val + 1)); + + // Count class distributions in left/right partitions without moving data + int16_t left_counts[256] = {0}; // Assume max 256 classes + int16_t right_counts[256] = {0}; + int16_t left_total = 0, right_total = 0; + + for (int16_t i = start; i < end; i++) { + int16_t sample_idx = workspace->sample_indices[i]; + int16_t feature_val = features[sample_idx * model->n_features + feature_idx]; + int16_t label = labels[sample_idx]; + + if (feature_val <= threshold) { + left_counts[label]++; + left_total++; + } else { + right_counts[label]++; + right_total++; + } + } + + if (left_total == 0 || right_total == 0) continue; + if (left_total < model->config.min_samples_leaf || right_total < model->config.min_samples_leaf) continue; + + // Calculate Gini for left and right partitions from counts + float left_gini = calculate_gini_from_counts(left_counts, left_total, model->n_classes); + float right_gini = calculate_gini_from_counts(right_counts, right_total, model->n_classes); + + // Calculate weighted Gini impurity + float weighted_gini = ((float)left_total * left_gini + (float)right_total * right_gini) / (float)total_samples; + + // Calculate information gain + float improvement = parent_gini - weighted_gini; + + if (improvement > best_improvement) { + best_improvement = improvement; + *best_feature = feature_idx; + *best_threshold = threshold; + } + } + } + + return (best_improvement > 0.0f) ? 0 : -1; // Return 0 if good split found, -1 otherwise +} + +// Partition samples based on feature threshold +static int16_t partition_samples(const int16_t *features, EmlTreesModel *model, + EmlTreesWorkspace *workspace, int16_t start, int16_t end, + int8_t feature, int16_t threshold) { + int16_t left = start; + int16_t right = end - 1; + + while (left <= right) { + // Find element on left that should be on right + while (left <= right && features[workspace->sample_indices[left] * model->n_features + feature] <= threshold) { + left++; + } + + // Find element on right that should be on left + while (left <= right && features[workspace->sample_indices[right] * model->n_features + feature] > threshold) { + right--; + } + + // Swap if needed + if (left < right) { + int16_t temp = workspace->sample_indices[left]; + workspace->sample_indices[left] = workspace->sample_indices[right]; + workspace->sample_indices[right] = temp; + left++; + right--; + } + } + + return left; // Split point +} + +// Get majority class in a range +static int16_t get_majority_class(const int16_t *labels, const int16_t *indices, + int16_t start, int16_t end, int16_t n_classes) { + int16_t counts[256] = {0}; // Assume max 256 classes + int16_t max_count = 0; + int16_t majority_class = 0; + + for (int16_t i = start; i < end; i++) { + counts[labels[indices[i]]]++; + } + + for (int16_t i = 0; i < n_classes; i++) { + if (counts[i] > max_count) { + max_count = counts[i]; + majority_class = i; + } + } + + return majority_class; +} + +// Build a single tree +static int16_t build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace, + const int16_t *features, const int16_t *labels) { + + int16_t tree_start = model->n_nodes_used; + + // Subsample features + int16_t n_features_subset = (int16_t)((float)model->n_features * model->config.feature_subsample_ratio); + if (n_features_subset < 1) n_features_subset = 1; + if (n_features_subset > model->n_features) n_features_subset = model->n_features; + + for (int16_t i = 0; i < model->n_features; i++) { + workspace->feature_indices[i] = i; + } + shuffle_indices(workspace->feature_indices, model->n_features, &workspace->rng_state); + + // Initialize root node state + int16_t stack_size = 1; + workspace->node_stack[0].node_idx = tree_start; + workspace->node_stack[0].start = 0; + workspace->node_stack[0].end = workspace->n_samples; + workspace->node_stack[0].depth = 0; + + // Initialize min/max values + for (int16_t f = 0; f < model->n_features; f++) { + workspace->min_vals[f] = 32767; + workspace->max_vals[f] = -32768; + } + + // Calculate initial min/max + for (int16_t i = 0; i < workspace->n_samples; i++) { + for (int16_t f = 0; f < model->n_features; f++) { + int16_t val = features[workspace->sample_indices[i] * model->n_features + f]; + if (val < workspace->min_vals[f]) workspace->min_vals[f] = val; + if (val > workspace->max_vals[f]) workspace->max_vals[f] = val; + } + } + + // Process stack + while (stack_size > 0) { + NodeState current = workspace->node_stack[--stack_size]; + int16_t node_idx = current.node_idx; + + if (node_idx >= model->max_nodes) { + return -1; // Out of nodes + } + + // Check stopping criteria + int16_t n_samples_node = current.end - current.start; + if (current.depth >= model->config.max_depth || + n_samples_node < model->config.min_samples_leaf * 2 || + n_samples_node <= 0) { + + // Create leaf node + model->nodes[node_idx].feature = -1; + model->nodes[node_idx].value = get_majority_class(labels, workspace->sample_indices, + current.start, current.end, model->n_classes); + model->nodes[node_idx].left = -1; + model->nodes[node_idx].right = -1; + if (node_idx >= model->n_nodes_used) { + model->n_nodes_used = node_idx + 1; + } + continue; + } + + // Find best split + int8_t best_feature; + int16_t best_threshold; + int16_t split_result = find_best_split(features, labels, model, workspace, + current.start, current.end, + n_features_subset, &best_feature, &best_threshold); + + if (split_result != 0 || best_feature == -1) { + // No valid split found, create leaf + model->nodes[node_idx].feature = -1; + model->nodes[node_idx].value = get_majority_class(labels, workspace->sample_indices, + current.start, current.end, model->n_classes); + model->nodes[node_idx].left = -1; + model->nodes[node_idx].right = -1; + if (node_idx >= model->n_nodes_used) { + model->n_nodes_used = node_idx + 1; + } + continue; + } + + // Partition samples + int16_t split_point = partition_samples(features, model, workspace, current.start, current.end, + best_feature, best_threshold); + + // Check if partition was successful + if (split_point <= current.start || split_point >= current.end) { + // Partition failed, create leaf + model->nodes[node_idx].feature = -1; + model->nodes[node_idx].value = get_majority_class(labels, workspace->sample_indices, + current.start, current.end, model->n_classes); + model->nodes[node_idx].left = -1; + model->nodes[node_idx].right = -1; + if (node_idx >= model->n_nodes_used) { + model->n_nodes_used = node_idx + 1; + } + continue; + } + + // Calculate next available node indices + int16_t next_node = model->n_nodes_used; + if (next_node + 1 >= model->max_nodes) { + // Not enough space for children, create leaf + model->nodes[node_idx].feature = -1; + model->nodes[node_idx].value = get_majority_class(labels, workspace->sample_indices, + current.start, current.end, model->n_classes); + model->nodes[node_idx].left = -1; + model->nodes[node_idx].right = -1; + if (node_idx >= model->n_nodes_used) { + model->n_nodes_used = node_idx + 1; + } + continue; + } + + // Create internal node + model->nodes[node_idx].feature = best_feature; + model->nodes[node_idx].value = best_threshold; + model->nodes[node_idx].left = next_node; + model->nodes[node_idx].right = next_node + 1; + + // Update n_nodes_used to reserve space for children + model->n_nodes_used = next_node + 2; + if (node_idx >= model->n_nodes_used - 2) { + model->n_nodes_used = node_idx + 1; + } + + // Add children to stack (right first, then left for correct processing order) + if (stack_size < 100) { // Reasonable stack limit + // Right child + workspace->node_stack[stack_size].node_idx = model->nodes[node_idx].right; + workspace->node_stack[stack_size].start = split_point; + workspace->node_stack[stack_size].end = current.end; + workspace->node_stack[stack_size].depth = current.depth + 1; + stack_size++; + + // Left child + workspace->node_stack[stack_size].node_idx = model->nodes[node_idx].left; + workspace->node_stack[stack_size].start = current.start; + workspace->node_stack[stack_size].end = split_point; + workspace->node_stack[stack_size].depth = current.depth + 1; + stack_size++; + } + } + + return 0; +} + +// Main training function +int16_t eml_trees_train(EmlTreesModel *model, EmlTreesWorkspace *workspace, + const int16_t *features, const int16_t *labels) { + + model->n_nodes_used = 0; + workspace->rng_state = model->config.rng_seed; + + // Calculate subsample size + int16_t subsample_size = (int16_t)((float)workspace->n_samples * model->config.subsample_ratio); + if (subsample_size < 1) subsample_size = 1; + if (subsample_size > workspace->n_samples) subsample_size = workspace->n_samples; + + // Initialize sample indices + for (int16_t i = 0; i < workspace->n_samples; i++) { + workspace->sample_indices[i] = i; + } + + // Build each tree + for (int16_t tree = 0; tree < model->n_trees; tree++) { + // Store tree start index + model->tree_starts[tree] = model->n_nodes_used; + + // Subsample without replacement + shuffle_indices(workspace->sample_indices, workspace->n_samples, &workspace->rng_state); + + // Temporarily set n_samples to subsample size for tree building + int16_t original_n_samples = workspace->n_samples; + workspace->n_samples = subsample_size; + + // Build tree with subsampled data + int16_t result = build_tree(model, workspace, features, labels); + + // Restore original n_samples + workspace->n_samples = original_n_samples; + + if (result != 0) { + return result; + } + } + + return 0; +} + +// Prediction function that returns probabilities +int16_t eml_trees_predict_proba(const EmlTreesModel *model, const int16_t *features, + float *probabilities, int16_t *votes) { + + // Initialize vote counts using model's n_classes + for (int16_t i = 0; i < model->n_classes; i++) { + votes[i] = 0; + } + + // Get prediction from each tree + for (int16_t tree = 0; tree < model->n_trees; tree++) { + int16_t node_idx = model->tree_starts[tree]; + + // Traverse tree + while (model->nodes[node_idx].feature != -1) { + int8_t feature = model->nodes[node_idx].feature; + int16_t threshold = model->nodes[node_idx].value; + + if (features[feature] <= threshold) { + node_idx = model->nodes[node_idx].left; + } else { + node_idx = model->nodes[node_idx].right; + } + } + + // Add leaf prediction to votes + int16_t predicted_class = model->nodes[node_idx].value; + if (predicted_class >= 0 && predicted_class < model->n_classes) { + votes[predicted_class]++; + } + } + + // Convert votes to probabilities + for (int16_t i = 0; i < model->n_classes; i++) { + probabilities[i] = (float)votes[i] / (float)model->n_trees; + } + + // Find majority class + int16_t max_votes = 0; + int16_t predicted_class = 0; + for (int16_t i = 0; i < model->n_classes; i++) { + if (votes[i] > max_votes) { + max_votes = votes[i]; + predicted_class = i; + } + } + + return predicted_class; +} diff --git a/src/emlearn_extratrees/extratrees.c b/src/emlearn_extratrees/extratrees.c new file mode 100644 index 0000000..92d184c --- /dev/null +++ b/src/emlearn_extratrees/extratrees.c @@ -0,0 +1,325 @@ +// Include the header file to get access to the MicroPython API +#include "py/dynruntime.h" + +#include + +#include "eml_extratrees.c" + +// memset/memcpy for compatibility +#if !defined(__linux__) +void *memcpy(void *dst, const void *src, size_t n) { + return mp_fun_table.memmove_(dst, src, n); +} +void *memset(void *s, int c, size_t n) { + return mp_fun_table.memset_(s, c, n); +} +#endif + +// MicroPython type for ExtraTrees model +typedef struct _mp_obj_extratrees_model_t { + mp_obj_base_t base; + EmlTreesModel model; + EmlTreesWorkspace workspace; + int16_t *features_buffer; // Buffer to store features during training + int16_t *labels_buffer; // Buffer to store labels during training +} mp_obj_extratrees_model_t; + +mp_obj_full_type_t extratrees_model_type; + +// Create a new instance +static mp_obj_t extratrees_model_new(size_t n_args, const mp_obj_t *args) { + // Args: n_features, n_classes, [n_trees], [max_depth], [min_samples_leaf], [n_thresholds], + // [subsample_ratio], [feature_subsample_ratio], [max_nodes], [max_samples], [rng_seed] + if (n_args < 2 || n_args > 11) { + mp_raise_ValueError(MP_ERROR_TEXT("Expected 2-11 arguments: n_features, n_classes, [n_trees=10], [max_depth=10], [min_samples_leaf=1], [n_thresholds=10], [subsample_ratio=1.0], [feature_subsample_ratio=1.0], [max_nodes=1000], [max_samples=1000], [rng_seed=42]")); + } + + mp_int_t n_features = mp_obj_get_int(args[0]); + mp_int_t n_classes = mp_obj_get_int(args[1]); + mp_int_t n_trees = (n_args > 2) ? mp_obj_get_int(args[2]) : 10; + mp_int_t max_depth = (n_args > 3) ? mp_obj_get_int(args[3]) : 10; + mp_int_t min_samples_leaf = (n_args > 4) ? mp_obj_get_int(args[4]) : 1; + mp_int_t n_thresholds = (n_args > 5) ? mp_obj_get_int(args[5]) : 10; + float subsample_ratio = (n_args > 6) ? mp_obj_get_float(args[6]) : 1.0f; + float feature_subsample_ratio = (n_args > 7) ? mp_obj_get_float(args[7]) : 1.0f; + mp_int_t max_nodes = (n_args > 8) ? mp_obj_get_int(args[8]) : 1000; + mp_int_t max_samples = (n_args > 9) ? mp_obj_get_int(args[9]) : 1000; + mp_int_t rng_seed = (n_args > 10) ? mp_obj_get_int(args[10]) : 42; + + // Allocate space + mp_obj_extratrees_model_t *o = \ + mp_obj_malloc(mp_obj_extratrees_model_t, (mp_obj_type_t *)&extratrees_model_type); + + EmlTreesModel *model = &o->model; + EmlTreesWorkspace *workspace = &o->workspace; + memset(model, 0, sizeof(EmlTreesModel)); + memset(workspace, 0, sizeof(EmlTreesWorkspace)); + + // Configure model + model->n_features = n_features; + model->n_classes = n_classes; + model->n_trees = n_trees; + model->max_nodes = max_nodes; + model->n_nodes_used = 0; + + // Configure model config + model->config.max_depth = max_depth; + model->config.min_samples_leaf = min_samples_leaf; + model->config.n_thresholds = n_thresholds; + model->config.subsample_ratio = subsample_ratio; + model->config.feature_subsample_ratio = feature_subsample_ratio; + model->config.rng_seed = rng_seed; + + // Allocate model buffers + model->nodes = (EmlTreesNode *)m_malloc(sizeof(EmlTreesNode) * max_nodes); + model->tree_starts = (int16_t *)m_malloc(sizeof(int16_t) * n_trees); + + // Allocate workspace buffers + workspace->sample_indices = (int16_t *)m_malloc(sizeof(int16_t) * max_samples); + workspace->feature_indices = (int16_t *)m_malloc(sizeof(int16_t) * n_features); + workspace->min_vals = (int16_t *)m_malloc(sizeof(int16_t) * n_features); + workspace->max_vals = (int16_t *)m_malloc(sizeof(int16_t) * n_features); + workspace->node_stack = (NodeState *)m_malloc(sizeof(NodeState) * 100); // Stack limit + workspace->n_samples = 0; // Will be set during training + workspace->rng_state = rng_seed; + + // Allocate training data buffers + o->features_buffer = (int16_t *)m_malloc(sizeof(int16_t) * max_samples * n_features); + o->labels_buffer = (int16_t *)m_malloc(sizeof(int16_t) * max_samples); + + // Initialize nodes and tree starts + memset(model->nodes, 0, sizeof(EmlTreesNode) * max_nodes); + memset(model->tree_starts, 0, sizeof(int16_t) * n_trees); + + return MP_OBJ_FROM_PTR(o); +} +// Define a Python reference to the function above +static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_new_obj, 2, 11, extratrees_model_new); + +// Delete an instance +static mp_obj_t extratrees_model_del(mp_obj_t self_obj) { + mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj); + EmlTreesModel *model = &o->model; + EmlTreesWorkspace *workspace = &o->workspace; + + // Free allocated memory + m_free(model->nodes); + m_free(model->tree_starts); + m_free(workspace->sample_indices); + m_free(workspace->feature_indices); + m_free(workspace->min_vals); + m_free(workspace->max_vals); + m_free(workspace->node_stack); + m_free(o->features_buffer); + m_free(o->labels_buffer); + + return mp_const_none; +} +// Define a Python reference to the function above +static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_del_obj, extratrees_model_del); + +// Train the model +static mp_obj_t extratrees_model_train(size_t n_args, const mp_obj_t *args) { + // Args: self, X, y + if (n_args != 3) { + mp_raise_ValueError(MP_ERROR_TEXT("Expected 3 arguments: self, X, y")); + } + + mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(args[0]); + EmlTreesModel *model = &o->model; + EmlTreesWorkspace *workspace = &o->workspace; + + // Extract X buffer + mp_buffer_info_t X_bufinfo; + mp_get_buffer_raise(args[1], &X_bufinfo, MP_BUFFER_READ); + if (X_bufinfo.typecode != 'h') { // int16_t + mp_raise_ValueError(MP_ERROR_TEXT("X expecting int16 array")); + } + const int16_t *X = X_bufinfo.buf; + const int X_len = X_bufinfo.len / sizeof(int16_t); + + // Extract y buffer + mp_buffer_info_t y_bufinfo; + mp_get_buffer_raise(args[2], &y_bufinfo, MP_BUFFER_READ); + if (y_bufinfo.typecode != 'h') { // int16_t + mp_raise_ValueError(MP_ERROR_TEXT("y expecting int16 array")); + } + const int16_t *y = y_bufinfo.buf; + const int y_len = y_bufinfo.len / sizeof(int16_t); + + // Validate dimensions + if (X_len != y_len * model->n_features) { + mp_raise_ValueError(MP_ERROR_TEXT("X and y dimensions don't match")); + } + + const int16_t n_samples = y_len; + workspace->n_samples = n_samples; + + // Copy data to internal buffers (eml_trees expects non-const pointers) + memcpy(o->features_buffer, X, X_len * sizeof(int16_t)); + memcpy(o->labels_buffer, y, y_len * sizeof(int16_t)); + + // Perform training + int16_t result = eml_trees_train(model, workspace, o->features_buffer, o->labels_buffer); + + if (result != 0) { + mp_raise_ValueError(MP_ERROR_TEXT("Training failed")); + } + + return mp_const_none; +} +static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_train_obj, 3, 3, extratrees_model_train); + +// Predict using the model (returns class probabilities) +static mp_obj_t extratrees_model_predict_proba(mp_obj_fun_bc_t *self_obj, + size_t n_args, size_t n_kw, mp_obj_t *args) { + // Check number of arguments is valid + mp_arg_check_num(n_args, n_kw, 3, 3, false); + + mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(args[0]); + EmlTreesModel *model = &o->model; + + // Extract features buffer pointer and verify typecode + mp_buffer_info_t features_bufinfo; + mp_get_buffer_raise(args[1], &features_bufinfo, MP_BUFFER_READ); + if (features_bufinfo.typecode != 'h') { // int16_t + mp_raise_ValueError(MP_ERROR_TEXT("features expecting int16 array")); + } + const int16_t *features = features_bufinfo.buf; + const int n_features = features_bufinfo.len / sizeof(int16_t); + + if (n_features != model->n_features) { + mp_raise_ValueError(MP_ERROR_TEXT("Feature count mismatch")); + } + + // Extract probabilities output buffer + mp_buffer_info_t proba_bufinfo; + mp_get_buffer_raise(args[2], &proba_bufinfo, MP_BUFFER_WRITE); + if (proba_bufinfo.typecode != 'f') { // float + mp_raise_ValueError(MP_ERROR_TEXT("probabilities expecting float32 array")); + } + float *probabilities = proba_bufinfo.buf; + const int proba_len = proba_bufinfo.len / sizeof(float); + + if (proba_len != model->n_classes) { + mp_raise_ValueError(MP_ERROR_TEXT("Probabilities buffer size mismatch")); + } + + // Allocate temporary votes buffer + int16_t *votes = (int16_t *)m_malloc(sizeof(int16_t) * model->n_classes); + + // Make prediction + int16_t predicted_class = eml_trees_predict_proba(model, features, probabilities, votes); + + // Free temporary buffer + m_free(votes); + + return mp_obj_new_int(predicted_class); +} + +// Predict using the model (returns only class label) +static mp_obj_t extratrees_model_predict(mp_obj_fun_bc_t *self_obj, + size_t n_args, size_t n_kw, mp_obj_t *args) { + // Check number of arguments is valid + mp_arg_check_num(n_args, n_kw, 2, 2, false); + + mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(args[0]); + EmlTreesModel *model = &o->model; + + // Extract features buffer pointer and verify typecode + mp_buffer_info_t features_bufinfo; + mp_get_buffer_raise(args[1], &features_bufinfo, MP_BUFFER_READ); + if (features_bufinfo.typecode != 'h') { // int16_t + mp_raise_ValueError(MP_ERROR_TEXT("features expecting int16 array")); + } + const int16_t *features = features_bufinfo.buf; + const int n_features = features_bufinfo.len / sizeof(int16_t); + + if (n_features != model->n_features) { + mp_raise_ValueError(MP_ERROR_TEXT("Feature count mismatch")); + } + + // Allocate temporary buffers + float *probabilities = (float *)m_malloc(sizeof(float) * model->n_classes); + int16_t *votes = (int16_t *)m_malloc(sizeof(int16_t) * model->n_classes); + + // Make prediction + int16_t predicted_class = eml_trees_predict_proba(model, features, probabilities, votes); + + // Free temporary buffers + m_free(probabilities); + m_free(votes); + + return mp_obj_new_int(predicted_class); +} + +// Get number of features +static mp_obj_t extratrees_model_get_n_features(mp_obj_t self_obj) { + mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj); + EmlTreesModel *model = &o->model; + + return mp_obj_new_int(model->n_features); +} +// Define a Python reference to the function above +static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_features_obj, extratrees_model_get_n_features); + +// Get number of classes +static mp_obj_t extratrees_model_get_n_classes(mp_obj_t self_obj) { + mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj); + EmlTreesModel *model = &o->model; + + return mp_obj_new_int(model->n_classes); +} +// Define a Python reference to the function above +static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_classes_obj, extratrees_model_get_n_classes); + +// Get number of trees +static mp_obj_t extratrees_model_get_n_trees(mp_obj_t self_obj) { + mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj); + EmlTreesModel *model = &o->model; + + return mp_obj_new_int(model->n_trees); +} +// Define a Python reference to the function above +static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_trees_obj, extratrees_model_get_n_trees); + +// Get number of nodes used +static mp_obj_t extratrees_model_get_n_nodes_used(mp_obj_t self_obj) { + mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj); + EmlTreesModel *model = &o->model; + + return mp_obj_new_int(model->n_nodes_used); +} +// Define a Python reference to the function above +static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_nodes_used_obj, extratrees_model_get_n_nodes_used); + +// Module setup +mp_map_elem_t extratrees_model_locals_dict_table[10]; +static MP_DEFINE_CONST_DICT(extratrees_model_locals_dict, extratrees_model_locals_dict_table); + +// Module setup entrypoint +mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *args) { + // This must be first, it sets up the globals dict and other things + MP_DYNRUNTIME_INIT_ENTRY + + mp_store_global(MP_QSTR_new, MP_OBJ_FROM_PTR(&extratrees_model_new_obj)); + + extratrees_model_type.base.type = (void*)&mp_fun_table.type_type; + extratrees_model_type.flags = MP_TYPE_FLAG_ITER_IS_CUSTOM; + extratrees_model_type.name = MP_QSTR_extratrees; + + // methods + extratrees_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_DYNRUNTIME_MAKE_FUNCTION(extratrees_model_predict) }; + extratrees_model_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict_proba), MP_DYNRUNTIME_MAKE_FUNCTION(extratrees_model_predict_proba) }; + extratrees_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train), MP_OBJ_FROM_PTR(&extratrees_model_train_obj) }; + extratrees_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&extratrees_model_del_obj) }; + extratrees_model_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_features), MP_OBJ_FROM_PTR(&extratrees_model_get_n_features_obj) }; + extratrees_model_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_classes), MP_OBJ_FROM_PTR(&extratrees_model_get_n_classes_obj) }; + extratrees_model_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_obj) }; + extratrees_model_locals_dict_table[7] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_nodes_used), MP_OBJ_FROM_PTR(&extratrees_model_get_n_nodes_used_obj) }; + + MP_OBJ_TYPE_SET_SLOT(&extratrees_model_type, locals_dict, (void*)&extratrees_model_locals_dict, 8); + + // This must be last, it restores the globals dict + MP_DYNRUNTIME_INIT_EXIT +} From e9ca3eec4fc05e6da61049999a1f14e06fffe706 Mon Sep 17 00:00:00 2001 From: Jon Nordby Date: Tue, 8 Jul 2025 22:06:54 +0200 Subject: [PATCH 2/6] extratrees: Most bugs fixed Must accept nodes with no improvement in Gini, to get to later nodes --- src/emlearn_extratrees/eml_extratrees.c | 484 ++++++++++++++++-------- 1 file changed, 316 insertions(+), 168 deletions(-) diff --git a/src/emlearn_extratrees/eml_extratrees.c b/src/emlearn_extratrees/eml_extratrees.c index 7acf2ee..8f85f3d 100644 --- a/src/emlearn_extratrees/eml_extratrees.c +++ b/src/emlearn_extratrees/eml_extratrees.c @@ -1,6 +1,9 @@ #include #include #include +#include + +#define printf(fmt, ...) mp_printf(&mp_plat_print, fmt, ##__VA_ARGS__) typedef struct _EmlTreesNode { int8_t feature; // -1 for leaf nodes @@ -77,40 +80,238 @@ static float calculate_gini_from_counts(const int16_t *counts, int16_t total, in return gini; } -// Find best split for a node using count-based Gini (no temp_indices needed) + + + +// Partition samples based on feature threshold +static int16_t partition_samples(const int16_t *features, EmlTreesModel *model, + EmlTreesWorkspace *workspace, int16_t start, int16_t end, + int8_t feature, int16_t threshold) { + int16_t left = start; + int16_t right = end - 1; + + while (left <= right) { + // Find element on left that should be on right + while (left <= right && features[workspace->sample_indices[left] * model->n_features + feature] <= threshold) { + left++; + } + + // Find element on right that should be on left + while (left <= right && features[workspace->sample_indices[right] * model->n_features + feature] > threshold) { + right--; + } + + // Swap if needed + if (left < right) { + int16_t temp = workspace->sample_indices[left]; + workspace->sample_indices[left] = workspace->sample_indices[right]; + workspace->sample_indices[right] = temp; + left++; + right--; + } + } + + return left; // Split point +} + + + + + +// Add this debug version of eml_trees_predict_proba +int16_t eml_trees_predict_proba(const EmlTreesModel *model, const int16_t *features, + float *probabilities, int16_t *votes) { + + // Initialize vote counts + for (int16_t i = 0; i < model->n_classes; i++) { + votes[i] = 0; + } + + //printf("Prediction debug: features=[%d,%d]\n", features[0], features[1]); + + // Get prediction from each tree + for (int16_t tree = 0; tree < model->n_trees; tree++) { + int16_t node_idx = model->tree_starts[tree]; + //printf(" Tree %d: starting at node %d\n", tree, node_idx); + + // Traverse tree + int16_t steps = 0; + while (node_idx >= 0 && node_idx < model->n_nodes_used && + model->nodes[node_idx].feature != -1 && steps < 20) { + + int8_t feature = model->nodes[node_idx].feature; + int16_t threshold = model->nodes[node_idx].value; + int16_t left = model->nodes[node_idx].left; + int16_t right = model->nodes[node_idx].right; + + printf(" Node %d: feature=%d, threshold=%d, feature_val=%d\n", + node_idx, feature, threshold, features[feature]); + + if (features[feature] <= threshold) { + //printf(" Going LEFT to node %d\n", left); + node_idx = left; + } else { + //printf(" Going RIGHT to node %d\n", right); + node_idx = right; + } + steps++; + } + + // Check leaf node + if (node_idx >= 0 && node_idx < model->n_nodes_used) { + int16_t predicted_class = model->nodes[node_idx].value; + //printf(" Tree %d: reached leaf node %d, class=%d\n", tree, node_idx, predicted_class); + + if (predicted_class >= 0 && predicted_class < model->n_classes) { + votes[predicted_class]++; + } + } else { + //printf(" Tree %d: invalid leaf node %d\n", tree, node_idx); + } + } + + //printf("Final votes: [%d,%d]\n", votes[0], votes[1]); + + // Rest of function unchanged... + for (int16_t i = 0; i < model->n_classes; i++) { + probabilities[i] = (float)votes[i] / (float)model->n_trees; + } + + int16_t max_votes = 0; + int16_t predicted_class = 0; + for (int16_t i = 0; i < model->n_classes; i++) { + if (votes[i] > max_votes) { + max_votes = votes[i]; + predicted_class = i; + } + } + + return predicted_class; +} + + + + +// ALSO: Make sure get_majority_class is working correctly +static int16_t get_majority_class(const int16_t *labels, const int16_t *indices, + int16_t start, int16_t end, int16_t n_classes) { + int16_t counts[256] = {0}; + int16_t max_count = 0; + int16_t majority_class = 0; + + printf("get_majority_class: samples %d to %d\n", start, end-1); + + if (start >= end) { + printf(" No samples, returning class 0\n"); + return 0; + } + + // Count occurrences + for (int16_t i = start; i < end; i++) { + int16_t sample_idx = indices[i]; + int16_t label = labels[sample_idx]; + + if (label >= 0 && label < n_classes) { + counts[label]++; + printf(" Sample %d: index=%d, label=%d\n", i, sample_idx, label); + } + } + + // Find majority + for (int16_t i = 0; i < n_classes; i++) { + if (counts[i] > max_count) { + max_count = counts[i]; + majority_class = i; + } + } + + printf(" Counts: [%d,%d], majority class: %d\n", counts[0], counts[1], majority_class); + + return majority_class; +} + + +// CRITICAL FIX: Accept splits with zero improvement +// For complex patterns like XOR, we need to allow splits that don't immediately improve Gini +// but will lead to better splits at deeper levels + static int16_t find_best_split(const int16_t *features, const int16_t *labels, EmlTreesModel *model, EmlTreesWorkspace *workspace, int16_t start, int16_t end, int16_t n_features_subset, int8_t *best_feature, int16_t *best_threshold) { - float best_improvement = -1.0f; // Best information gain found + float best_improvement = -1.0f; *best_feature = -1; *best_threshold = 0; int16_t total_samples = end - start; - if (total_samples < 2) return -1; - // Calculate parent impurity + if (total_samples < 2) { + return -1; + } + + // Calculate parent class distribution int16_t parent_counts[256] = {0}; for (int16_t i = start; i < end; i++) { - parent_counts[labels[workspace->sample_indices[i]]]++; + int16_t sample_idx = workspace->sample_indices[i]; + int16_t label = labels[sample_idx]; + parent_counts[label]++; } + float parent_gini = calculate_gini_from_counts(parent_counts, total_samples, model->n_classes); - // Try each feature in the subset + // If already pure, no split needed + if (parent_gini == 0.0f) { + return -1; + } + + // Try each feature for (int16_t f = 0; f < n_features_subset; f++) { int16_t feature_idx = workspace->feature_indices[f]; - int16_t min_val = workspace->min_vals[feature_idx]; - int16_t max_val = workspace->max_vals[feature_idx]; - if (min_val >= max_val) continue; + // Collect unique values for this feature in this node + int16_t unique_vals[50]; + int16_t n_unique = 0; - // Try random thresholds - for (int16_t t = 0; t < model->config.n_thresholds; t++) { - int16_t threshold = min_val + (eml_rand(&workspace->rng_state) % (max_val - min_val + 1)); + for (int16_t i = start; i < end; i++) { + int16_t sample_idx = workspace->sample_indices[i]; + int16_t val = features[sample_idx * model->n_features + feature_idx]; - // Count class distributions in left/right partitions without moving data - int16_t left_counts[256] = {0}; // Assume max 256 classes + // Check if already in unique_vals + bool already_present = false; + for (int16_t u = 0; u < n_unique; u++) { + if (unique_vals[u] == val) { + already_present = true; + break; + } + } + if (!already_present && n_unique < 50) { + unique_vals[n_unique++] = val; + } + } + + if (n_unique < 2) { + continue; // Need at least 2 unique values to split + } + + // Sort unique values to try thresholds between them + for (int16_t i = 0; i < n_unique - 1; i++) { + for (int16_t j = i + 1; j < n_unique; j++) { + if (unique_vals[i] > unique_vals[j]) { + int16_t temp = unique_vals[i]; + unique_vals[i] = unique_vals[j]; + unique_vals[j] = temp; + } + } + } + + // Try thresholds between consecutive unique values + for (int16_t u = 0; u < n_unique - 1; u++) { + // Use threshold between unique_vals[u] and unique_vals[u+1] + int16_t threshold = unique_vals[u]; + + // Count left/right distributions + int16_t left_counts[256] = {0}; int16_t right_counts[256] = {0}; int16_t left_total = 0, right_total = 0; @@ -128,20 +329,25 @@ static int16_t find_best_split(const int16_t *features, const int16_t *labels, } } - if (left_total == 0 || right_total == 0) continue; - if (left_total < model->config.min_samples_leaf || right_total < model->config.min_samples_leaf) continue; + // Check if split creates non-empty partitions + if (left_total == 0 || right_total == 0) { + continue; + } + + // Check min_samples_leaf constraint + if (left_total < model->config.min_samples_leaf || right_total < model->config.min_samples_leaf) { + continue; + } - // Calculate Gini for left and right partitions from counts + // Calculate improvement float left_gini = calculate_gini_from_counts(left_counts, left_total, model->n_classes); float right_gini = calculate_gini_from_counts(right_counts, right_total, model->n_classes); - - // Calculate weighted Gini impurity float weighted_gini = ((float)left_total * left_gini + (float)right_total * right_gini) / (float)total_samples; - - // Calculate information gain float improvement = parent_gini - weighted_gini; - if (improvement > best_improvement) { + // CRITICAL FIX: Accept splits with improvement >= 0.0 (not just > 0.0) + // This allows splits that don't immediately improve but may lead to better deeper splits + if (improvement >= best_improvement) { best_improvement = improvement; *best_feature = feature_idx; *best_threshold = threshold; @@ -149,62 +355,12 @@ static int16_t find_best_split(const int16_t *features, const int16_t *labels, } } - return (best_improvement > 0.0f) ? 0 : -1; // Return 0 if good split found, -1 otherwise -} - -// Partition samples based on feature threshold -static int16_t partition_samples(const int16_t *features, EmlTreesModel *model, - EmlTreesWorkspace *workspace, int16_t start, int16_t end, - int8_t feature, int16_t threshold) { - int16_t left = start; - int16_t right = end - 1; - - while (left <= right) { - // Find element on left that should be on right - while (left <= right && features[workspace->sample_indices[left] * model->n_features + feature] <= threshold) { - left++; - } - - // Find element on right that should be on left - while (left <= right && features[workspace->sample_indices[right] * model->n_features + feature] > threshold) { - right--; - } - - // Swap if needed - if (left < right) { - int16_t temp = workspace->sample_indices[left]; - workspace->sample_indices[left] = workspace->sample_indices[right]; - workspace->sample_indices[right] = temp; - left++; - right--; - } - } - - return left; // Split point -} - -// Get majority class in a range -static int16_t get_majority_class(const int16_t *labels, const int16_t *indices, - int16_t start, int16_t end, int16_t n_classes) { - int16_t counts[256] = {0}; // Assume max 256 classes - int16_t max_count = 0; - int16_t majority_class = 0; - - for (int16_t i = start; i < end; i++) { - counts[labels[indices[i]]]++; - } - - for (int16_t i = 0; i < n_classes; i++) { - if (counts[i] > max_count) { - max_count = counts[i]; - majority_class = i; - } - } - - return majority_class; + // CRITICAL FIX: Accept any valid split, even with zero improvement + // Change the return condition to accept improvement >= 0.0 + return (*best_feature != -1) ? 0 : -1; } -// Build a single tree +// ALSO: Ensure stopping criteria allow deep enough trees for XOR static int16_t build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace, const int16_t *features, const int16_t *labels) { @@ -227,42 +383,57 @@ static int16_t build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace, workspace->node_stack[0].end = workspace->n_samples; workspace->node_stack[0].depth = 0; - // Initialize min/max values - for (int16_t f = 0; f < model->n_features; f++) { - workspace->min_vals[f] = 32767; - workspace->max_vals[f] = -32768; - } - - // Calculate initial min/max - for (int16_t i = 0; i < workspace->n_samples; i++) { - for (int16_t f = 0; f < model->n_features; f++) { - int16_t val = features[workspace->sample_indices[i] * model->n_features + f]; - if (val < workspace->min_vals[f]) workspace->min_vals[f] = val; - if (val > workspace->max_vals[f]) workspace->max_vals[f] = val; - } - } - // Process stack while (stack_size > 0) { NodeState current = workspace->node_stack[--stack_size]; int16_t node_idx = current.node_idx; if (node_idx >= model->max_nodes) { - return -1; // Out of nodes + return -1; } - // Check stopping criteria + // Check stopping criteria - MODIFIED for XOR int16_t n_samples_node = current.end - current.start; - if (current.depth >= model->config.max_depth || - n_samples_node < model->config.min_samples_leaf * 2 || - n_samples_node <= 0) { - + + // Create leaf if: + // 1. Reached max depth, OR + // 2. Too few samples for further splitting, OR + // 3. Node is already pure + bool should_stop = false; + + if (current.depth >= model->config.max_depth) { + should_stop = true; + } else if (n_samples_node < 2 * model->config.min_samples_leaf) { + should_stop = true; + } else { + // Check if node is pure + int16_t first_label = -1; + bool is_pure = true; + for (int16_t i = current.start; i < current.end; i++) { + int16_t sample_idx = workspace->sample_indices[i]; + int16_t label = labels[sample_idx]; + if (first_label == -1) { + first_label = label; + } else if (label != first_label) { + is_pure = false; + break; + } + } + if (is_pure) { + should_stop = true; + } + } + + if (should_stop) { // Create leaf node + int16_t majority = get_majority_class(labels, workspace->sample_indices, + current.start, current.end, model->n_classes); + model->nodes[node_idx].feature = -1; - model->nodes[node_idx].value = get_majority_class(labels, workspace->sample_indices, - current.start, current.end, model->n_classes); + model->nodes[node_idx].value = majority; model->nodes[node_idx].left = -1; model->nodes[node_idx].right = -1; + if (node_idx >= model->n_nodes_used) { model->n_nodes_used = node_idx + 1; } @@ -278,11 +449,14 @@ static int16_t build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace, if (split_result != 0 || best_feature == -1) { // No valid split found, create leaf + int16_t majority = get_majority_class(labels, workspace->sample_indices, + current.start, current.end, model->n_classes); + model->nodes[node_idx].feature = -1; - model->nodes[node_idx].value = get_majority_class(labels, workspace->sample_indices, - current.start, current.end, model->n_classes); + model->nodes[node_idx].value = majority; model->nodes[node_idx].left = -1; model->nodes[node_idx].right = -1; + if (node_idx >= model->n_nodes_used) { model->n_nodes_used = node_idx + 1; } @@ -293,14 +467,16 @@ static int16_t build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace, int16_t split_point = partition_samples(features, model, workspace, current.start, current.end, best_feature, best_threshold); - // Check if partition was successful if (split_point <= current.start || split_point >= current.end) { // Partition failed, create leaf + int16_t majority = get_majority_class(labels, workspace->sample_indices, + current.start, current.end, model->n_classes); + model->nodes[node_idx].feature = -1; - model->nodes[node_idx].value = get_majority_class(labels, workspace->sample_indices, - current.start, current.end, model->n_classes); + model->nodes[node_idx].value = majority; model->nodes[node_idx].left = -1; model->nodes[node_idx].right = -1; + if (node_idx >= model->n_nodes_used) { model->n_nodes_used = node_idx + 1; } @@ -309,13 +485,20 @@ static int16_t build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace, // Calculate next available node indices int16_t next_node = model->n_nodes_used; + if (next_node <= node_idx) { + next_node = node_idx + 1; + } + if (next_node + 1 >= model->max_nodes) { - // Not enough space for children, create leaf + // Not enough space, create leaf + int16_t majority = get_majority_class(labels, workspace->sample_indices, + current.start, current.end, model->n_classes); + model->nodes[node_idx].feature = -1; - model->nodes[node_idx].value = get_majority_class(labels, workspace->sample_indices, - current.start, current.end, model->n_classes); + model->nodes[node_idx].value = majority; model->nodes[node_idx].left = -1; model->nodes[node_idx].right = -1; + if (node_idx >= model->n_nodes_used) { model->n_nodes_used = node_idx + 1; } @@ -328,14 +511,11 @@ static int16_t build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace, model->nodes[node_idx].left = next_node; model->nodes[node_idx].right = next_node + 1; - // Update n_nodes_used to reserve space for children + // Update n_nodes_used model->n_nodes_used = next_node + 2; - if (node_idx >= model->n_nodes_used - 2) { - model->n_nodes_used = node_idx + 1; - } - // Add children to stack (right first, then left for correct processing order) - if (stack_size < 100) { // Reasonable stack limit + // Add children to stack + if (stack_size < 98) { // Right child workspace->node_stack[stack_size].node_idx = model->nodes[node_idx].right; workspace->node_stack[stack_size].start = split_point; @@ -355,18 +535,25 @@ static int16_t build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace, return 0; } -// Main training function + +// CRITICAL: Also check that we're not accidentally filtering out all samples during subsampling +// In eml_trees_train, make sure to print subsample_size: + int16_t eml_trees_train(EmlTreesModel *model, EmlTreesWorkspace *workspace, const int16_t *features, const int16_t *labels) { model->n_nodes_used = 0; workspace->rng_state = model->config.rng_seed; + printf("Training: %d trees, %d total samples\n", model->n_trees, workspace->n_samples); + // Calculate subsample size int16_t subsample_size = (int16_t)((float)workspace->n_samples * model->config.subsample_ratio); if (subsample_size < 1) subsample_size = 1; if (subsample_size > workspace->n_samples) subsample_size = workspace->n_samples; + printf("Subsample size: %d (ratio=%.2f)\n", subsample_size, model->config.subsample_ratio); + // Initialize sample indices for (int16_t i = 0; i < workspace->n_samples; i++) { workspace->sample_indices[i] = i; @@ -374,20 +561,29 @@ int16_t eml_trees_train(EmlTreesModel *model, EmlTreesWorkspace *workspace, // Build each tree for (int16_t tree = 0; tree < model->n_trees; tree++) { - // Store tree start index + printf("\n=== Building tree %d ===\n", tree); + model->tree_starts[tree] = model->n_nodes_used; - // Subsample without replacement + // Subsample shuffle_indices(workspace->sample_indices, workspace->n_samples, &workspace->rng_state); - // Temporarily set n_samples to subsample size for tree building + printf("After shuffle, first few indices: "); + for (int16_t i = 0; i < (workspace->n_samples < 8 ? workspace->n_samples : 8); i++) { + printf("%d ", workspace->sample_indices[i]); + } + printf("\n"); + + // Set subsample size int16_t original_n_samples = workspace->n_samples; workspace->n_samples = subsample_size; - // Build tree with subsampled data + printf("Using %d samples for this tree\n", workspace->n_samples); + + // Build tree int16_t result = build_tree(model, workspace, features, labels); - // Restore original n_samples + // Restore original sample count workspace->n_samples = original_n_samples; if (result != 0) { @@ -395,55 +591,7 @@ int16_t eml_trees_train(EmlTreesModel *model, EmlTreesWorkspace *workspace, } } + printf("\nTraining completed: %d nodes total\n", model->n_nodes_used); return 0; } -// Prediction function that returns probabilities -int16_t eml_trees_predict_proba(const EmlTreesModel *model, const int16_t *features, - float *probabilities, int16_t *votes) { - - // Initialize vote counts using model's n_classes - for (int16_t i = 0; i < model->n_classes; i++) { - votes[i] = 0; - } - - // Get prediction from each tree - for (int16_t tree = 0; tree < model->n_trees; tree++) { - int16_t node_idx = model->tree_starts[tree]; - - // Traverse tree - while (model->nodes[node_idx].feature != -1) { - int8_t feature = model->nodes[node_idx].feature; - int16_t threshold = model->nodes[node_idx].value; - - if (features[feature] <= threshold) { - node_idx = model->nodes[node_idx].left; - } else { - node_idx = model->nodes[node_idx].right; - } - } - - // Add leaf prediction to votes - int16_t predicted_class = model->nodes[node_idx].value; - if (predicted_class >= 0 && predicted_class < model->n_classes) { - votes[predicted_class]++; - } - } - - // Convert votes to probabilities - for (int16_t i = 0; i < model->n_classes; i++) { - probabilities[i] = (float)votes[i] / (float)model->n_trees; - } - - // Find majority class - int16_t max_votes = 0; - int16_t predicted_class = 0; - for (int16_t i = 0; i < model->n_classes; i++) { - if (votes[i] > max_votes) { - max_votes = votes[i]; - predicted_class = i; - } - } - - return predicted_class; -} From 425d59807ca07c8d6c6561aa65304cf655b76cb3 Mon Sep 17 00:00:00 2001 From: Jon Nordby Date: Tue, 8 Jul 2025 23:27:15 +0200 Subject: [PATCH 3/6] extratrees: Add example for breast cancer and wine datasets Wine dataset trains in 15 seconds on ESP32-S3 with these parameters --- prepare_cancer.py | 64 +++++++++++++ prepare_wine.py | 86 ++++++++++++++++++ src/emlearn_extratrees/eml_extratrees.c | 6 ++ tests/test_extratrees_cancer.py | 85 ++++++++++++++++++ tests/test_extratrees_wine.py | 115 ++++++++++++++++++++++++ 5 files changed, 356 insertions(+) create mode 100644 prepare_cancer.py create mode 100644 prepare_wine.py create mode 100644 tests/test_extratrees_cancer.py create mode 100644 tests/test_extratrees_wine.py diff --git a/prepare_cancer.py b/prepare_cancer.py new file mode 100644 index 0000000..66c4143 --- /dev/null +++ b/prepare_cancer.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +import numpy as np +from sklearn.datasets import load_breast_cancer +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import ExtraTreesClassifier +from sklearn.metrics import accuracy_score, classification_report + +# Load dataset +data = load_breast_cancer() +X, y = data.data, data.target + +print(f"Dataset: {X.shape[0]} samples, {X.shape[1]} features") +print(f"Classes: {np.unique(y)} (0=malignant, 1=benign)") +print(f"Class distribution: {np.bincount(y)}") + +# Train/test split +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y +) + +# Scale features +scaler = StandardScaler() +X_train_scaled = scaler.fit_transform(X_train) +X_test_scaled = scaler.transform(X_test) + +# Convert to int16 range (0-1000) for MicroPython compatibility +X_train_int = ((X_train_scaled + 3) / 6 * 1000).astype(np.int16) +X_test_int = ((X_test_scaled + 3) / 6 * 1000).astype(np.int16) + +# Clip to valid range +X_train_int = np.clip(X_train_int, 0, 1000) +X_test_int = np.clip(X_test_int, 0, 1000) + +print(f"Train: {X_train_int.shape}, Test: {X_test_int.shape}") +print(f"Feature range: [{X_train_int.min()}, {X_train_int.max()}]") + +# Sklearn baseline +clf = ExtraTreesClassifier( + n_estimators=20, + max_depth=10, + min_samples_leaf=2, + random_state=42 +) + +clf.fit(X_train_scaled, y_train) +y_pred = clf.predict(X_test_scaled) +baseline_acc = accuracy_score(y_test, y_pred) + +print(f"\nSklearn ExtraTrees baseline: {baseline_acc:.3f}") +print(classification_report(y_test, y_pred, target_names=['malignant', 'benign'])) + +# Save data for MicroPython +np.save('X_train.npy', X_train_int) +np.save('y_train.npy', y_train.astype(np.int16)) +np.save('X_test.npy', X_test_int) +np.save('y_test.npy', y_test.astype(np.int16)) + +print(f"\nSaved files:") +print(f"X_train.npy: {X_train_int.shape} int16") +print(f"y_train.npy: {y_train.shape} int16") +print(f"X_test.npy: {X_test_int.shape} int16") +print(f"y_test.npy: {y_test.shape} int16") +print(f"\nTarget accuracy: {baseline_acc:.3f}") diff --git a/prepare_wine.py b/prepare_wine.py new file mode 100644 index 0000000..f5870eb --- /dev/null +++ b/prepare_wine.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import ExtraTreesClassifier +from sklearn.metrics import accuracy_score, classification_report + +# Load both red and white wine datasets +red_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv" +white_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv" + +red = pd.read_csv(red_url, sep=';') +white = pd.read_csv(white_url, sep=';') + +# Add wine type feature (0=red, 1=white) +red['wine_type'] = 0 +white['wine_type'] = 1 + +# Combine datasets +data = pd.concat([red, white], ignore_index=True) + +# Convert quality to binary classification (good wine: quality >= 6) +X = data.drop('quality', axis=1).values +y = (data['quality'] >= 6).astype(int).values + +print(f"Dataset: {X.shape[0]} samples, {X.shape[1]} features") +print(f"Classes: 0=poor wine (<6), 1=good wine (>=6)") +print(f"Class distribution: {np.bincount(y)}") +print(f"Good wine ratio: {y.mean():.2f}") + +# Train/test split +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y +) + +# Scale features +scaler = StandardScaler() +X_train_scaled = scaler.fit_transform(X_train) +X_test_scaled = scaler.transform(X_test) + +# Convert to int16 range (0-1000) for MicroPython +X_train_int = ((X_train_scaled + 3) / 6 * 1000).astype(np.int16) +X_test_int = ((X_test_scaled + 3) / 6 * 1000).astype(np.int16) + +X_train_int = np.clip(X_train_int, 0, 1000) +X_test_int = np.clip(X_test_int, 0, 1000) + +print(f"Train: {X_train_int.shape}, Test: {X_test_int.shape}") +print(f"Feature range: [{X_train_int.min()}, {X_train_int.max()}]") + +# Sklearn baseline +clf = ExtraTreesClassifier( + n_estimators=30, + max_depth=12, + min_samples_leaf=2, + random_state=42 +) + +clf.fit(X_train_scaled, y_train) +y_pred = clf.predict(X_test_scaled) +baseline_acc = accuracy_score(y_test, y_pred) + +print(f"\nSklearn ExtraTrees baseline: {baseline_acc:.3f}") +print(classification_report(y_test, y_pred, target_names=['poor_wine', 'good_wine'])) + +# Feature importance +feature_names = list(data.columns[:-1]) # Exclude 'quality' +importances = clf.feature_importances_ +top_features = sorted(zip(feature_names, importances), key=lambda x: x[1], reverse=True) +print(f"\nTop 5 features:") +for name, imp in top_features[:5]: + print(f" {name}: {imp:.3f}") + +# Save data for MicroPython +np.save('X_train.npy', X_train_int) +np.save('y_train.npy', y_train.astype(np.int16)) +np.save('X_test.npy', X_test_int) +np.save('y_test.npy', y_test.astype(np.int16)) + +print(f"\nSaved files:") +print(f"X_train.npy: {X_train_int.shape} int16") +print(f"y_train.npy: {y_train.shape} int16") +print(f"X_test.npy: {X_test_int.shape} int16") +print(f"y_test.npy: {y_test.shape} int16") +print(f"\nTarget accuracy: {baseline_acc:.3f}") diff --git a/src/emlearn_extratrees/eml_extratrees.c b/src/emlearn_extratrees/eml_extratrees.c index 8f85f3d..e79844a 100644 --- a/src/emlearn_extratrees/eml_extratrees.c +++ b/src/emlearn_extratrees/eml_extratrees.c @@ -3,7 +3,13 @@ #include #include +#define DEBUG 0 + +#if DEBUG #define printf(fmt, ...) mp_printf(&mp_plat_print, fmt, ##__VA_ARGS__) +#else +#define printf(fmt, ...) ((void)0) +#endif typedef struct _EmlTreesNode { int8_t feature; // -1 for leaf nodes diff --git a/tests/test_extratrees_cancer.py b/tests/test_extratrees_cancer.py new file mode 100644 index 0000000..c03779a --- /dev/null +++ b/tests/test_extratrees_cancer.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# MicroPython test script (run with MicroPython after dataset prep) +import array +import gc +import npyfile + +def load_npy_int16(filename): + """Load .npy file and convert to int16 array""" + shape, data = npyfile.load(filename) + return array.array('h', data) + +def test_real_dataset(): + print("=== REAL DATASET TEST ===") + + X_train_flat = load_npy_int16('X_train.npy') + y_train = load_npy_int16('y_train.npy') + X_test_flat = load_npy_int16('X_test.npy') + y_test = load_npy_int16('y_test.npy') + + + n_features = 30 + n_train = len(y_train) + n_test = len(y_test) + + print(f"Loaded: {n_train} train, {n_test} test samples") + print(f"Features: {n_features}") + + # Import after data loading to save memory + import emlearn_extratrees + + # Create model + model = emlearn_extratrees.new( + 30, # n_features + 2, # n_classes + 20, # n_trees + 12, # max_depth + 2, # min_samples_leaf + 15, # n_thresholds + 0.8, # subsample_ratio + 0.7, # feature_subsample_ratio + 3000, # max_nodes + 500, # max_samples + 42 # rng_seed + ) + + print("Training...") + model.train(X_train_flat, y_train) + print(f"Trained: {model.get_n_nodes_used()} nodes") + + # Test + correct = 0 + probabilities = array.array('f', [0.0, 0.0]) + + for i in range(n_test): + start_idx = i * n_features + end_idx = start_idx + n_features + features = array.array('h', X_test_flat[start_idx:end_idx]) + + predicted = model.predict_proba(features, probabilities) + actual = y_test[i] + + if predicted == actual: + correct += 1 + + if i < 5: # Show first few predictions + conf = max(probabilities[0], probabilities[1]) + print(f"Sample {i}: pred={predicted}, actual={actual}, conf={conf:.3f}") + + accuracy = correct / n_test + print(f"\nAccuracy: {accuracy:.3f} ({correct}/{n_test})") + print(f"Target (sklearn): ~0.965") + + if accuracy >= 0.90: + print("✅ EXCELLENT: Matches professional ML performance!") + elif accuracy >= 0.85: + print("✅ VERY GOOD: Strong real-world performance!") + elif accuracy >= 0.80: + print("✅ GOOD: Solid performance on real data!") + elif accuracy >= 0.70: + print("⚠️ FAIR: Working but needs tuning") + else: + print("❌ POOR: Significant issues") + +if __name__ == "__main__": + test_real_dataset() diff --git a/tests/test_extratrees_wine.py b/tests/test_extratrees_wine.py new file mode 100644 index 0000000..d599369 --- /dev/null +++ b/tests/test_extratrees_wine.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Wine Quality test for MicroPython + +import array +import gc +import time +import npyfile +import emlearn_extratrees + +def load_npy_int16(filename): + """Load .npy file and convert to int16 array""" + shape, data = npyfile.load(filename) + return array.array('h', data) + +def test_wine_quality(): + print("=== WINE QUALITY DATASET TEST ===") + + # Load preprocessed data + try: + X_train_flat = load_npy_int16('X_train.npy') + y_train = load_npy_int16('y_train.npy') + X_test_flat = load_npy_int16('X_test.npy') + y_test = load_npy_int16('y_test.npy') + except: + print("Error: Run wine_quality_prep.py first") + return + + n_features = 12 # 11 wine features + wine_type + n_train = len(y_train) + n_test = len(y_test) + + print(f"Loaded: {n_train} train, {n_test} test samples") + print(f"Features: {n_features} (alcohol, acidity, etc. + wine_type)") + print("Task: Predict good wine (quality >= 6) vs poor wine") + + # Create model - adjusted for large dataset constraints + model = emlearn_extratrees.new( + 12, # n_features + 2, # n_classes + 5, # n_trees + 10, # max_depth + 3, # min_samples_leaf + 20, # n_thresholds + 0.20, # subsample_ratio (much smaller: 15% of 5197 = ~780 samples) + 0.8, # feature_subsample_ratio + 2000, # max_nodes + 10000, # max_samples (matches subsample size) + 42 # rng_seed + ) + + train_start = time.ticks_ms() + print("Training...") + model.train(X_train_flat, y_train) + print(f"Trained: {model.get_n_nodes_used()} nodes") + train_duration = time.ticks_diff(time.ticks_ms(), train_start) + print('Time (ms)', train_duration) + + # Test + correct = 0 + tp, tn, fp, fn = 0, 0, 0, 0 + probabilities = array.array('f', [0.0, 0.0]) + + for i in range(n_test): + start_idx = i * n_features + end_idx = start_idx + n_features + features = array.array('h', X_test_flat[start_idx:end_idx]) + + predicted = model.predict_proba(features, probabilities) + actual = y_test[i] + + if predicted == actual: + correct += 1 + + # Confusion matrix + if predicted == 1 and actual == 1: + tp += 1 + elif predicted == 0 and actual == 0: + tn += 1 + elif predicted == 1 and actual == 0: + fp += 1 + elif predicted == 0 and actual == 1: + fn += 1 + + if i < 5: + conf = max(probabilities[0], probabilities[1]) + wine_quality = "good" if actual == 1 else "poor" + pred_quality = "good" if predicted == 1 else "poor" + print(f"Sample {i}: pred={pred_quality}, actual={wine_quality}, conf={conf:.3f}") + + accuracy = correct / n_test + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + + print(f"\nResults:") + print(f"Accuracy: {accuracy:.3f} ({correct}/{n_test})") + print(f"Precision: {precision:.3f}") + print(f"Recall: {recall:.3f}") + print(f"F1-Score: {f1:.3f}") + print(f"Confusion: TP={tp}, TN={tn}, FP={fp}, FN={fn}") + print(f"Target (sklearn): ~0.80") + + if accuracy >= 0.78: + print("✅ EXCELLENT: Great performance on wine quality!") + elif accuracy >= 0.75: + print("✅ VERY GOOD: Strong wine classification!") + elif accuracy >= 0.70: + print("✅ GOOD: Solid wine quality prediction!") + elif accuracy >= 0.65: + print("⚠️ FAIR: Working but could improve") + else: + print("❌ POOR: Needs significant improvement") + +if __name__ == "__main__": + test_wine_quality() From 453f350f362fdcdcb793cac444fb18e3fd92e7f1 Mon Sep 17 00:00:00 2001 From: Jon Nordby Date: Wed, 9 Jul 2025 00:53:59 +0200 Subject: [PATCH 4/6] Include extratrees in default build --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 47b645e..8f3171f 100644 --- a/Makefile +++ b/Makefile @@ -136,8 +136,8 @@ release: zip -r $(RELEASE_NAME).zip $(RELEASE_NAME) #cp $(RELEASE_NAME).zip emlearn-micropython-latest.zip -check: emlearn_trees.results emlearn_neighbors.results emlearn_iir.results emlearn_iir_q15.results emlearn_fft.results emlearn_kmeans.results emlearn_arrayutils.results emlearn_cnn.results emlearn_linreg.results +check: emlearn_trees.results emlearn_neighbors.results emlearn_iir.results emlearn_iir_q15.results emlearn_fft.results emlearn_kmeans.results emlearn_arrayutils.results emlearn_cnn.results emlearn_linreg.results emlearn_extratrees.results -dist: $(MODULES_PATH)/emlearn_trees.mpy $(MODULES_PATH)/emlearn_neighbors.mpy $(MODULES_PATH)/emlearn_iir.mpy $(MODULES_PATH)/emlearn_iir_q15.mpy $(MODULES_PATH)/emlearn_fft.mpy $(MODULES_PATH)/emlearn_kmeans.mpy $(MODULES_PATH)/emlearn_arrayutils.mpy $(MODULES_PATH)/emlearn_cnn_int8.mpy $(MODULES_PATH)/emlearn_cnn_fp32.mpy $(MODULES_PATH)/emlearn_linreg.mpy +dist: $(MODULES_PATH)/emlearn_trees.mpy $(MODULES_PATH)/emlearn_neighbors.mpy $(MODULES_PATH)/emlearn_iir.mpy $(MODULES_PATH)/emlearn_iir_q15.mpy $(MODULES_PATH)/emlearn_fft.mpy $(MODULES_PATH)/emlearn_kmeans.mpy $(MODULES_PATH)/emlearn_arrayutils.mpy $(MODULES_PATH)/emlearn_cnn_int8.mpy $(MODULES_PATH)/emlearn_cnn_fp32.mpy $(MODULES_PATH)/emlearn_linreg.mpy $(MODULES_PATH)/emlearn_extratrees.mpy From bbc604a8236d025be947a4e14c6b5485cd39b69d Mon Sep 17 00:00:00 2001 From: Jon Nordby Date: Wed, 9 Jul 2025 01:06:00 +0200 Subject: [PATCH 5/6] extratrees: Add a missing test file --- tests/test_extratrees_xor.py | 239 +++++++++++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 tests/test_extratrees_xor.py diff --git a/tests/test_extratrees_xor.py b/tests/test_extratrees_xor.py new file mode 100644 index 0000000..a49456b --- /dev/null +++ b/tests/test_extratrees_xor.py @@ -0,0 +1,239 @@ +# Final comprehensive XOR test suite +import array +import emlearn_extratrees + +def test_xor_comprehensive(): + """Comprehensive XOR test with the fixed algorithm""" + print("=== Comprehensive XOR Test ===") + + # XOR training data - repeated for better training + base_pattern = [ + (0, 0, 0), # XOR: (0,0) -> 0 + (0, 100, 1), # XOR: (0,1) -> 1 + (100, 0, 1), # XOR: (1,0) -> 1 + (100, 100, 0), # XOR: (1,1) -> 0 + ] + + # Repeat pattern multiple times to give ensemble more training data + X_data = [] + y_data = [] + for _ in range(8): # 32 samples total + for x1, x2, y in base_pattern: + X_data.extend([x1, x2]) + y_data.append(y) + + X = array.array('h', X_data) + y = array.array('h', y_data) + + print(f"Training data: {len(y_data)} samples (8x XOR pattern)") + + # Test with ensemble of trees (now that individual trees work) + model = emlearn_extratrees.new( + 2, # n_features + 2, # n_classes + 10, # n_trees (ensemble) + 8, # max_depth + 1, # min_samples_leaf + 10, # n_thresholds + 0.8, # subsample_ratio (80% for diversity) + 1.0, # feature_subsample_ratio (use both features) + 500, # max_nodes + 100, # max_samples + 42 # rng_seed + ) + + model.train(X, y) + + print(f"Model: {model.get_n_trees()} trees, {model.get_n_nodes_used()} nodes total") + + # Test core XOR patterns + test_cases = [ + ([0, 0], 0), + ([0, 100], 1), + ([100, 0], 1), + ([100, 100], 0), + ] + + print("\nCore XOR Results:") + correct = 0 + probabilities = array.array('f', [0.0, 0.0]) + + for features, expected in test_cases: + test_features = array.array('h', features) + predicted = model.predict_proba(test_features, probabilities) + is_correct = predicted == expected + if is_correct: + correct += 1 + + confidence = max(probabilities[0], probabilities[1]) + print(f" {features} -> pred={predicted}, exp={expected}, conf={confidence:.2f} {'✓' if is_correct else '✗'}") + + core_accuracy = 100.0 * correct / 4 + print(f"Core XOR Accuracy: {core_accuracy:.0f}%") + + # Test interpolation (intermediate values) + print("\nInterpolation Test:") + interpolation_cases = [ + ([25, 25], "?"), # Between (0,0) and (100,100) - ambiguous + ([25, 75], "?"), # Between (0,100) and (100,0) - ambiguous + ([10, 90], 1), # Closer to (0,100) -> should be 1 + ([90, 10], 1), # Closer to (100,0) -> should be 1 + ([90, 90], 0), # Closer to (100,100) -> should be 0 + ([10, 10], 0), # Closer to (0,0) -> should be 0 + ] + + for features, expected in interpolation_cases: + test_features = array.array('h', features) + predicted = model.predict_proba(test_features, probabilities) + confidence = max(probabilities[0], probabilities[1]) + + if expected == "?": + marker = "?" + else: + marker = "✓" if predicted == expected else "✗" + + print(f" {features} -> pred={predicted}, exp={expected}, conf={confidence:.2f} {marker}") + + return core_accuracy >= 100 + +def test_xor_robustness(): + """Test XOR robustness with different parameters""" + print("\n=== XOR Robustness Test ===") + + # XOR data + X_data = [0, 0, 0, 100, 100, 0, 100, 100] * 6 # 24 samples + y_data = [0, 1, 1, 0] * 6 + + X = array.array('h', X_data) + y = array.array('h', y_data) + + configs = [ + (5, 6, "5 trees, depth 6"), + (15, 10, "15 trees, depth 10"), + (20, 12, "20 trees, depth 12"), + ] + + results = [] + + for n_trees, max_depth, desc in configs: + print(f"\nTesting {desc}:") + + model = emlearn_extratrees.new(2, 2, n_trees, max_depth, 1, 8, 0.9, 1.0, 1000, 100, 123) + model.train(X, y) + + # Test all XOR cases + correct = 0 + probabilities = array.array('f', [0.0, 0.0]) + test_cases = [([0, 0], 0), ([0, 100], 1), ([100, 0], 1), ([100, 100], 0)] + + for features, expected in test_cases: + test_features = array.array('h', features) + predicted = model.predict_proba(test_features, probabilities) + if predicted == expected: + correct += 1 + + accuracy = 100.0 * correct / 4 + results.append(accuracy) + print(f" Accuracy: {accuracy:.0f}% ({correct}/4 correct)") + + avg_accuracy = sum(results) / len(results) + print(f"\nAverage accuracy across configs: {avg_accuracy:.0f}%") + + return avg_accuracy >= 75 + +def test_xor_different_values(): + """Test XOR with different value ranges""" + print("\n=== XOR with Different Value Ranges ===") + + # Test with different value ranges to ensure generalization + test_ranges = [ + ([0, 1], "Binary"), + ([0, 10], "0-10"), + ([0, 1000], "0-1000"), + ([-50, 50], "-50 to 50"), + ] + + results = [] + + for value_range, desc in test_ranges: + print(f"\nTesting {desc} range:") + + low, high = value_range + X_data = [ + low, low, # (low,low) -> 0 + low, high, # (low,high) -> 1 + high, low, # (high,low) -> 1 + high, high, # (high,high) -> 0 + ] * 8 # 32 samples + y_data = [0, 1, 1, 0] * 8 + + X = array.array('h', X_data) + y = array.array('h', y_data) + + model = emlearn_extratrees.new(2, 2, 12, 10, 1, 10, 0.8, 1.0, 800, 100, 456) + model.train(X, y) + + # Test + test_cases = [ + ([low, low], 0), + ([low, high], 1), + ([high, low], 1), + ([high, high], 0), + ] + + correct = 0 + probabilities = array.array('f', [0.0, 0.0]) + + for features, expected in test_cases: + test_features = array.array('h', features) + predicted = model.predict_proba(test_features, probabilities) + if predicted == expected: + correct += 1 + + accuracy = 100.0 * correct / 4 + results.append(accuracy) + print(f" Accuracy: {accuracy:.0f}%") + + avg_accuracy = sum(results) / len(results) + print(f"\nAverage across value ranges: {avg_accuracy:.0f}%") + + return avg_accuracy >= 75 + +if __name__ == "__main__": + print("🔥 FIXED XOR TEST SUITE 🔥") + print("=" * 60) + + try: + # Test 1: Comprehensive XOR + success1 = test_xor_comprehensive() + + if success1: + print("\n✅ COMPREHENSIVE XOR TEST PASSED!") + + # Test 2: Robustness + success2 = test_xor_robustness() + + # Test 3: Different value ranges + success3 = test_xor_different_values() + + if success2: + print("\n✅ ROBUSTNESS TEST PASSED!") + if success3: + print("\n✅ VALUE RANGE TEST PASSED!") + + if success1 and success2 and success3: + print("\n🎉🎉🎉 ALL XOR TESTS PASSED! 🎉🎉🎉") + print("Your Extra Trees implementation is WORKING PERFECTLY!") + print("The algorithm can now learn complex non-linear patterns like XOR!") + else: + print("\n🔥 Core XOR works! Some edge cases may need fine-tuning.") + + else: + print("\n❌ Something is still wrong with the core algorithm") + + except Exception as e: + print(f"❌ Error: {e}") + import sys + sys.print_exception(e) + + print("\n" + "="*60) From b651831fbf8b44bc7cc7a0705cfa42fd4f8c2724 Mon Sep 17 00:00:00 2001 From: Jon Nordby Date: Thu, 10 Jul 2025 00:44:32 +0200 Subject: [PATCH 6/6] Use dedicated directory for dataset generation script --- prepare_california.py => tests/datasets/prepare_california.py | 0 prepare_cancer.py => tests/datasets/prepare_cancer.py | 0 prepare_wine.py => tests/datasets/prepare_wine.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename prepare_california.py => tests/datasets/prepare_california.py (100%) rename prepare_cancer.py => tests/datasets/prepare_cancer.py (100%) rename prepare_wine.py => tests/datasets/prepare_wine.py (100%) diff --git a/prepare_california.py b/tests/datasets/prepare_california.py similarity index 100% rename from prepare_california.py rename to tests/datasets/prepare_california.py diff --git a/prepare_cancer.py b/tests/datasets/prepare_cancer.py similarity index 100% rename from prepare_cancer.py rename to tests/datasets/prepare_cancer.py diff --git a/prepare_wine.py b/tests/datasets/prepare_wine.py similarity index 100% rename from prepare_wine.py rename to tests/datasets/prepare_wine.py