2121import org .elasticsearch .client .ml .dataframe .evaluation .EvaluationMetric ;
2222import org .elasticsearch .common .Nullable ;
2323import org .elasticsearch .common .ParseField ;
24+ import org .elasticsearch .common .Strings ;
2425import org .elasticsearch .common .xcontent .ConstructingObjectParser ;
26+ import org .elasticsearch .common .xcontent .ToXContentObject ;
2527import org .elasticsearch .common .xcontent .XContentBuilder ;
2628import org .elasticsearch .common .xcontent .XContentParser ;
2729
2830import java .io .IOException ;
2931import java .util .Collections ;
30- import java .util .Map ;
32+ import java .util .List ;
3133import java .util .Objects ;
32- import java .util .TreeMap ;
3334
34- import static org .elasticsearch .common .xcontent .ConstructingObjectParser .constructorArg ;
3535import static org .elasticsearch .common .xcontent .ConstructingObjectParser .optionalConstructorArg ;
3636
3737/**
@@ -97,52 +97,52 @@ public int hashCode() {
9797 public static class Result implements EvaluationMetric .Result {
9898
9999 private static final ParseField CONFUSION_MATRIX = new ParseField ("confusion_matrix" );
100- private static final ParseField OTHER_CLASSES_COUNT = new ParseField ("_other_ " );
100+ private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField ("other_actual_class_count " );
101101
102102 @ SuppressWarnings ("unchecked" )
103103 private static final ConstructingObjectParser <Result , Void > PARSER =
104104 new ConstructingObjectParser <>(
105- "multiclass_confusion_matrix_result" , true , a -> new Result ((Map < String , Map < String , Long >> ) a [0 ], (long ) a [1 ]));
105+ "multiclass_confusion_matrix_result" , true , a -> new Result ((List < ActualClass > ) a [0 ], (Long ) a [1 ]));
106106
107107 static {
108- PARSER .declareObject (
109- constructorArg (),
110- (p , c ) -> p .map (TreeMap ::new , p2 -> p2 .map (TreeMap ::new , XContentParser ::longValue )),
111- CONFUSION_MATRIX );
112- PARSER .declareLong (constructorArg (), OTHER_CLASSES_COUNT );
108+ PARSER .declareObjectArray (optionalConstructorArg (), ActualClass .PARSER , CONFUSION_MATRIX );
109+ PARSER .declareLong (optionalConstructorArg (), OTHER_ACTUAL_CLASS_COUNT );
113110 }
114111
115112 public static Result fromXContent (XContentParser parser ) {
116113 return PARSER .apply (parser , null );
117114 }
118115
119- // Immutable
120- private final Map <String , Map <String , Long >> confusionMatrix ;
121- private final long otherClassesCount ;
116+ private final List <ActualClass > confusionMatrix ;
117+ private final Long otherActualClassCount ;
122118
123- public Result (Map < String , Map < String , Long >> confusionMatrix , long otherClassesCount ) {
124- this .confusionMatrix = Collections .unmodifiableMap (Objects .requireNonNull (confusionMatrix ));
125- this .otherClassesCount = otherClassesCount ;
119+ public Result (@ Nullable List < ActualClass > confusionMatrix , @ Nullable Long otherActualClassCount ) {
120+ this .confusionMatrix = confusionMatrix != null ? Collections .unmodifiableList (Objects .requireNonNull (confusionMatrix )) : null ;
121+ this .otherActualClassCount = otherActualClassCount ;
126122 }
127123
128124 @ Override
129125 public String getMetricName () {
130126 return NAME ;
131127 }
132128
133- public Map < String , Map < String , Long > > getConfusionMatrix () {
129+ public List < ActualClass > getConfusionMatrix () {
134130 return confusionMatrix ;
135131 }
136132
137- public long getOtherClassesCount () {
138- return otherClassesCount ;
133+ public Long getOtherActualClassCount () {
134+ return otherActualClassCount ;
139135 }
140136
141137 @ Override
142138 public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
143139 builder .startObject ();
144- builder .field (CONFUSION_MATRIX .getPreferredName (), confusionMatrix );
145- builder .field (OTHER_CLASSES_COUNT .getPreferredName (), otherClassesCount );
140+ if (confusionMatrix != null ) {
141+ builder .field (CONFUSION_MATRIX .getPreferredName (), confusionMatrix );
142+ }
143+ if (otherActualClassCount != null ) {
144+ builder .field (OTHER_ACTUAL_CLASS_COUNT .getPreferredName (), otherActualClassCount );
145+ }
146146 builder .endObject ();
147147 return builder ;
148148 }
@@ -153,12 +153,140 @@ public boolean equals(Object o) {
153153 if (o == null || getClass () != o .getClass ()) return false ;
154154 Result that = (Result ) o ;
155155 return Objects .equals (this .confusionMatrix , that .confusionMatrix )
156- && this .otherClassesCount == that .otherClassesCount ;
156+ && Objects . equals ( this .otherActualClassCount , that .otherActualClassCount ) ;
157157 }
158158
159159 @ Override
160160 public int hashCode () {
161- return Objects .hash (confusionMatrix , otherClassesCount );
161+ return Objects .hash (confusionMatrix , otherActualClassCount );
162+ }
163+ }
164+
165+ public static class ActualClass implements ToXContentObject {
166+
167+ private static final ParseField ACTUAL_CLASS = new ParseField ("actual_class" );
168+ private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField ("actual_class_doc_count" );
169+ private static final ParseField PREDICTED_CLASSES = new ParseField ("predicted_classes" );
170+ private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField ("other_predicted_class_doc_count" );
171+
172+ @ SuppressWarnings ("unchecked" )
173+ private static final ConstructingObjectParser <ActualClass , Void > PARSER =
174+ new ConstructingObjectParser <>(
175+ "multiclass_confusion_matrix_actual_class" ,
176+ true ,
177+ a -> new ActualClass ((String ) a [0 ], (Long ) a [1 ], (List <PredictedClass >) a [2 ], (Long ) a [3 ]));
178+
179+ static {
180+ PARSER .declareString (optionalConstructorArg (), ACTUAL_CLASS );
181+ PARSER .declareLong (optionalConstructorArg (), ACTUAL_CLASS_DOC_COUNT );
182+ PARSER .declareObjectArray (optionalConstructorArg (), PredictedClass .PARSER , PREDICTED_CLASSES );
183+ PARSER .declareLong (optionalConstructorArg (), OTHER_PREDICTED_CLASS_DOC_COUNT );
184+ }
185+
186+ private final String actualClass ;
187+ private final Long actualClassDocCount ;
188+ private final List <PredictedClass > predictedClasses ;
189+ private final Long otherPredictedClassDocCount ;
190+
191+ public ActualClass (@ Nullable String actualClass ,
192+ @ Nullable Long actualClassDocCount ,
193+ @ Nullable List <PredictedClass > predictedClasses ,
194+ @ Nullable Long otherPredictedClassDocCount ) {
195+ this .actualClass = actualClass ;
196+ this .actualClassDocCount = actualClassDocCount ;
197+ this .predictedClasses = predictedClasses != null ? Collections .unmodifiableList (predictedClasses ) : null ;
198+ this .otherPredictedClassDocCount = otherPredictedClassDocCount ;
199+ }
200+
201+ @ Override
202+ public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
203+ builder .startObject ();
204+ if (actualClass != null ) {
205+ builder .field (ACTUAL_CLASS .getPreferredName (), actualClass );
206+ }
207+ if (actualClassDocCount != null ) {
208+ builder .field (ACTUAL_CLASS_DOC_COUNT .getPreferredName (), actualClassDocCount );
209+ }
210+ if (predictedClasses != null ) {
211+ builder .field (PREDICTED_CLASSES .getPreferredName (), predictedClasses );
212+ }
213+ if (otherPredictedClassDocCount != null ) {
214+ builder .field (OTHER_PREDICTED_CLASS_DOC_COUNT .getPreferredName (), otherPredictedClassDocCount );
215+ }
216+ builder .endObject ();
217+ return builder ;
218+ }
219+
220+ @ Override
221+ public boolean equals (Object o ) {
222+ if (this == o ) return true ;
223+ if (o == null || getClass () != o .getClass ()) return false ;
224+ ActualClass that = (ActualClass ) o ;
225+ return Objects .equals (this .actualClass , that .actualClass )
226+ && Objects .equals (this .actualClassDocCount , that .actualClassDocCount )
227+ && Objects .equals (this .predictedClasses , that .predictedClasses )
228+ && Objects .equals (this .otherPredictedClassDocCount , that .otherPredictedClassDocCount );
229+ }
230+
231+ @ Override
232+ public int hashCode () {
233+ return Objects .hash (actualClass , actualClassDocCount , predictedClasses , otherPredictedClassDocCount );
234+ }
235+
236+ @ Override
237+ public String toString () {
238+ return Strings .toString (this );
239+ }
240+ }
241+
242+ public static class PredictedClass implements ToXContentObject {
243+
244+ private static final ParseField PREDICTED_CLASS = new ParseField ("predicted_class" );
245+ private static final ParseField COUNT = new ParseField ("count" );
246+
247+ @ SuppressWarnings ("unchecked" )
248+ private static final ConstructingObjectParser <PredictedClass , Void > PARSER =
249+ new ConstructingObjectParser <>(
250+ "multiclass_confusion_matrix_predicted_class" , true , a -> new PredictedClass ((String ) a [0 ], (Long ) a [1 ]));
251+
252+ static {
253+ PARSER .declareString (optionalConstructorArg (), PREDICTED_CLASS );
254+ PARSER .declareLong (optionalConstructorArg (), COUNT );
255+ }
256+
257+ private final String predictedClass ;
258+ private final Long count ;
259+
260+ public PredictedClass (@ Nullable String predictedClass , @ Nullable Long count ) {
261+ this .predictedClass = predictedClass ;
262+ this .count = count ;
263+ }
264+
265+ @ Override
266+ public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
267+ builder .startObject ();
268+ if (predictedClass != null ) {
269+ builder .field (PREDICTED_CLASS .getPreferredName (), predictedClass );
270+ }
271+ if (count != null ) {
272+ builder .field (COUNT .getPreferredName (), count );
273+ }
274+ builder .endObject ();
275+ return builder ;
276+ }
277+
278+ @ Override
279+ public boolean equals (Object o ) {
280+ if (this == o ) return true ;
281+ if (o == null || getClass () != o .getClass ()) return false ;
282+ PredictedClass that = (PredictedClass ) o ;
283+ return Objects .equals (this .predictedClass , that .predictedClass )
284+ && Objects .equals (this .count , that .count );
285+ }
286+
287+ @ Override
288+ public int hashCode () {
289+ return Objects .hash (predictedClass , count );
162290 }
163291 }
164292}
0 commit comments