88using Microsoft . ML . Data ;
99using Microsoft . ML . Data . IO ;
1010
11- [ assembly: LoadableClass ( typeof ( BinaryPredictionTransformer < object > ) , typeof ( BinaryPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
11+ [ assembly: LoadableClass ( typeof ( BinaryPredictionTransformer < IPredictorProducing < float > > ) , typeof ( BinaryPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
1212 "" , BinaryPredictionTransformer . LoaderSignature ) ]
1313
14- [ assembly: LoadableClass ( typeof ( MulticlassPredictionTransformer < object > ) , typeof ( MulticlassPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
14+ [ assembly: LoadableClass ( typeof ( MulticlassPredictionTransformer < IPredictorProducing < VBuffer < float > > > ) , typeof ( MulticlassPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
1515 "" , MulticlassPredictionTransformer . LoaderSignature ) ]
1616
17- [ assembly: LoadableClass ( typeof ( RegressionPredictionTransformer < object > ) , typeof ( RegressionPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
17+ [ assembly: LoadableClass ( typeof ( RegressionPredictionTransformer < IPredictorProducing < float > > ) , typeof ( RegressionPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
1818 "" , RegressionPredictionTransformer . LoaderSignature ) ]
1919
20- [ assembly: LoadableClass ( typeof ( RankingPredictionTransformer < object > ) , typeof ( RankingPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
20+ [ assembly: LoadableClass ( typeof ( RankingPredictionTransformer < IPredictorProducing < float > > ) , typeof ( RankingPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
2121 "" , RankingPredictionTransformer . LoaderSignature ) ]
2222
23- [ assembly: LoadableClass ( typeof ( AnomalyPredictionTransformer < object > ) , typeof ( AnomalyPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
23+ [ assembly: LoadableClass ( typeof ( AnomalyPredictionTransformer < IPredictorProducing < float > > ) , typeof ( AnomalyPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
2424 "" , AnomalyPredictionTransformer . LoaderSignature ) ]
2525
26- [ assembly: LoadableClass ( typeof ( ClusteringPredictionTransformer < object > ) , typeof ( ClusteringPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
26+ [ assembly: LoadableClass ( typeof ( ClusteringPredictionTransformer < IPredictorProducing < VBuffer < float > > > ) , typeof ( ClusteringPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
2727 "" , ClusteringPredictionTransformer . LoaderSignature ) ]
2828
2929namespace Microsoft . ML . Data
@@ -51,7 +51,8 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
5151 private protected readonly IHost Host ;
5252 [ BestFriend ]
5353 private protected ISchemaBindableMapper BindableMapper ;
54- protected DataViewSchema TrainSchema ;
54+ [ BestFriend ]
55+ private protected DataViewSchema TrainSchema ;
5556
5657 /// <summary>
5758 /// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
@@ -141,7 +142,8 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
141142
142143 private protected abstract void SaveModel ( ModelSaveContext ctx ) ;
143144
144- protected void SaveModelCore ( ModelSaveContext ctx )
145+ [ BestFriend ]
146+ private protected void SaveModelCore ( ModelSaveContext ctx )
145147 {
146148 // *** Binary format ***
147149 // <base info>
@@ -233,14 +235,14 @@ public sealed override DataViewSchema GetOutputSchema(DataViewSchema inputSchema
233235 return Transform ( new EmptyDataView ( Host , inputSchema ) ) . Schema ;
234236 }
235237
236- private protected override void SaveModel ( ModelSaveContext ctx )
238+ private protected sealed override void SaveModel ( ModelSaveContext ctx )
237239 {
238240 Host . CheckValue ( ctx , nameof ( ctx ) ) ;
239241 ctx . CheckAtModel ( ) ;
240242 SaveCore ( ctx ) ;
241243 }
242244
243- protected virtual void SaveCore ( ModelSaveContext ctx )
245+ private protected virtual void SaveCore ( ModelSaveContext ctx )
244246 {
245247 SaveModelCore ( ctx ) ;
246248 ctx . SaveStringOrNull ( FeatureColumn ) ;
@@ -295,7 +297,7 @@ private void SetScorer()
295297 Scorer = new BinaryClassifierScorer ( Host , args , new EmptyDataView ( Host , TrainSchema ) , BindableMapper . Bind ( Host , schema ) , schema ) ;
296298 }
297299
298- protected override void SaveCore ( ModelSaveContext ctx )
300+ private protected override void SaveCore ( ModelSaveContext ctx )
299301 {
300302 Contracts . AssertValue ( ctx ) ;
301303 ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
@@ -364,7 +366,7 @@ private void SetScorer()
364366 Scorer = new BinaryClassifierScorer ( Host , args , new EmptyDataView ( Host , TrainSchema ) , BindableMapper . Bind ( Host , schema ) , schema ) ;
365367 }
366368
367- protected override void SaveCore ( ModelSaveContext ctx )
369+ private protected override void SaveCore ( ModelSaveContext ctx )
368370 {
369371 Contracts . AssertValue ( ctx ) ;
370372 ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
@@ -428,7 +430,7 @@ private void SetScorer()
428430 Scorer = new MultiClassClassifierScorer ( Host , args , new EmptyDataView ( Host , TrainSchema ) , BindableMapper . Bind ( Host , schema ) , schema ) ;
429431 }
430432
431- protected override void SaveCore ( ModelSaveContext ctx )
433+ private protected override void SaveCore ( ModelSaveContext ctx )
432434 {
433435 Contracts . AssertValue ( ctx ) ;
434436 ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
@@ -473,7 +475,7 @@ internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext
473475 Scorer = GetGenericScorer ( ) ;
474476 }
475477
476- protected override void SaveCore ( ModelSaveContext ctx )
478+ private protected override void SaveCore ( ModelSaveContext ctx )
477479 {
478480 Contracts . AssertValue ( ctx ) ;
479481 ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
@@ -515,7 +517,7 @@ internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx
515517 Scorer = GetGenericScorer ( ) ;
516518 }
517519
518- protected override void SaveCore ( ModelSaveContext ctx )
520+ private protected override void SaveCore ( ModelSaveContext ctx )
519521 {
520522 Contracts . AssertValue ( ctx ) ;
521523 ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
@@ -567,7 +569,7 @@ internal ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext
567569 Scorer = new ClusteringScorer ( Host , args , new EmptyDataView ( Host , TrainSchema ) , BindableMapper . Bind ( Host , schema ) , schema ) ;
568570 }
569571
570- protected override void SaveCore ( ModelSaveContext ctx )
572+ private protected override void SaveCore ( ModelSaveContext ctx )
571573 {
572574 Contracts . AssertValue ( ctx ) ;
573575 ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
@@ -594,47 +596,47 @@ internal static class BinaryPredictionTransformer
594596 {
595597 public const string LoaderSignature = "BinaryPredXfer" ;
596598
597- public static BinaryPredictionTransformer < object > Create ( IHostEnvironment env , ModelLoadContext ctx )
598- => new BinaryPredictionTransformer < object > ( env , ctx ) ;
599+ public static BinaryPredictionTransformer < IPredictorProducing < float > > Create ( IHostEnvironment env , ModelLoadContext ctx )
600+ => new BinaryPredictionTransformer < IPredictorProducing < float > > ( env , ctx ) ;
599601 }
600602
601603 internal static class MulticlassPredictionTransformer
602604 {
603605 public const string LoaderSignature = "MulticlassPredXfer" ;
604606
605- public static MulticlassPredictionTransformer < object > Create ( IHostEnvironment env , ModelLoadContext ctx )
606- => new MulticlassPredictionTransformer < object > ( env , ctx ) ;
607+ public static MulticlassPredictionTransformer < IPredictorProducing < VBuffer < float > > > Create ( IHostEnvironment env , ModelLoadContext ctx )
608+ => new MulticlassPredictionTransformer < IPredictorProducing < VBuffer < float > > > ( env , ctx ) ;
607609 }
608610
609611 internal static class RegressionPredictionTransformer
610612 {
611613 public const string LoaderSignature = "RegressionPredXfer" ;
612614
613- public static RegressionPredictionTransformer < object > Create ( IHostEnvironment env , ModelLoadContext ctx )
614- => new RegressionPredictionTransformer < object > ( env , ctx ) ;
615+ public static RegressionPredictionTransformer < IPredictorProducing < float > > Create ( IHostEnvironment env , ModelLoadContext ctx )
616+ => new RegressionPredictionTransformer < IPredictorProducing < float > > ( env , ctx ) ;
615617 }
616618
617619 internal static class RankingPredictionTransformer
618620 {
619621 public const string LoaderSignature = "RankingPredXfer" ;
620622
621- public static RankingPredictionTransformer < object > Create ( IHostEnvironment env , ModelLoadContext ctx )
622- => new RankingPredictionTransformer < object > ( env , ctx ) ;
623+ public static RankingPredictionTransformer < IPredictorProducing < float > > Create ( IHostEnvironment env , ModelLoadContext ctx )
624+ => new RankingPredictionTransformer < IPredictorProducing < float > > ( env , ctx ) ;
623625 }
624626
625627 internal static class AnomalyPredictionTransformer
626628 {
627629 public const string LoaderSignature = "AnomalyPredXfer" ;
628630
629- public static AnomalyPredictionTransformer < object > Create ( IHostEnvironment env , ModelLoadContext ctx )
630- => new AnomalyPredictionTransformer < object > ( env , ctx ) ;
631+ public static AnomalyPredictionTransformer < IPredictorProducing < float > > Create ( IHostEnvironment env , ModelLoadContext ctx )
632+ => new AnomalyPredictionTransformer < IPredictorProducing < float > > ( env , ctx ) ;
631633 }
632634
633635 internal static class ClusteringPredictionTransformer
634636 {
635637 public const string LoaderSignature = "ClusteringPredXfer" ;
636638
637- public static ClusteringPredictionTransformer < object > Create ( IHostEnvironment env , ModelLoadContext ctx )
638- => new ClusteringPredictionTransformer < object > ( env , ctx ) ;
639+ public static ClusteringPredictionTransformer < IPredictorProducing < VBuffer < float > > > Create ( IHostEnvironment env , ModelLoadContext ctx )
640+ => new ClusteringPredictionTransformer < IPredictorProducing < VBuffer < float > > > ( env , ctx ) ;
639641 }
640642}
0 commit comments