@@ -869,6 +869,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
869869 self .quant_input = self ._mod_extra_config .inputs [0 ]
870870 self .dequant_output = self ._mod_extra_config .outputs [0 ]
871871 if self .use_qdq :
872+ self .qdq_input = self ._mod_extra_config .inputs [1 ]
872873 self .update = self .update_qdq
873874 mod .update = self .update_qdq
874875 else :
@@ -885,8 +886,23 @@ def allocate(self, inp_seq_len, dtype, device, shape):
885886
886887 # overwrite update function of original module to force quant and dequant of cache input and output
887888 def update_qdq (self , prev , cur , dim , idx , inp_seq_len ):
888- qinput = self .quant_input (cur )
889- output = self .org_update (prev , qinput , dim , idx , inp_seq_len )
889+ """
890+ Explanation: If we want to optimize index_copy so it would run in fp8 instead of bf16
891+ we need the tensors to be in fp8 before calling index_copy.
892+ Also the `prev` and `curr` tensors need to be of the same dtype - and quanting them both
893+ from bf16 is no help, best we can do is have prev be initialized an fp8 tensor from the start.
894+ Since the initilization of `prev` is done in OHF (and that is not implemented yet) we
895+ currently need to support both options until the implementation in OHF is done, then
896+ can we remove the support for the bf16 `prev` option (the else here).
897+ """
898+ if prev .dtype == torch .float8_e4m3fn :
899+ qcurr = self .quant_input (cur )
900+ qoutput = self .org_update (prev , qcurr , dim , idx , inp_seq_len )
901+ output = self .dequant_output (qoutput )
902+ # TODO: remove the `else` part once the lp_dtype is implemented in OHF
903+ else :
904+ curr = self .qdq_input (cur )
905+ output = self .org_update (prev , curr , dim , idx , inp_seq_len )
890906 return output
891907
892908 # overwrite update function of original module to force quant and dequant of cache input and output
0 commit comments