2828_USE_FAST_NORM = False # defaulting to False for now
2929
3030
31+ def get_autocast_dtype (device : str = 'cuda' ):
32+ try :
33+ return torch .get_autocast_dtype (device )
34+ except (AttributeError , TypeError ):
35+ # dispatch to older device specific fns, only covering cuda/cpu devices here
36+ if device == 'cpu' :
37+ return torch .get_autocast_cpu_dtype ()
38+ else :
39+ assert device == 'cuda'
40+ return torch .get_autocast_gpu_dtype ()
41+
42+
43+ def is_autocast_enabled (device : str = 'cuda' ):
44+ try :
45+ return torch .is_autocast_enabled (device )
46+ except TypeError :
47+ # dispatch to older device specific fns, only covering cuda/cpu devices here
48+ if device == 'cpu' :
49+ return torch .is_autocast_cpu_enabled ()
50+ else :
51+ assert device == 'cuda'
52+ return torch .is_autocast_enabled () # defaults cuda (only cuda on older pytorch)
53+
54+
3155def is_fast_norm ():
3256 return _USE_FAST_NORM
3357
@@ -48,14 +72,14 @@ def fast_group_norm(
4872 # currently cannot use is_autocast_enabled within torchscript
4973 return F .group_norm (x , num_groups , weight , bias , eps )
5074
51- if torch . is_autocast_enabled ():
75+ if is_autocast_enabled (x . device . type ):
5276 # normally native AMP casts GN inputs to float32
5377 # here we use the low precision autocast dtype
5478 # FIXME what to do re CPU autocast?
55- dt = torch . get_autocast_gpu_dtype ( )
79+ dt = get_autocast_dtype ( x . device . type )
5680 x , weight , bias = x .to (dt ), weight .to (dt ), bias .to (dt ) if bias is not None else None
5781
58- with torch .cuda . amp .autocast (enabled = False ):
82+ with torch .amp .autocast (device_type = x . device . type , enabled = False ):
5983 return F .group_norm (x , num_groups , weight , bias , eps )
6084
6185
@@ -73,14 +97,14 @@ def fast_layer_norm(
7397 if has_apex :
7498 return fused_layer_norm_affine (x , weight , bias , normalized_shape , eps )
7599
76- if torch . is_autocast_enabled ():
100+ if is_autocast_enabled (x . device . type ):
77101 # normally native AMP casts LN inputs to float32
78102 # apex LN does not, this is behaving like Apex
79- dt = torch . get_autocast_gpu_dtype ( )
103+ dt = get_autocast_dtype ( x . device . type )
80104 # FIXME what to do re CPU autocast?
81105 x , weight , bias = x .to (dt ), weight .to (dt ), bias .to (dt ) if bias is not None else None
82106
83- with torch .cuda . amp .autocast (enabled = False ):
107+ with torch .amp .autocast (device_type = x . device . type , enabled = False ):
84108 return F .layer_norm (x , normalized_shape , weight , bias , eps )
85109
86110
0 commit comments