diff --git a/ddprof-lib/src/main/cpp/javaApi.cpp b/ddprof-lib/src/main/cpp/javaApi.cpp index 7d45e4227..a2bd6b4a6 100644 --- a/ddprof-lib/src/main/cpp/javaApi.cpp +++ b/ddprof-lib/src/main/cpp/javaApi.cpp @@ -406,3 +406,24 @@ Java_com_datadoghq_profiler_JVMAccess_healthCheck0(JNIEnv *env, jobject unused) { return true; } + +extern "C" DLLEXPORT jlong JNICALL +Java_com_datadoghq_profiler_ActiveBitmaps_bitmapAddressFor0(JNIEnv *env, + jclass unused, + jint tid) { + u64* bitmap = Profiler::instance()->threadFilter()->bitmapAddressFor((int)tid); + return (jlong)bitmap; +} + +extern "C" DLLEXPORT jboolean JNICALL +Java_com_datadoghq_profiler_ActiveBitmaps_isActive(JNIEnv *env, + jclass unused, + jint tid) { + return Profiler::instance()->threadFilter()->accept((int)tid) ? JNI_TRUE : JNI_FALSE; +} + +extern "C" DLLEXPORT jlong JNICALL +Java_com_datadoghq_profiler_ActiveBitmaps_getActiveCountAddr0(JNIEnv *env, + jclass unused) { + return (jlong)Profiler::instance()->threadFilter()->addressOfSize(); +} diff --git a/ddprof-lib/src/main/cpp/reverse_bits.h b/ddprof-lib/src/main/cpp/reverse_bits.h new file mode 100644 index 000000000..81e46e19a --- /dev/null +++ b/ddprof-lib/src/main/cpp/reverse_bits.h @@ -0,0 +1,23 @@ +// +// Borrow the implementation from openjdk +// https://github.com/openjdk/jdk/blob/master/src/hotspot/share/utilities/reverse_bits.hpp +// + +#ifndef REVERSE_BITS_H +#define REVERSE_BITS_H +#include "arch_dd.h" +#include + +static constexpr u32 rep_5555 = static_cast(UINT64_C(0x5555555555555555)); +static constexpr u32 rep_3333 = static_cast(UINT64_C(0x3333333333333333)); +static constexpr u32 rep_0F0F = static_cast(UINT64_C(0x0F0F0F0F0F0F0F0F)); + +inline u16 reverse16(u16 v) { + u32 x = static_cast(v); + x = ((x & rep_5555) << 1) | ((x >> 1) & rep_5555); + x = ((x & rep_3333) << 2) | ((x >> 2) & rep_3333); + x = ((x & rep_0F0F) << 4) | ((x >> 4) & rep_0F0F); + return __builtin_bswap16(static_cast(x)); +} + +#endif //REVERSE_BITS_H diff --git a/ddprof-lib/src/main/cpp/threadFilter.cpp b/ddprof-lib/src/main/cpp/threadFilter.cpp index 034aabf9b..25c1cd387 100644 --- a/ddprof-lib/src/main/cpp/threadFilter.cpp +++ b/ddprof-lib/src/main/cpp/threadFilter.cpp @@ -17,6 +17,8 @@ #include "threadFilter.h" #include "counters.h" #include "os.h" +#include "reverse_bits.h" +#include #include #include @@ -85,17 +87,23 @@ void ThreadFilter::clear() { _size = 0; } -bool ThreadFilter::accept(int thread_id) { - u64 *b = bitmap(thread_id); - return b != NULL && (word(b, thread_id) & (1ULL << (thread_id & 0x3f))); +int ThreadFilter::mapThreadId(int thread_id) { + // We want to map the thread_id inside the same bitmap + static_assert(BITMAP_SIZE >= (u16)0xffff, "Potential verflow"); + u16 lower16 = (u16)(thread_id & 0xffff); + lower16 = reverse16(lower16); + int tid = (thread_id & ~0xffff) | lower16; + return tid; } -void ThreadFilter::add(int thread_id) { - u64 *b = bitmap(thread_id); + +u64* ThreadFilter::getBitmapFor(int thread_id) { + int index = static_cast(thread_id) / BITMAP_CAPACITY; + u64* b = _bitmap[index]; if (b == NULL) { b = (u64 *)OS::safeAlloc(BITMAP_SIZE); u64 *oldb = __sync_val_compare_and_swap( - &_bitmap[(u32)thread_id / BITMAP_CAPACITY], NULL, b); + &_bitmap[index], NULL, b); if (oldb != NULL) { OS::safeFree(b, BITMAP_SIZE); b = oldb; @@ -103,7 +111,25 @@ void ThreadFilter::add(int thread_id) { trackPage(); } } + return b; +} +u64* ThreadFilter::bitmapAddressFor(int thread_id) { + u64* bitmap = getBitmapFor(thread_id); + thread_id = mapThreadId(thread_id); + return wordAddress(bitmap, thread_id); +} + +bool ThreadFilter::accept(int thread_id) { + u64 *b = bitmap(thread_id); + thread_id = mapThreadId(thread_id); + return b != NULL && (word(b, thread_id) & (1ULL << (thread_id & 0x3f))); +} + +void ThreadFilter::add(int thread_id) { + u64 *b = getBitmapFor(thread_id); + assert(b != NULL); + thread_id = mapThreadId(thread_id); u64 bit = 1ULL << (thread_id & 0x3f); if (!(__sync_fetch_and_or(&word(b, thread_id), bit) & bit)) { atomicInc(_size); @@ -111,6 +137,7 @@ void ThreadFilter::add(int thread_id) { } void ThreadFilter::remove(int thread_id) { + thread_id = mapThreadId(thread_id); u64 *b = bitmap(thread_id); if (b == NULL) { return; @@ -132,7 +159,10 @@ void ThreadFilter::collect(std::vector &v) { // order here u64 word = __atomic_load_n(&b[j], __ATOMIC_ACQUIRE); while (word != 0) { - v.push_back(start_id + j * 64 + __builtin_ctzl(word)); + int tid = start_id + j * 64 + __builtin_ctzl(word); + // restore thread id + tid = mapThreadId(tid); + v.push_back(tid); word &= (word - 1); } } diff --git a/ddprof-lib/src/main/cpp/threadFilter.h b/ddprof-lib/src/main/cpp/threadFilter.h index cec7e7048..a5654dbe4 100644 --- a/ddprof-lib/src/main/cpp/threadFilter.h +++ b/ddprof-lib/src/main/cpp/threadFilter.h @@ -45,11 +45,19 @@ class ThreadFilter { __ATOMIC_ACQUIRE); } + static int mapThreadId(int thread_id); + u64 &word(u64 *bitmap, int thread_id) { // todo: add thread safe APIs return bitmap[((u32)thread_id % BITMAP_CAPACITY) >> 6]; } + u64* wordAddress(u64 *bitmap, int thread_id) { + return &bitmap[((u32)thread_id % BITMAP_CAPACITY) >> 6]; + } + + u64* getBitmapFor(int thread_id); + public: ThreadFilter(); ThreadFilter(ThreadFilter &threadFilter) = delete; @@ -58,6 +66,7 @@ class ThreadFilter { bool enabled() { return _enabled; } int size() { return _size; } + const volatile int* addressOfSize() const { return &_size; } void init(const char *filter); void clear(); @@ -65,6 +74,7 @@ class ThreadFilter { bool accept(int thread_id); void add(int thread_id); void remove(int thread_id); + u64* bitmapAddressFor(int thread_id); void collect(std::vector &v); }; diff --git a/ddprof-lib/src/main/java/com/datadoghq/profiler/ActiveBitmaps.java b/ddprof-lib/src/main/java/com/datadoghq/profiler/ActiveBitmaps.java new file mode 100644 index 000000000..167f01fb6 --- /dev/null +++ b/ddprof-lib/src/main/java/com/datadoghq/profiler/ActiveBitmaps.java @@ -0,0 +1,80 @@ +package com.datadoghq.profiler; + +import sun.misc.Unsafe; +import java.lang.reflect.Field; + + +class ActiveBitmaps { + private static final Unsafe UNSAFE; + static { + Unsafe unsafe = null; + try { + Field f = Unsafe.class.getDeclaredField("theUnsafe"); + f.setAccessible(true); + unsafe = (Unsafe) f.get(null); + } catch (Exception ignore) { } + UNSAFE = unsafe; + } + + private static long activeCountAddr = 0; + + private static final ThreadLocal Address = new ThreadLocal() { + @Override protected Long initialValue() { + return -1L; + } + }; + + public static void initialize() { + activeCountAddr = getActiveCountAddr0(); + } + + // Set bitmap to native code + static native long bitmapAddressFor0(int tid); + + static long getBitmask(int tid) { + int tmp = (tid >> 8) & 0xff ; + int bits = 0; + for (int index = 0; index < 7 ; index++) { + if ((tmp & 0x01) == 0x01) { + bits |= 0x01; + } + tmp >>= 1; + bits <<= 1; + } + return 1L << (bits & 0x3f); + } + + static void setActive(int tid, boolean active) { + long addr = Address.get(); + if (addr == -1) { + addr = bitmapAddressFor0(tid); + Address.set(addr); + } + long bitmask = getBitmask(tid); + long value = UNSAFE.getLong(addr); + long newVal; + if (active) { + newVal = value | bitmask; + } else { + newVal = value & ~bitmask; + } + while (!UNSAFE.compareAndSwapLong(null, addr, value, newVal)) { + value = UNSAFE.getLong(addr); + newVal = active ? (value | bitmask) : (value & ~bitmask); + } + int delta = active ? 1 : -1; + assert activeCountAddr != 0; + UNSAFE.getAndAddInt(null, activeCountAddr, delta); +// if (isActive(tid) != active) { +// throw new RuntimeException("SetActive failed"); +// } + + assert isActive(tid) == active; + } + + // For verification + static native boolean isActive(int tid); + + static native long getActiveCountAddr0(); +} + diff --git a/ddprof-lib/src/main/java/com/datadoghq/profiler/JavaProfiler.java b/ddprof-lib/src/main/java/com/datadoghq/profiler/JavaProfiler.java index 4436273c6..8ca4bd59e 100644 --- a/ddprof-lib/src/main/java/com/datadoghq/profiler/JavaProfiler.java +++ b/ddprof-lib/src/main/java/com/datadoghq/profiler/JavaProfiler.java @@ -108,6 +108,7 @@ public static synchronized JavaProfiler getInstance(String libLocation, String s throw new IOException("Failed to load Datadog Java profiler library", result.error); } init0(); + ActiveBitmaps.initialize(); profiler.initializeContextStorage(); instance = profiler; @@ -208,7 +209,7 @@ public boolean recordTraceRoot(long rootSpanId, String endpoint, int sizeLimit) * 'filter' option must be enabled to use this method. */ public void addThread() { - filterThread0(true); + ActiveBitmaps.setActive(TID.get(), true); } /** @@ -216,7 +217,7 @@ public void addThread() { * 'filter' option must be enabled to use this method. */ public void removeThread() { - filterThread0(false); + ActiveBitmaps.setActive(TID.get(), false); } diff --git a/ddprof-stresstest/src/jmh/java/com/datadoghq/profiler/stresstest/scenarios/ThreadFilterBenchmark.java b/ddprof-stresstest/src/jmh/java/com/datadoghq/profiler/stresstest/scenarios/ThreadFilterBenchmark.java new file mode 100644 index 000000000..682176af7 --- /dev/null +++ b/ddprof-stresstest/src/jmh/java/com/datadoghq/profiler/stresstest/scenarios/ThreadFilterBenchmark.java @@ -0,0 +1,247 @@ +package com.datadoghq.profiler.stresstest.scenarios; + +import com.datadoghq.profiler.JavaProfiler; +import com.datadoghq.profiler.stresstest.Configuration; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicIntegerArray; + +@State(Scope.Benchmark) +public class ThreadFilterBenchmark extends Configuration { + + private static final int NUM_THREADS = 4; + private ExecutorService executorService; + private JavaProfiler profiler; + private AtomicBoolean running; + private CountDownLatch startLatch; + private CountDownLatch stopLatch; + private AtomicLong operationCount; + private long startTime; + private long stopTime; + private PrintWriter logWriter; + private static final int ARRAY_SIZE = 1024; // Larger array to stress memory + private static final int[] sharedArray = new int[ARRAY_SIZE]; + private static final AtomicIntegerArray atomicArray = new AtomicIntegerArray(ARRAY_SIZE); + private static final int CACHE_LINE_SIZE = 64; // Typical cache line size + private static final int STRIDE = CACHE_LINE_SIZE / Integer.BYTES; // Elements per cache line + private boolean useThreadFilters = true; // Flag to control the use of thread filters + private AtomicLong addThreadCount = new AtomicLong(0); + private AtomicLong removeThreadCount = new AtomicLong(0); + + @Setup(Level.Trial) + public void setup() throws IOException { + System.out.println("Setting up benchmark..."); + System.out.println("Creating thread pool with " + NUM_THREADS + " threads"); + executorService = Executors.newFixedThreadPool(NUM_THREADS); + System.out.println("Getting profiler instance"); + profiler = JavaProfiler.getInstance(); + + // Stop the profiler if it's already running + try { + profiler.stop(); + } catch (IllegalStateException e) { + System.out.println("Profiler was not active at setup."); + } + + String config = "start,wall=10ms,filter=1,file=/tmp/thread_filter_profile.jfr"; + System.out.println("Starting profiler with " + config); + profiler.execute(config); + System.out.println("Started profiler with output file"); + running = new AtomicBoolean(true); + operationCount = new AtomicLong(0); + startTime = System.currentTimeMillis(); + stopTime = startTime + 30000; // Run for 30 seconds + System.out.println("Benchmark setup completed at " + startTime); + + try { + String logFile = "/tmp/thread_filter_benchmark.log"; + System.out.println("Attempting to create log file at: " + logFile); + logWriter = new PrintWriter(new FileWriter(logFile)); + logWriter.printf("Benchmark started at %d%n", startTime); + logWriter.flush(); + System.out.println("Successfully created and wrote to log file"); + } catch (IOException e) { + System.err.println("Failed to create log file: " + e.getMessage()); + e.printStackTrace(); + throw e; + } + } + + @TearDown(Level.Trial) + public void tearDown() { + System.out.println("Tearing down benchmark..."); + running.set(false); + + // Wait for all threads to finish with a timeout + try { + if (stopLatch != null) { + if (!stopLatch.await(30, TimeUnit.SECONDS)) { + System.err.println("Warning: Some threads did not finish within timeout"); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Shutdown executor with timeout + executorService.shutdown(); + try { + if (!executorService.awaitTermination(30, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + if (!executorService.awaitTermination(30, TimeUnit.SECONDS)) { + System.err.println("Warning: Executor did not terminate"); + } + } + } catch (InterruptedException e) { + executorService.shutdownNow(); + Thread.currentThread().interrupt(); + } + + // Stop the profiler if it's active + try { +// profiler.stop(); + } catch (IllegalStateException e) { + System.out.println("Profiler was not active at teardown."); + } + + long endTime = System.currentTimeMillis(); + long totalOps = operationCount.get(); + double durationSecs = (endTime - startTime) / 1000.0; + double opsPerSec = totalOps / durationSecs; + double addOpsPerSec = addThreadCount.get() / durationSecs; + double removeOpsPerSec = removeThreadCount.get() / durationSecs; + + String stats = String.format("Thread Filter Stats:%n" + + "Total operations: %,d%n" + + "Duration: %.2f seconds%n" + + "Operations/second: %,.0f%n" + + "Operations/second/thread: %,.0f%n" + + "AddThread operations/second: %,.0f%n" + + "RemoveThread operations/second: %,.0f%n", + totalOps, durationSecs, opsPerSec, opsPerSec / NUM_THREADS, addOpsPerSec, removeOpsPerSec); + + System.out.print(stats); + if (logWriter != null) { + try { + logWriter.print(stats); + logWriter.flush(); + logWriter.close(); + System.out.println("Successfully closed log file"); + } catch (Exception e) { + System.err.println("Error closing log file: " + e.getMessage()); + e.printStackTrace(); + } + } + } + + public void setUseThreadFilters(boolean useThreadFilters) { + this.useThreadFilters = useThreadFilters; + } + + @Benchmark + @BenchmarkMode(Mode.Throughput) + @Fork(value = 1, warmups = 0) + @Warmup(iterations = 1, time = 1) + @Measurement(iterations = 1, time = 2) + @Threads(1) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + public long threadFilterStress() throws InterruptedException { + System.out.println("Starting benchmark iteration..."); + startLatch = new CountDownLatch(NUM_THREADS); + stopLatch = new CountDownLatch(NUM_THREADS); + + // Start all worker threads + for (int i = 0; i < NUM_THREADS; i++) { + final int threadId = i; + executorService.submit(() -> { + try { + startLatch.countDown(); + startLatch.await(30, TimeUnit.SECONDS); + + String startMsg = String.format("Thread %d started%n", threadId); + System.out.print(startMsg); + if (logWriter != null) { + logWriter.print(startMsg); + logWriter.flush(); + } + + while (running.get() && System.currentTimeMillis() < stopTime) { + // Memory-intensive operations that would be sensitive to false sharing + for (int j = 0; j < ARRAY_SIZE; j += STRIDE) { + if (useThreadFilters) { + // Register thread at the start of each cache line operation + profiler.addThread(); + addThreadCount.incrementAndGet(); + } + + // Each thread writes to its own cache line + int baseIndex = (threadId * STRIDE) % ARRAY_SIZE; + for (int k = 0; k < STRIDE; k++) { + int index = (baseIndex + k) % ARRAY_SIZE; + // Write to shared array + sharedArray[index] = threadId; + // Read and modify + int value = sharedArray[index] + 1; + // Atomic operation + atomicArray.set(index, value); + } + + if (useThreadFilters) { + // Remove thread after cache line operation + profiler.removeThread(); + removeThreadCount.incrementAndGet(); + } + operationCount.incrementAndGet(); + } + + // More memory operations with thread registration + for (int j = 0; j < ARRAY_SIZE; j += STRIDE) { + if (useThreadFilters) { + // Register thread at the start of each cache line operation + profiler.addThread(); + addThreadCount.incrementAndGet(); + } + + int baseIndex = (threadId * STRIDE) % ARRAY_SIZE; + for (int k = 0; k < STRIDE; k++) { + int index = (baseIndex + k) % ARRAY_SIZE; + int value = atomicArray.get(index); + sharedArray[index] = value * 2; + } + + if (useThreadFilters) { + // Remove thread after cache line operation + profiler.removeThread(); + removeThreadCount.incrementAndGet(); + } + operationCount.incrementAndGet(); + } + + if (operationCount.get() % 1000 == 0) { + String progressMsg = String.format("Thread %d completed %d operations%n", threadId, operationCount.get()); + System.out.print(progressMsg); + if (logWriter != null) { + logWriter.print(progressMsg); + logWriter.flush(); + } + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + stopLatch.countDown(); + } + }); + } + + stopLatch.await(); + return operationCount.get(); + } +}