1414 Dict ,
1515 Generic ,
1616 List ,
17+ Literal ,
1718 Optional ,
19+ overload ,
1820 Tuple ,
1921 Type ,
2022 TYPE_CHECKING ,
6264TTargetValue = TypeVar ("TTargetValue" )
6365
6466
67+ @overload
68+ def _to_tensor (
69+ name : str , arr : Optional [npt .ArrayLike ], none_ok : Literal [True ] = ...
70+ ) -> Optional [Tensor ]: ...
71+ @overload
72+ def _to_tensor (
73+ name : str , arr : Optional [npt .ArrayLike ], none_ok : Literal [False ] = ...
74+ ) -> Tensor : ...
75+ def _to_tensor (
76+ name : str , arr : Optional [npt .ArrayLike ], none_ok : bool = False
77+ ) -> Optional [Tensor ]:
78+ if arr is None :
79+ if none_ok :
80+ return None
81+ raise TypeError (f"Expected array-like for `{ name } ` but received None!" )
82+ if not isinstance (arr , Tensor ):
83+ arr = torch .tensor (arr )
84+ return arr
85+
86+
6587@dataclass (kw_only = True )
6688class BaseLLMAttributionResult (ABC , Generic [TInputValue , TTargetValue ]):
6789 """
@@ -77,6 +99,8 @@ class BaseLLMAttributionResult(ABC, Generic[TInputValue, TTargetValue]):
7799 ] # value for each target name e.g. token prob
78100 _aggregate_attr : Tensor # 1D [# input_values]
79101 _element_attr : Optional [Tensor ] = None # 2D [# target_names, # input_values]
102+ _aggregate_attr_var : Optional [Tensor ] = None # 1D [# input_values]
103+ _element_attr_var : Optional [Tensor ] = None # 2D [# target_names, # input_values]
80104 aggregate_descriptor : str = "Aggregate"
81105 element_descriptor : str = "Element"
82106
@@ -88,6 +112,8 @@ def __init__(
88112 target_values : Optional [Union [npt .ArrayLike , List [TTargetValue ]]] = None ,
89113 aggregate_attr : npt .ArrayLike ,
90114 element_attr : Optional [npt .ArrayLike ] = None ,
115+ aggregate_attr_var : Optional [npt .ArrayLike ] = None ,
116+ element_attr_var : Optional [npt .ArrayLike ] = None ,
91117 aggregate_descriptor : str = "Aggregate" ,
92118 element_descriptor : str = "Element" ,
93119 ) -> None :
@@ -96,6 +122,8 @@ def __init__(
96122 self .target_values = target_values
97123 self .aggregate_attr = aggregate_attr
98124 self .element_attr = element_attr
125+ self .aggregate_attr_var = aggregate_attr_var
126+ self .element_attr_var = element_attr_var
99127 self .aggregate_descriptor = aggregate_descriptor
100128 self .element_descriptor = element_descriptor
101129
@@ -105,10 +133,9 @@ def aggregate_attr(self) -> Tensor:
105133
106134 @aggregate_attr .setter
107135 def aggregate_attr (self , aggregate_attr : npt .ArrayLike ) -> None :
108- if isinstance (aggregate_attr , Tensor ):
109- self ._aggregate_attr = aggregate_attr
110- else :
111- self ._aggregate_attr = torch .tensor (aggregate_attr )
136+ self ._aggregate_attr = _to_tensor (
137+ "aggregate_attr" , aggregate_attr , none_ok = False
138+ )
112139 # IDEA: in the future we might want to support higher dim seq_attr
113140 # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
114141 assert len (self ._aggregate_attr .shape ) == 1 , "seq_attr must be a 1D tensor"
@@ -122,12 +149,7 @@ def element_attr(self) -> Optional[Tensor]:
122149
123150 @element_attr .setter
124151 def element_attr (self , element_attr : Optional [npt .ArrayLike ]) -> None :
125- if element_attr is None :
126- self ._element_attr = None
127- elif isinstance (element_attr , Tensor ):
128- self ._element_attr = element_attr
129- else :
130- self ._element_attr = torch .tensor (element_attr )
152+ self ._element_attr = _to_tensor ("element_attr" , element_attr , none_ok = True )
131153
132154 if self ._element_attr is not None :
133155 # IDEA: in the future we might want to support higher dim seq_attr
@@ -141,6 +163,39 @@ def element_attr(self, element_attr: Optional[npt.ArrayLike]) -> None:
141163 f"got { self ._element_attr .shape } "
142164 )
143165
166+ @property
167+ def aggregate_attr_var (self ) -> Optional [Tensor ]:
168+ return self ._aggregate_attr_var
169+
170+ @aggregate_attr_var .setter
171+ def aggregate_attr_var (self , aggregate_attr_var : Optional [npt .ArrayLike ]) -> None :
172+ self ._aggregate_attr_var = _to_tensor (
173+ "aggregate_attr_var" , aggregate_attr_var , none_ok = True
174+ )
175+ if self ._aggregate_attr_var is not None :
176+ assert self ._aggregate_attr_var .shape == self ._aggregate_attr .shape , (
177+ f"aggregate_attr ({ self ._aggregate_attr .shape } ) must have same shape "
178+ f"as aggregate_attr_var ({ self ._aggregate_attr_var .shape } )"
179+ )
180+
181+ @property
182+ def element_attr_var (self ) -> Optional [Tensor ]:
183+ return self ._element_attr_var
184+
185+ @element_attr_var .setter
186+ def element_attr_var (self , element_attr_var : Optional [npt .ArrayLike ]) -> None :
187+ self ._element_attr_var = _to_tensor (
188+ "element_attr_var" , element_attr_var , none_ok = True
189+ )
190+ if self ._element_attr_var is not None :
191+ assert (
192+ self ._element_attr is not None
193+ ), "element_attr must be set before setting element_attr_var"
194+ assert self ._element_attr_var .shape == self ._element_attr .shape , (
195+ f"element_attr ({ self ._element_attr .shape } ) must have same shape "
196+ f"as element_attr_var ({ self ._element_attr_var .shape } )"
197+ )
198+
144199 @property
145200 def target_values (self ) -> Optional [List [TTargetValue ]]:
146201 return self ._target_values
@@ -377,6 +432,22 @@ def token_attr(self) -> Optional[Tensor]:
377432 def token_attr (self , token_attr : Optional [npt .ArrayLike ]) -> None :
378433 self .element_attr = token_attr
379434
435+ @property
436+ def seq_attr_var (self ) -> Optional [Tensor ]:
437+ return self .aggregate_attr_var
438+
439+ @seq_attr_var .setter
440+ def seq_attr_var (self , seq_attr_var : Optional [npt .ArrayLike ]) -> None :
441+ self .aggregate_attr_var = seq_attr_var
442+
443+ @property
444+ def token_attr_var (self ) -> Optional [Tensor ]:
445+ return self .element_attr_var
446+
447+ @token_attr_var .setter
448+ def token_attr_var (self , token_attr_var : Optional [npt .ArrayLike ]) -> None :
449+ self .element_attr_var = token_attr_var
450+
380451 @property
381452 def seq_attr_dict (self ) -> Dict [TInputValue , float ]:
382453 return self .aggregate_attr_dict
@@ -402,6 +473,8 @@ def __init__(
402473 output_tokens : List [str ],
403474 seq_attr : npt .ArrayLike ,
404475 token_attr : Optional [npt .ArrayLike ] = None ,
476+ seq_attr_var : Optional [npt .ArrayLike ] = None ,
477+ token_attr_var : Optional [npt .ArrayLike ] = None ,
405478 output_probs : Optional [npt .ArrayLike ] = None ,
406479 ) -> None :
407480 super ().__init__ (
@@ -410,6 +483,8 @@ def __init__(
410483 target_values = output_probs ,
411484 aggregate_attr = seq_attr ,
412485 element_attr = token_attr ,
486+ aggregate_attr_var = seq_attr_var ,
487+ element_attr_var = token_attr_var ,
413488 aggregate_descriptor = "Sequence" ,
414489 element_descriptor = "Token" ,
415490 )
0 commit comments