@@ -317,5 +317,30 @@ def test_draw_keypoints_errors():
317317 utils .draw_keypoints (image = img , keypoints = invalid_keypoints )
318318
319319
320+ def test_flow_to_image ():
321+ h , w = 100 , 100
322+ flow = torch .meshgrid (torch .arange (h ), torch .arange (w ), indexing = "ij" )
323+ flow = torch .stack (flow [::- 1 ], dim = 0 ).float ()
324+ flow [0 ] -= h / 2
325+ flow [1 ] -= w / 2
326+ img = utils .flow_to_image (flow )
327+ path = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "assets" , "expected_flow.pt" )
328+ expected_img = torch .load (path , map_location = "cpu" )
329+ assert_equal (expected_img , img )
330+
331+
332+ def test_flow_to_image_errors ():
333+ wrong_flow1 = torch .full ((3 , 10 , 10 ), 0 , dtype = torch .float )
334+ wrong_flow2 = torch .full ((2 , 10 ), 0 , dtype = torch .float )
335+ wrong_flow3 = torch .full ((2 , 10 , 30 ), 0 , dtype = torch .int )
336+
337+ with pytest .raises (ValueError , match = "Input flow should have shape" ):
338+ utils .flow_to_image (flow = wrong_flow1 )
339+ with pytest .raises (ValueError , match = "Input flow should have shape" ):
340+ utils .flow_to_image (flow = wrong_flow2 )
341+ with pytest .raises (ValueError , match = "Flow should be of dtype torch.float" ):
342+ utils .flow_to_image (flow = wrong_flow3 )
343+
344+
320345if __name__ == "__main__" :
321346 pytest .main ([__file__ ])
0 commit comments