@@ -96,15 +96,17 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
9696 ctx . LoadModelOrNull < ICalibrator , SignatureLoadModel > ( env , out calibrator , @"Calibrator" ) ;
9797 if ( calibrator == null )
9898 return predictor ;
99- return new SchemaBindableCalibratedPredictor ( env , predictor , calibrator ) ;
99+ return new SchemaBindableCalibratedModelParameters < FastTreeBinaryModelParameters , ICalibrator > ( env , predictor , calibrator ) ;
100100 }
101101
102102 public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
103103 }
104104
105105 /// <include file = 'doc.xml' path='doc/members/member[@name="FastTree"]/*' />
106106 public sealed partial class FastTreeBinaryClassificationTrainer :
107- BoostingFastTreeTrainerBase < FastTreeBinaryClassificationTrainer . Options , BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > , IPredictorWithFeatureWeights < float > >
107+ BoostingFastTreeTrainerBase < FastTreeBinaryClassificationTrainer . Options ,
108+ BinaryPredictionTransformer < CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > > ,
109+ CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > >
108110 {
109111 /// <summary>
110112 /// The LoadName for the assembly containing the trainer.
@@ -156,7 +158,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options optio
156158
157159 public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
158160
159- private protected override IPredictorWithFeatureWeights < float > TrainModelCore ( TrainContext context )
161+ private protected override CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > TrainModelCore ( TrainContext context )
160162 {
161163 Host . CheckValue ( context , nameof ( context ) ) ;
162164 var trainData = context . TrainingSet ;
@@ -185,7 +187,7 @@ private protected override IPredictorWithFeatureWeights<float> TrainModelCore(Tr
185187 // BinaryClassificationObjectiveFunction.GetGradientInOneQuery being consistent with the
186188 // description in section 6 of the paper.
187189 var cali = new PlattCalibrator ( Host , - 1 * _sigmoidParameter , 0 ) ;
188- return new FeatureWeightsCalibratedPredictor ( Host , pred , cali ) ;
190+ return new FeatureWeightsCalibratedModelParameters < FastTreeBinaryModelParameters , PlattCalibrator > ( Host , pred , cali ) ;
189191 }
190192
191193 protected override ObjectiveFunctionBase ConstructObjFunc ( IChannel ch )
@@ -273,10 +275,11 @@ protected override void InitializeTests()
273275 }
274276 }
275277
276- protected override BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > MakeTransformer ( IPredictorWithFeatureWeights < float > model , Schema trainSchema )
277- => new BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > ( Host , model , trainSchema , FeatureColumn . Name ) ;
278+ protected override BinaryPredictionTransformer < CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > > MakeTransformer (
279+ CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > model , Schema trainSchema )
280+ => new BinaryPredictionTransformer < CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > > ( Host , model , trainSchema , FeatureColumn . Name ) ;
278281
279- public BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > Train ( IDataView trainData , IDataView validationData = null )
282+ public BinaryPredictionTransformer < CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > > Train ( IDataView trainData , IDataView validationData = null )
280283 => TrainTransformer ( trainData , validationData ) ;
281284
282285 protected override SchemaShape . Column [ ] GetOutputColumnsCore ( SchemaShape inputSchema )
0 commit comments