diff --git a/dipu/torch_dipu/csrc_dipu/diopirt/diopirt_impl.cpp b/dipu/torch_dipu/csrc_dipu/diopirt/diopirt_impl.cpp index db909e62e..5aa5bb7d3 100644 --- a/dipu/torch_dipu/csrc_dipu/diopirt/diopirt_impl.cpp +++ b/dipu/torch_dipu/csrc_dipu/diopirt/diopirt_impl.cpp @@ -183,6 +183,24 @@ DIOPI_RT_API diopiError_t diopiGeneratorSetState( return diopiSuccess; } +DIOPI_RT_API diopiError_t diopiGeneratorGetSeedAndOffset( + diopiGeneratorHandle_t th, uint64_t& seed, uint64_t& offset) { + auto generator = reinterpret_cast(th); + auto gen_impl = at::check_generator(*generator); + offset = gen_impl->get_offset(); + seed = gen_impl->current_seed(); + return diopiSuccess; +} + +DIOPI_RT_API diopiError_t diopiGeneratorSetSeedAndOffset( + diopiGeneratorHandle_t th, uint64_t seed, uint64_t offset) { + auto generator = reinterpret_cast(th); + auto gen_impl = at::check_generator(*generator); + gen_impl->set_offset(offset); + gen_impl->set_current_seed(seed); + return diopiSuccess; +} + DIOPI_RT_API diopiError_t diopiRecordStart(const char* record_name, void** record) { *record = new RecordBlockCreator(record_name);