@@ -489,12 +489,12 @@ def get_int(b: bytes) -> int:
489489
490490
491491SN3_PASCALVINCENT_TYPEMAP  =  {
492-     8 : ( torch .uint8 ,  np . uint8 ,  np . uint8 ) ,
493-     9 : ( torch .int8 ,  np . int8 ,  np . int8 ) ,
494-     11 : ( torch .int16 ,  np . dtype ( ">i2" ),  "i2" ) ,
495-     12 : ( torch .int32 ,  np . dtype ( ">i4" ),  "i4" ) ,
496-     13 : ( torch .float32 ,  np . dtype ( ">f4" ),  "f4" ) ,
497-     14 : ( torch .float64 ,  np . dtype ( ">f8" ),  "f8" ) ,
492+     8 : torch .uint8 ,
493+     9 : torch .int8 ,
494+     11 : torch .int16 ,
495+     12 : torch .int32 ,
496+     13 : torch .float32 ,
497+     14 : torch .float64 ,
498498}
499499
500500
@@ -511,11 +511,11 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
511511    ty  =  magic  //  256 
512512    assert  1  <=  nd  <=  3 
513513    assert  8  <=  ty  <=  14 
514-     m  =  SN3_PASCALVINCENT_TYPEMAP [ty ]
514+     torch_type  =  SN3_PASCALVINCENT_TYPEMAP [ty ]
515515    s  =  [get_int (data [4  *  (i  +  1 ) : 4  *  (i  +  2 )]) for  i  in  range (nd )]
516-     parsed  =  np .frombuffer (data , dtype = m [ 1 ] , offset = (4  *  (nd  +  1 )))
516+     parsed  =  torch .frombuffer (data , dtype = torch_type , offset = (4  *  (nd  +  1 )))
517517    assert  parsed .shape [0 ] ==  np .prod (s ) or  not  strict 
518-     return  torch . from_numpy ( parsed . astype ( m [ 2 ])) .view (* s )
518+     return  parsed .view (* s )
519519
520520
521521def  read_label_file (path : str ) ->  torch .Tensor :
0 commit comments