@@ -67,7 +67,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
6767def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , callback = lambda _ : _ , ** sampling_kwargs ):
6868 new_tokens , new_probs = [], []
6969 for i in range (num_new_tokens ):
70- with torch .backends . cuda . sdp_kernel ( enable_flash = False , enable_mem_efficient = False , enable_math = True ): # Actually better for Inductor to codegen attention here
70+ with torch .nn . attention . sdpa_kernel ( torch . nn . attention . SDPBackend . MATH ):
7171 next_token , next_prob = decode_one_token (
7272 model , cur_token , input_pos , ** sampling_kwargs
7373 )
@@ -345,15 +345,19 @@ def main(
345345 prefill = torch .compile (prefill , fullgraph = True , dynamic = True )
346346
347347 if memory_profile :
348- torch .cuda .memory ._record_memory_history (True ,trace_alloc_max_entries = 250000 , trace_alloc_record_context = True )
348+ if device != "cuda" :
349+ print ("Memory profiling only works on CUDA" )
350+ else :
351+ torch .cuda .memory ._record_memory_history (True ,trace_alloc_max_entries = 250000 , trace_alloc_record_context = True )
349352 aggregate_metrics = {
350353 'tokens_per_sec' : [],
351354 }
352355 start = - 1 if compile else 0
353356
354357 for i in range (start , num_samples ):
355358 if i == 0 :
356- torch .cuda .reset_peak_memory_stats ()
359+ if device == "cuda" :
360+ torch .cuda .reset_peak_memory_stats () # MKG
357361 device_sync (device = device ) # MKG
358362 if i >= 0 and interactive :
359363 prompt = input ("What is your prompt? " )
@@ -421,15 +425,18 @@ def callback(x):
421425 print (f"Bandwidth achieved: { model_size * tokens_sec :.02f} GB/s" )
422426
423427 if memory_profile and i == 0 :
424- snapshot = torch .cuda .memory ._snapshot ()
425- with open (f"{ memory_profile } .pickle" , 'wb' ) as f :
426- from pickle import dump
427- dump (snapshot , f )
428- print (
429- f"\n memory profile { memory_profile } .pickle saved, to convert that to a usable file, use" ,
430- "python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
431- )
432- break
428+ if device != "cuda" :
429+ print ("Memory profiling only works on CUDA" )
430+ else :
431+ snapshot = torch .cuda .memory ._snapshot ()
432+ with open (f"{ memory_profile } .pickle" , 'wb' ) as f :
433+ from pickle import dump
434+ dump (snapshot , f )
435+ print (
436+ f"\n memory profile { memory_profile } .pickle saved, to convert that to a usable file, use" ,
437+ "python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
438+ )
439+ break
433440
434441 print ("==========" )
435442
0 commit comments