diff --git a/NCEModule.lua b/NCEModule.lua index 881cf5a..3acedaa 100644 --- a/NCEModule.lua +++ b/NCEModule.lua @@ -99,7 +99,8 @@ function NCEModule:updateOutput(inputTable) elseif self.batchnoise then self.output = (torch.type(self.output) == 'table' and #self.output == 4) and self.output or {input.new(), input.new(), input.new(), input.new()} - assert(torch.type(target) == 'torch.CudaTensor' or torch.type(target) == 'torch.LongTensor') + assert(torch.type(target) == 'torch.CudaLongTensor' or torch.type(target) == 'torch.CudaTensor' + or torch.type(target) == 'torch.LongTensor') self.sampleidx = self.sampleidx or target.new() -- the last elements contain the target indices