diff --git a/adaptive_boxes.cu b/adaptive_boxes.cu index cc97ad3..d0715b6 100644 --- a/adaptive_boxes.cu +++ b/adaptive_boxes.cu @@ -6,6 +6,7 @@ #include #include #include +#include //STL #include // cuda call @@ -18,6 +19,9 @@ // csv #include "./include/csv_tools.h" #include "./include/io_tools.h" +// partition system +#include "./include/partition_graph.h" +#include "./include/partition_kernels.h" int main(int argc, char *argv[]){ @@ -92,6 +96,53 @@ int main(int argc, char *argv[]){ setup_kernel<<>>(devStates); cudaDeviceSynchronize(); + // Partition system initialization + const int partition_size = 32; // Default 32x32 tiles + const float density_threshold = 0.1f; // 10% filled pixels threshold + + // Adaptive partitioning threshold - only use partitions for large matrices + // Eliminates overhead for small datasets while preserving gains for large ones + bool use_partitions = (m * n > 500000); // ~707x707 threshold + + partition_t *partitions_d = nullptr; + int *adjacency_matrix_d = nullptr; + int partition_count = 0; + + // Setup partition kernels grid/block dimensions (declared here for use in main loop) + dim3 partition_grid, partition_block, connectivity_grid, connectivity_block; + + if (use_partitions) { + partition_count = calculate_partition_count(m, n, partition_size); + printf("Initializing partition system: %d partitions of size %dx%d\n", partition_count, partition_size, partition_size); + + // Allocate partition device memory + CC(cudaMalloc((void**)&partitions_d, partition_count * sizeof(partition_t))); + CC(cudaMalloc((void**)&adjacency_matrix_d, partition_count * partition_count * sizeof(int))); + + // Initialize partitions with zero values + CC(cudaMemset(partitions_d, 0, partition_count * sizeof(partition_t))); + CC(cudaMemset(adjacency_matrix_d, 0, partition_count * partition_count * sizeof(int))); + + // Configure partition kernels grid/block dimensions + partition_grid = dim3(partition_count, 1, 1); + partition_block = dim3(min(256, partition_size * partition_size), 1, 1); // Max 256 threads per block + + connectivity_grid = dim3((partition_count + 255) / 256, 1, 1); + connectivity_block = dim3(256, 1, 1); + + // Calculate initial density and connectivity + compute_partition_density<<>>(data_d, m, n, partitions_d, partition_count, partition_size); + cudaDeviceSynchronize(); + + build_connectivity_graph<<>>(partitions_d, partition_count, adjacency_matrix_d, density_threshold, partition_size, n); + cudaDeviceSynchronize(); + + update_partition_priorities<<>>(partitions_d, partition_count, adjacency_matrix_d); + cudaDeviceSynchronize(); + } else { + printf("Using original random exploration (matrix size %ldx%ld below threshold)\n", m, n); + } + // Loop printf("Working...\n"); rectangle_t rec; @@ -109,7 +160,7 @@ int main(int argc, char *argv[]){ int x1,x2,y1,y2; for (int step=0; step>>(devStates,m,n,data_d,out_d, areas_d); + find_largest_rectangle<<>>(devStates,m,n,data_d,out_d, areas_d, partitions_d, partition_count); cudaDeviceSynchronize(); thrust::device_vector::iterator iter = thrust::max_element(t_areas_d.begin(), t_areas_d.end()); @@ -148,9 +199,21 @@ int main(int argc, char *argv[]){ recs.push_back(rec); } + last_sum = sum; + + // Update partitions every 10th rectangle removal (only if using partitions) + if (use_partitions && step % 10 == 0) { + // Update affected partitions after rectangle removal + update_affected_partitions<<>>(x1, x2, y1, y2, partitions_d, partition_count, data_d, m, n, partition_size); + cudaDeviceSynchronize(); + + // Recompute priorities for updated partitions + update_partition_priorities<<>>(partitions_d, partition_count, adjacency_matrix_d); + cudaDeviceSynchronize(); + } + /*printf("sum = %d\n", sum); */ - last_sum = sum; if(sum<=0){ break; } @@ -177,6 +240,8 @@ int main(int argc, char *argv[]){ // Free memory cudaFree(devStates); + cudaFree(partitions_d); + cudaFree(adjacency_matrix_d); /*delete data;*/ return 0; diff --git a/include/partition_graph.h b/include/partition_graph.h new file mode 100644 index 0000000..4fc148c --- /dev/null +++ b/include/partition_graph.h @@ -0,0 +1,107 @@ +#ifndef PARTITION_GRAPH_H +#define PARTITION_GRAPH_H + +#include + +/** + * Partition structure representing a spatial tile in the 2D matrix + * Used for spatial partitioning to guide rectangle exploration + */ +struct partition_t { + int x_start, x_end; // Tile boundaries (inclusive) + int y_start, y_end; + float density; // Ratio of filled pixels (0.0-1.0) + int connectivity; // Number of connected neighbors + float priority; // Search priority score +}; + +/** + * Partition graph structure for spatial decomposition + * Manages partitions and their connectivity relationships + */ +struct partition_graph_t { + partition_t* partitions; // Device array of partitions + int partition_count; // Total number of partitions + int* adjacency_matrix; // Sparse connectivity matrix (partition_count x partition_count) + int* priority_queue; // Sorted partition indices by priority + int partition_size; // Size of each partition (e.g., 32x32) +}; + +/** + * Helper functions for partition management + */ + +/** + * Calculate number of partitions needed for given matrix dimensions + * + * Args: + * m: Matrix height + * n: Matrix width + * partition_size: Size of each partition tile + * + * Returns: + * Total number of partitions needed + */ +__host__ __device__ inline int calculate_partition_count(int m, int n, int partition_size) { + int partitions_y = (m + partition_size - 1) / partition_size; // Ceiling division + int partitions_x = (n + partition_size - 1) / partition_size; + return partitions_x * partitions_y; +} + +/** + * Get partition index for given matrix coordinates + * + * Args: + * row: Matrix row + * col: Matrix column + * n: Matrix width + * partition_size: Size of each partition tile + * + * Returns: + * Partition index + */ +__host__ __device__ inline int get_partition_index(int row, int col, int n, int partition_size) { + int partition_row = row / partition_size; + int partition_col = col / partition_size; + int partitions_x = (n + partition_size - 1) / partition_size; + return partition_row * partitions_x + partition_col; +} + +/** + * Get priority-guided partition for this thread (avoids clustering) + * Each thread gets a different high-priority partition using thread ID + * + * Args: + * partitions: Array of partitions + * partition_count: Number of partitions + * thread_id: Unique thread identifier for distribution + * + * Returns: + * Index of a high-priority partition for this thread + */ +__device__ inline int get_priority_guided_partition(partition_t* partitions, int partition_count, int thread_id) { + if (partitions == NULL || partition_count <= 0) return 0; + + // Fast sampling approach: avoid expensive linear search + // Sample from a smaller subset of partitions and pick the best among them + const int sample_size = min(8, partition_count); // Sample at most 8 partitions + + int best_idx = 0; + float best_priority = -1.0f; + + for (int i = 0; i < sample_size; i++) { + // Use thread_id and iteration to distribute sampling across partition space + int sample_idx = (thread_id + i * 1337) % partition_count; // Pseudo-random distribution + + float priority = partitions[sample_idx].priority; + if (priority > best_priority) { + best_priority = priority; + best_idx = sample_idx; + } + } + + + return best_idx; +} + +#endif // PARTITION_GRAPH_H \ No newline at end of file diff --git a/include/partition_kernels.h b/include/partition_kernels.h new file mode 100644 index 0000000..102dc42 --- /dev/null +++ b/include/partition_kernels.h @@ -0,0 +1,287 @@ +#ifndef PARTITION_KERNELS_H +#define PARTITION_KERNELS_H + +#include +#include "partition_graph.h" + +/** + * Compute density for each partition + * Each block processes one partition, threads process pixels within partition + * + * Args: + * data_matrix: Input binary matrix (m x n) + * m: Matrix height + * n: Matrix width + * partitions: Array of partitions to update + * partition_count: Number of partitions + * partition_size: Size of each partition tile + */ +__global__ void compute_partition_density(int* data_matrix, long m, long n, + partition_t* partitions, int partition_count, int partition_size) { + int partition_id = blockIdx.x; + if (partition_id >= partition_count || partitions == NULL || data_matrix == NULL) return; + + // Get partition boundaries + partition_t* partition = &partitions[partition_id]; + + // Calculate partition coordinates from ID + int partitions_x = (n + partition_size - 1) / partition_size; + int partition_row = partition_id / partitions_x; + int partition_col = partition_id % partitions_x; + + int x_start = partition_col * partition_size; + int y_start = partition_row * partition_size; + int x_end = min(x_start + partition_size - 1, (int)n - 1); + int y_end = min(y_start + partition_size - 1, (int)m - 1); + + // Update partition boundaries + partition->x_start = x_start; + partition->x_end = x_end; + partition->y_start = y_start; + partition->y_end = y_end; + + // Count filled pixels using block threads + __shared__ int block_sum; + if (threadIdx.x == 0) { + block_sum = 0; + } + __syncthreads(); + + // Each thread processes multiple pixels + int thread_count = 0; + int pixels_per_thread = ((y_end - y_start + 1) * (x_end - x_start + 1) + blockDim.x - 1) / blockDim.x; + + for (int i = 0; i < pixels_per_thread; i++) { + int pixel_idx = threadIdx.x * pixels_per_thread + i; + int total_pixels = (y_end - y_start + 1) * (x_end - x_start + 1); + + if (pixel_idx < total_pixels) { + int local_y = pixel_idx / (x_end - x_start + 1); + int local_x = pixel_idx % (x_end - x_start + 1); + int global_y = y_start + local_y; + int global_x = x_start + local_x; + + if (global_y < m && global_x < n) { + if (data_matrix[global_y * n + global_x] == 1) { + thread_count++; + } + } + } + } + + // Reduce thread counts to block sum + atomicAdd(&block_sum, thread_count); + __syncthreads(); + + // Calculate density + if (threadIdx.x == 0) { + int total_pixels = (y_end - y_start + 1) * (x_end - x_start + 1); + partition->density = total_pixels > 0 ? (float)block_sum / (float)total_pixels : 0.0f; + } +} + +/** + * Build connectivity graph between adjacent partitions + * Each thread processes one partition and checks its neighbors + * + * Args: + * partitions: Array of partitions + * partition_count: Number of partitions + * adjacency_matrix: Output connectivity matrix (partition_count x partition_count) + * density_threshold: Minimum density for connectivity + * partition_size: Size of each partition + * matrix_width: Width of original matrix + */ +__global__ void build_connectivity_graph(partition_t* partitions, int partition_count, + int* adjacency_matrix, float density_threshold, + int partition_size, int matrix_width) { + int partition_id = threadIdx.x + blockIdx.x * blockDim.x; + if (partition_id >= partition_count || partitions == NULL || adjacency_matrix == NULL) return; + + partition_t* current = &partitions[partition_id]; + + // Calculate grid dimensions + int partitions_x = (matrix_width + partition_size - 1) / partition_size; + int partition_row = partition_id / partitions_x; + int partition_col = partition_id % partitions_x; + + int connectivity_count = 0; + + // Check 8 neighbors (including diagonals) + for (int dy = -1; dy <= 1; dy++) { + for (int dx = -1; dx <= 1; dx++) { + if (dx == 0 && dy == 0) continue; // Skip self + + int neighbor_row = partition_row + dy; + int neighbor_col = partition_col + dx; + + // Check bounds + if (neighbor_row >= 0 && neighbor_col >= 0 && + neighbor_col < partitions_x && + neighbor_row < (partition_count / partitions_x + (partition_count % partitions_x ? 1 : 0))) { + + int neighbor_id = neighbor_row * partitions_x + neighbor_col; + if (neighbor_id < partition_count) { + partition_t* neighbor = &partitions[neighbor_id]; + + // Check if both partitions meet density threshold + if (current->density > density_threshold && neighbor->density > density_threshold) { + adjacency_matrix[partition_id * partition_count + neighbor_id] = 1; + connectivity_count++; + } else { + adjacency_matrix[partition_id * partition_count + neighbor_id] = 0; + } + } + } + } + } + + // Update connectivity count + current->connectivity = connectivity_count; +} + +/** + * Update partition priorities based on density and connectivity + * Each thread processes one partition + * + * Args: + * partitions: Array of partitions + * partition_count: Number of partitions + * adjacency_matrix: Connectivity matrix + */ +__global__ void update_partition_priorities(partition_t* partitions, int partition_count, + int* adjacency_matrix) { + int partition_id = threadIdx.x + blockIdx.x * blockDim.x; + if (partition_id >= partition_count || partitions == NULL) return; + + partition_t* partition = &partitions[partition_id]; + + // Validate partition boundaries before calculations + if (partition->x_end < partition->x_start || partition->y_end < partition->y_start) { + partition->priority = 0.0f; // Invalid partition gets zero priority + return; + } + + // Calculate area potential (larger partitions get slight boost) + int partition_width = partition->x_end - partition->x_start + 1; + int partition_height = partition->y_end - partition->y_start + 1; + + // Safety check for reasonable partition size + if (partition_width <= 0 || partition_height <= 0 || + partition_width > 128 || partition_height > 128) { + partition->priority = 0.0f; // Corrupted partition gets zero priority + return; + } + + float area_potential = sqrtf((float)(partition_width * partition_height)) / 32.0f; // Normalize to ~32x32 + + // Validate density before use + float density = partition->density; + if (density < 0.0f) density = 0.0f; + if (density > 1.0f) density = 1.0f; + + // Validate connectivity before use + int connectivity = partition->connectivity; + if (connectivity < 0) connectivity = 0; + if (connectivity > 8) connectivity = 8; // Max 8 neighbors + + // Priority formula: density × connectivity × area_potential + // Add small epsilon to avoid zero priorities + partition->priority = (density + 0.01f) * (connectivity + 1.0f) * area_potential; + + // Clamp priority to reasonable range + if (partition->priority < 0.0f) partition->priority = 0.01f; + if (partition->priority > 1000.0f) partition->priority = 1000.0f; // Cap at reasonable max +} + +/** + * Update affected partitions after rectangle removal + * Only recalculates density for partitions that intersect with removed rectangle + * + * Args: + * x1, x2, y1, y2: Rectangle boundaries that was removed + * partitions: Array of partitions + * partition_count: Number of partitions + * data_matrix: Updated matrix after rectangle removal + * m: Matrix height + * n: Matrix width + * partition_size: Size of each partition + */ +__global__ void update_affected_partitions(int x1, int x2, int y1, int y2, + partition_t* partitions, int partition_count, + int* data_matrix, long m, long n, int partition_size) { + int partition_id = blockIdx.x; + if (partition_id >= partition_count || partitions == NULL || data_matrix == NULL) return; + + // Additional safety checks for matrix dimensions + if (m <= 0 || n <= 0 || partition_size <= 0) return; + if (x1 < 0 || x2 >= n || y1 < 0 || y2 >= m || x2 < x1 || y2 < y1) return; + + partition_t* partition = &partitions[partition_id]; + + // Validate partition boundaries before use + if (partition->x_start < 0 || partition->x_end >= n || + partition->y_start < 0 || partition->y_end >= m || + partition->x_end < partition->x_start || partition->y_end < partition->y_start) { + // Partition boundaries invalid, skip update + return; + } + + // Check if partition intersects with removed rectangle (now with validated boundaries) + bool intersects = !(x2 < partition->x_start || x1 > partition->x_end || + y2 < partition->y_start || y1 > partition->y_end); + + if (!intersects) return; + + // Recalculate density for affected partition with robust bounds checking + __shared__ int block_sum; + if (threadIdx.x == 0) { + block_sum = 0; + } + __syncthreads(); + + int thread_count = 0; + int partition_width = partition->x_end - partition->x_start + 1; + int partition_height = partition->y_end - partition->y_start + 1; + int total_pixels = partition_width * partition_height; + + // Safety check for partition size + if (total_pixels <= 0 || total_pixels > partition_size * partition_size * 4) { + return; // Skip if partition seems corrupted + } + + int pixels_per_thread = (total_pixels + blockDim.x - 1) / blockDim.x; + + for (int i = 0; i < pixels_per_thread; i++) { + int pixel_idx = threadIdx.x * pixels_per_thread + i; + + if (pixel_idx < total_pixels) { + int local_y = pixel_idx / partition_width; + int local_x = pixel_idx % partition_width; + int global_y = partition->y_start + local_y; + int global_x = partition->x_start + local_x; + + // Double-check bounds before memory access + if (global_y >= 0 && global_y < m && global_x >= 0 && global_x < n) { + long matrix_idx = global_y * n + global_x; + if (matrix_idx >= 0 && matrix_idx < m * n) { // Final bounds check + if (data_matrix[matrix_idx] == 1) { + thread_count++; + } + } + } + } + } + + atomicAdd(&block_sum, thread_count); + __syncthreads(); + + if (threadIdx.x == 0) { + partition->density = total_pixels > 0 ? (float)block_sum / (float)total_pixels : 0.0f; + // Clamp density to valid range + if (partition->density < 0.0f) partition->density = 0.0f; + if (partition->density > 1.0f) partition->density = 1.0f; + } +} + +#endif // PARTITION_KERNELS_H \ No newline at end of file diff --git a/include/rectangular_explorer_kernel.h b/include/rectangular_explorer_kernel.h index 3c5f721..0b5d3bc 100644 --- a/include/rectangular_explorer_kernel.h +++ b/include/rectangular_explorer_kernel.h @@ -4,6 +4,8 @@ #include "./getters.h" // random #include "./random_generator.h" +// partition graph +#include "./partition_graph.h" // getters using namespace ng; @@ -25,7 +27,8 @@ using namespace ng; Each block has a largest rectangle that could find, in order to get the largest Rectangle, then compute the area and stores in areas variable. */ -__global__ void find_largest_rectangle(curandState *state, long m, long n, int *data_matrix, int *out, int *areas){ +__global__ void find_largest_rectangle(curandState *state, long m, long n, int *data_matrix, int *out, int *areas, + partition_t* partitions, int partition_count){ const int coords_m = 5; @@ -47,7 +50,8 @@ __global__ void find_largest_rectangle(curandState *state, long m, long n, int * int b_n = gridDim.x; - /* GET RANDOM POINT: the value of that random point in the matrix must be one(1) + /* GET PRIORITY-GUIDED POINT: Select point from high-priority partition + * Each thread gets a different high-priority partition to avoid clustering */ if(j==0){ areas[b_i*b_n + b_j] = 0; @@ -57,21 +61,63 @@ __global__ void find_largest_rectangle(curandState *state, long m, long n, int * unsigned int xx; unsigned int yy; - for(int g=0; g<100; g++){ - xx = curand(&localState); - yy = curand(&localState); - idx_i = abs((int)xx)%(m); - idx_j = abs((int)yy)%(n); - if (data_matrix[idx_i*n + idx_j]==1){ - is_sleeping = false; - break; - }else{ - is_sleeping = true; + + // Priority-guided selection: each thread gets different high-priority partition + bool found_in_partition = false; + if (partitions != NULL && partition_count > 0) { + int thread_id = b_i * gridDim.x + b_j; // Unique thread identifier + int partition_id = get_priority_guided_partition(partitions, partition_count, thread_id); + + // Bounds check partition_id + if (partition_id >= 0 && partition_id < partition_count) { + partition_t current_partition = partitions[partition_id]; + + // Validate partition boundaries + if (current_partition.x_start >= 0 && current_partition.x_end < n && + current_partition.y_start >= 0 && current_partition.y_end < m && + current_partition.x_end >= current_partition.x_start && + current_partition.y_end >= current_partition.y_start) { + + // Sample within high-priority partition + for(int g=0; g<50; g++){ + xx = curand(&localState); + yy = curand(&localState); + + int partition_width = current_partition.x_end - current_partition.x_start + 1; + int partition_height = current_partition.y_end - current_partition.y_start + 1; + + idx_j = current_partition.x_start + (abs((int)xx) % partition_width); + idx_i = current_partition.y_start + (abs((int)yy) % partition_height); + + if (data_matrix[idx_i*n + idx_j]==1){ + is_sleeping = false; + found_in_partition = true; + break; + } + } + } } } + + // Fallback to full matrix search if partition guidance failed + if (!found_in_partition) { + for(int g=0; g<100; g++){ + xx = curand(&localState); + yy = curand(&localState); + idx_i = abs((int)xx)%(m); + idx_j = abs((int)yy)%(n); + if (data_matrix[idx_i*n + idx_j]==1){ + is_sleeping = false; + break; + }else{ + is_sleeping = true; + } + } + } + state[id] = localState; - //printf("idx_i %d idx_j %d \n",idx_i,idx_j); + //printf("priority_guided: thread_id=%d, found_in_partition=%d, idx_i=%d, idx_j=%d\n", b_i*gridDim.x+b_j, found_in_partition, idx_i, idx_j); } __syncthreads();