Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions unified-runtime/test/conformance/event/urEventSetCallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,24 @@

#include "fixtures.h"
#include "uur/known_failure.h"
using namespace std::chrono_literals;

/* Using urEventReferenceTest to be able to release the event during the test */
using urEventSetCallbackTest = uur::event::urEventReferenceTest;
struct urEventSetCallbackTest : uur::event::urEventReferenceTest {
std::mutex m;
std::condition_variable cv;
int flag = 0;

void WaitForFlag(int Target = 1) {
std::unique_lock lk(m);
cv.wait_for(lk, 1000ms, [&] { return flag == Target; });
}

void SetFlag() {
flag++;
cv.notify_one();
}
};

/**
* Checks that the callback function is called.
Expand All @@ -22,19 +37,19 @@ TEST_P(urEventSetCallbackTest, Success) {
[[maybe_unused]] ur_execution_info_t execStatus,
void *pUserData) {

auto status = reinterpret_cast<bool *>(pUserData);
*status = true;
auto that = reinterpret_cast<urEventSetCallbackTest *>(pUserData);
that->SetFlag();
}
};

bool didRun = false;
ASSERT_SUCCESS(
urEventSetCallback(event, ur_execution_info_t::UR_EXECUTION_INFO_COMPLETE,
Callback::callback, &didRun));
Callback::callback, this));

ASSERT_SUCCESS(urEventWait(1, &event));
ASSERT_SUCCESS(urEventRelease(event));
ASSERT_TRUE(didRun);
WaitForFlag();
ASSERT_EQ(flag, 1);
}

/**
Expand All @@ -45,6 +60,7 @@ TEST_P(urEventSetCallbackTest, ValidateParameters) {
uur::LevelZeroV2{}, uur::NativeCPU{});

struct CallbackParameters {
urEventSetCallbackTest *test;
ur_event_handle_t event;
ur_execution_info_t execStatus;
};
Expand All @@ -56,17 +72,19 @@ TEST_P(urEventSetCallbackTest, ValidateParameters) {
auto parameters = reinterpret_cast<CallbackParameters *>(pUserData);
parameters->event = hEvent;
parameters->execStatus = execStatus;
parameters->test->SetFlag();
}
};

CallbackParameters parameters{};
CallbackParameters parameters{this, nullptr, UR_EXECUTION_INFO_QUEUED};

ASSERT_SUCCESS(
urEventSetCallback(event, ur_execution_info_t::UR_EXECUTION_INFO_COMPLETE,
Callback::callback, &parameters));

ASSERT_SUCCESS(urEventWait(1, &event));
ASSERT_SUCCESS(urEventRelease(event));
WaitForFlag();
ASSERT_EQ(event, parameters.event);
ASSERT_EQ(ur_execution_info_t::UR_EXECUTION_INFO_COMPLETE,
parameters.execStatus);
Expand All @@ -80,6 +98,7 @@ TEST_P(urEventSetCallbackTest, AllStates) {
uur::LevelZeroV2{}, uur::NativeCPU{});

struct CallbackStatus {
urEventSetCallbackTest *test = nullptr;
bool submitted = false;
bool running = false;
bool complete = false;
Expand Down Expand Up @@ -107,10 +126,12 @@ TEST_P(urEventSetCallbackTest, AllStates) {
FAIL() << "Invalid execution info enumeration";
}
}

status->test->SetFlag();
}
};

CallbackStatus status{};
CallbackStatus status{this};

ASSERT_SUCCESS(urEventSetCallback(
event, ur_execution_info_t::UR_EXECUTION_INFO_SUBMITTED,
Expand All @@ -124,6 +145,7 @@ TEST_P(urEventSetCallbackTest, AllStates) {

ASSERT_SUCCESS(urEventWait(1, &event));
ASSERT_SUCCESS(urEventRelease(event));
WaitForFlag(3);

ASSERT_TRUE(status.submitted);
ASSERT_TRUE(status.running);
Expand All @@ -145,19 +167,18 @@ TEST_P(urEventSetCallbackTest, EventAlreadyCompleted) {
[[maybe_unused]] ur_execution_info_t execStatus,
void *pUserData) {

auto status = reinterpret_cast<bool *>(pUserData);
*status = true;
auto that = reinterpret_cast<urEventSetCallbackTest *>(pUserData);
that->SetFlag();
}
};

bool didRun = false;

ASSERT_SUCCESS(
urEventSetCallback(event, ur_execution_info_t::UR_EXECUTION_INFO_COMPLETE,
Callback::callback, &didRun));
Callback::callback, this));

ASSERT_SUCCESS(urEventRelease(event));
ASSERT_TRUE(didRun);
WaitForFlag();
ASSERT_EQ(flag, 1);
}

UUR_INSTANTIATE_DEVICE_TEST_SUITE(urEventSetCallbackTest);
Expand Down