88using System . Collections . Immutable ;
99using System . IO ;
1010using System . Linq ;
11+ using System . Reflection ;
1112using Microsoft . ML ;
1213using Microsoft . ML . Calibrators ;
1314using Microsoft . ML . CommandLine ;
@@ -396,6 +397,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
396397 }
397398
398399 [ BestFriend ]
400+ [ PredictionTransformerLoadType ( typeof ( CalibratedModelParametersBase < , > ) ) ]
399401 internal sealed class ValueMapperCalibratedModelParameters < TSubModel , TCalibrator > :
400402 ValueMapperCalibratedModelParametersBase < TSubModel , TCalibrator > , ICanSaveModel
401403 where TSubModel : class
@@ -430,8 +432,8 @@ private static VersionInfo GetVersionInfoBulk()
430432 loaderAssemblyName : typeof ( ValueMapperCalibratedModelParameters < TSubModel , TCalibrator > ) . Assembly . FullName ) ;
431433 }
432434
433- private ValueMapperCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx )
434- : base ( env , RegistrationName , GetPredictor ( env , ctx ) , GetCalibrator ( env , ctx ) )
435+ private ValueMapperCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx , TSubModel predictor , TCalibrator calibrator )
436+ : base ( env , RegistrationName , predictor , calibrator )
435437 {
436438 }
437439
@@ -443,7 +445,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
443445 var ver2 = GetVersionInfoBulk ( ) ;
444446 var ver = ctx . Header . ModelSignature == ver2 . ModelSignature ? ver2 : ver1 ;
445447 ctx . CheckAtModel ( ver ) ;
446- return new ValueMapperCalibratedModelParameters < TSubModel , TCalibrator > ( env , ctx ) ;
448+
449+ // Load first the predictor and calibrator
450+ var predictor = GetPredictor ( env , ctx ) ;
451+ var calibrator = GetCalibrator ( env , ctx ) ;
452+
453+ // Create a generic type using the correct parameter types of predictor and calibrator
454+ Type genericType = typeof ( ValueMapperCalibratedModelParameters < , > ) ;
455+ var genericInstance = CreateCalibratedModelParameters . Create ( env , ctx , predictor , calibrator , genericType ) ;
456+
457+ return ( CalibratedModelParametersBase ) genericInstance ;
447458 }
448459
449460 void ICanSaveModel . Save ( ModelSaveContext ctx )
@@ -456,6 +467,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx)
456467 }
457468
458469 [ BestFriend ]
470+ [ PredictionTransformerLoadType ( typeof ( CalibratedModelParametersBase < , > ) ) ]
459471 internal sealed class FeatureWeightsCalibratedModelParameters < TSubModel , TCalibrator > :
460472 ValueMapperCalibratedModelParametersBase < TSubModel , TCalibrator > ,
461473 IPredictorWithFeatureWeights < float > ,
@@ -487,8 +499,9 @@ private static VersionInfo GetVersionInfo()
487499 loaderAssemblyName : typeof ( FeatureWeightsCalibratedModelParameters < TSubModel , TCalibrator > ) . Assembly . FullName ) ;
488500 }
489501
490- private FeatureWeightsCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx )
491- : base ( env , RegistrationName , GetPredictor ( env , ctx ) , GetCalibrator ( env , ctx ) )
502+ private FeatureWeightsCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx ,
503+ TSubModel predictor , TCalibrator calibrator )
504+ : base ( env , RegistrationName , predictor , calibrator )
492505 {
493506 Host . Check ( SubModel is IPredictorWithFeatureWeights < float > , "Predictor does not implement " + nameof ( IPredictorWithFeatureWeights < float > ) ) ;
494507 _featureWeights = ( IPredictorWithFeatureWeights < float > ) SubModel ;
@@ -499,7 +512,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
499512 Contracts . CheckValue ( env , nameof ( env ) ) ;
500513 env . CheckValue ( ctx , nameof ( ctx ) ) ;
501514 ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
502- return new FeatureWeightsCalibratedModelParameters < TSubModel , TCalibrator > ( env , ctx ) ;
515+
516+ // Load first the predictor and calibrator
517+ var predictor = GetPredictor ( env , ctx ) ;
518+ var calibrator = GetCalibrator ( env , ctx ) ;
519+
520+ // Create a generic type using the correct parameter types of predictor and calibrator
521+ Type genericType = typeof ( FeatureWeightsCalibratedModelParameters < , > ) ;
522+ var genericInstance = CreateCalibratedModelParameters . Create ( env , ctx , predictor , calibrator , genericType ) ;
523+
524+ return ( CalibratedModelParametersBase ) genericInstance ;
503525 }
504526
505527 void ICanSaveModel . Save ( ModelSaveContext ctx )
@@ -520,6 +542,7 @@ public void GetFeatureWeights(ref VBuffer<float> weights)
520542 /// Encapsulates a predictor and a calibrator that implement <see cref="IParameterMixer"/>.
521543 /// Its implementation of <see cref="IParameterMixer.CombineParameters"/> combines both the predictors and the calibrators.
522544 /// </summary>
545+ [ PredictionTransformerLoadType ( typeof ( CalibratedModelParametersBase < , > ) ) ]
523546 internal sealed class ParameterMixingCalibratedModelParameters < TSubModel , TCalibrator > :
524547 ValueMapperCalibratedModelParametersBase < TSubModel , TCalibrator > ,
525548 IParameterMixer < float > ,
@@ -553,8 +576,8 @@ private static VersionInfo GetVersionInfo()
553576 loaderAssemblyName : typeof ( ParameterMixingCalibratedModelParameters < TSubModel , TCalibrator > ) . Assembly . FullName ) ;
554577 }
555578
556- private ParameterMixingCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx )
557- : base ( env , RegistrationName , GetPredictor ( env , ctx ) , GetCalibrator ( env , ctx ) )
579+ private ParameterMixingCalibratedModelParameters ( IHostEnvironment env , ModelLoadContext ctx , TSubModel predictor , TCalibrator calibrator )
580+ : base ( env , RegistrationName , predictor , calibrator )
558581 {
559582 Host . Check ( SubModel is IParameterMixer < float > , "Predictor does not implement " + nameof ( IParameterMixer ) ) ;
560583 Host . Check ( SubModel is IPredictorWithFeatureWeights < float > , "Predictor does not implement " + nameof ( IPredictorWithFeatureWeights < float > ) ) ;
@@ -566,7 +589,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
566589 Contracts . CheckValue ( env , nameof ( env ) ) ;
567590 env . CheckValue ( ctx , nameof ( ctx ) ) ;
568591 ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
569- return new ParameterMixingCalibratedModelParameters < TSubModel , TCalibrator > ( env , ctx ) ;
592+
593+ // Load first the predictor and calibrator
594+ var predictor = GetPredictor ( env , ctx ) ;
595+ var calibrator = GetCalibrator ( env , ctx ) ;
596+
597+ // Create a generic type using the correct parameter types of predictor and calibrator
598+ Type genericType = typeof ( ParameterMixingCalibratedModelParameters < , > ) ;
599+ object genericInstance = CreateCalibratedModelParameters . Create ( env , ctx , predictor , calibrator , genericType ) ;
600+
601+ return ( CalibratedModelParametersBase ) genericInstance ;
570602 }
571603
572604 void ICanSaveModel . Save ( ModelSaveContext ctx )
@@ -777,6 +809,28 @@ ValueMapper<TSrc, VBuffer<float>> IFeatureContributionMapper.GetFeatureContribut
777809 }
778810 }
779811
812+ internal static class CreateCalibratedModelParameters
813+ {
814+ internal static object Create ( IHostEnvironment env , ModelLoadContext ctx , object predictor , ICalibrator calibrator , Type calibratedModelParametersType )
815+ {
816+ Type [ ] genericTypeArgs = { predictor . GetType ( ) , calibrator . GetType ( ) } ;
817+ Type constructed = calibratedModelParametersType . MakeGenericType ( genericTypeArgs ) ;
818+
819+ Type [ ] constructorArgs = {
820+ typeof ( IHostEnvironment ) ,
821+ typeof ( ModelLoadContext ) ,
822+ predictor . GetType ( ) ,
823+ calibrator . GetType ( )
824+ } ;
825+
826+ // Call the appropiate constructor of the created generic type passing on the previously loaded predictor and calibrator
827+ var genericCtor = constructed . GetConstructor ( BindingFlags . NonPublic | BindingFlags . Instance , null , constructorArgs , null ) ;
828+ object genericInstance = genericCtor . Invoke ( new object [ ] { env , ctx , predictor , calibrator } ) ;
829+
830+ return genericInstance ;
831+ }
832+ }
833+
780834 [ BestFriend ]
781835 internal static class CalibratorUtils
782836 {
0 commit comments