@@ -123,15 +123,18 @@ def preprocess(
123123 begins = []
124124 ends = []
125125
126- contexts_to_idx = {span : i for i , span in enumerate (contexts )}
126+ contexts_to_idx = {}
127+ for ctx in contexts :
128+ contexts_to_idx [ctx ] = len (contexts_to_idx )
129+ dedup_contexts = sorted (contexts_to_idx , key = contexts_to_idx .get )
127130 assert not pre_aligned or len (spans ) == len (contexts ), (
128131 "When `pre_aligned` is True, the number of spans and contexts must be the "
129132 "same."
130133 )
131134 aligned_contexts = (
132- [[c ] for c in contexts ]
135+ [[c ] for c in dedup_contexts ]
133136 if pre_aligned
134- else align_spans (contexts , spans , sort_by_overlap = True )
137+ else align_spans (dedup_contexts , spans , sort_by_overlap = True )
135138 )
136139 for i , (span , ctx ) in enumerate (zip (spans , aligned_contexts )):
137140 if len (ctx ) == 0 or ctx [0 ].start > span .start or ctx [0 ].end < span .end :
@@ -143,12 +146,16 @@ def preprocess(
143146 sequence_idx .append (contexts_to_idx [ctx [0 ]])
144147 begins .append (span .start - start )
145148 ends .append (span .end - start )
149+ assert begins [- 1 ] >= 0 , f"Begin offset is negative: { span .text } "
150+ assert ends [- 1 ] <= len (ctx [0 ]), f"End offset is out of bounds: { span .text } "
146151 return {
147152 "begins" : begins ,
148153 "ends" : ends ,
149154 "sequence_idx" : sequence_idx ,
150- "num_sequences" : len (contexts ),
151- "embedding" : self .embedding .preprocess (doc , contexts = contexts , ** kwargs ),
155+ "num_sequences" : len (dedup_contexts ),
156+ "embedding" : self .embedding .preprocess (
157+ doc , contexts = dedup_contexts , ** kwargs
158+ ),
152159 "stats" : {"spans" : len (begins )},
153160 }
154161
0 commit comments