22using System . Collections . Generic ;
33using Microsoft . ML ;
44using Microsoft . ML . Transforms . TimeSeries ;
5- using Microsoft . ML . TimeSeries ;
5+ using System . IO ;
66
77namespace Samples . Dynamic
88{
@@ -16,8 +16,7 @@ public static void Example()
1616 // as well as the source of randomness.
1717 var ml = new MLContext ( ) ;
1818
19- // Generate sample series data with a recurring pattern
20- const int SeasonalitySize = 5 ;
19+ // Generate sample series data with a recurring pattern.
2120 var data = new List < TimeSeriesData > ( )
2221 {
2322 new TimeSeriesData ( 0 ) ,
@@ -44,50 +43,58 @@ public static void Example()
4443
4544 // Setup arguments.
4645 var inputColumnName = nameof ( TimeSeriesData . Value ) ;
46+ var outputColumnName = nameof ( ForecastResult . Forecast ) ;
4747
48- // Instantiate forecasting model.
49- var model = ml . Forecasting . AdaptiveSingularSpectrumSequenceModeler ( inputColumnName , data . Count , SeasonalitySize + 1 , SeasonalitySize ,
50- 1 , AdaptiveSingularSpectrumSequenceModeler . RankSelectionMethod . Exact , null , SeasonalitySize / 2 , shouldComputeForecastIntervals : true , false ) ;
48+ // Instantiate the forecasting model.
49+ var model = ml . Forecasting . ForecastBySsa ( outputColumnName , inputColumnName , 5 , 11 , data . Count , 5 ,
50+ confidenceLevel : 0.95f ,
51+ forcastingConfidentLowerBoundColumnName : "ConfidenceLowerBound" ,
52+ forcastingConfidentUpperBoundColumnName : "ConfidenceUpperBound" ) ;
5153
5254 // Train.
53- model . Train ( dataView ) ;
54-
55- // Forecast next five values with confidence internal.
56- float [ ] forecast ;
57- float [ ] confidenceIntervalLowerBounds ;
58- float [ ] confidenceIntervalUpperBounds ;
59- model . ForecastWithConfidenceIntervals ( 5 , out forecast , out confidenceIntervalLowerBounds , out confidenceIntervalUpperBounds ) ;
60- PrintForecastValuesAndIntervals ( forecast , confidenceIntervalLowerBounds , confidenceIntervalUpperBounds ) ;
55+ var transformer = model . Fit ( dataView ) ;
56+
57+ // Forecast next five values.
58+ var forecastEngine = transformer . CreateTimeSeriesEngine < TimeSeriesData , ForecastResult > ( ml ) ;
59+ var forecast = forecastEngine . Predict ( ) ;
60+
61+ PrintForecastValuesAndIntervals ( forecast . Forecast , forecast . ConfidenceLowerBound , forecast . ConfidenceUpperBound ) ;
6162 // Forecasted values:
62- // [2.452744, 2.589339, 2.729183, 2.873005, 3.028931 ]
63+ // [1.977226, 1.020494, 1.760543, 3.437509, 4.266461 ]
6364 // Confidence intervals:
64- // [-0.2235315 - 5.12902 ] [-0.08777174 - 5.266451 ] [0.05076938 - 5.407597 ] [0.1925406 - 5.553469 ] [0.3469928 - 5.71087 ]
65+ // [0.3451088 - 3.609343 ] [-0.7967533 - 2.83774 ] [-0.058467 - 3.579552 ] [1.61505 - 5.259968 ] [2.349299 - 6.183623 ]
6566
6667 // Update with new observations.
67- dataView = ml . Data . LoadFromEnumerable ( new List < TimeSeriesData > ( ) { new TimeSeriesData ( 0 ) , new TimeSeriesData ( 0 ) , new TimeSeriesData ( 0 ) , new TimeSeriesData ( 0 ) } ) ;
68- model . Update ( dataView ) ;
68+ forecastEngine . Predict ( new TimeSeriesData ( 0 ) ) ;
69+ forecastEngine . Predict ( new TimeSeriesData ( 0 ) ) ;
70+ forecastEngine . Predict ( new TimeSeriesData ( 0 ) ) ;
71+ forecastEngine . Predict ( new TimeSeriesData ( 0 ) ) ;
6972
7073 // Checkpoint.
71- ml . Model . SaveForecastingModel ( model , "model.zip" ) ;
74+ forecastEngine . CheckPoint ( ml , "model.zip" ) ;
7275
7376 // Load the checkpointed model from disk.
74- var modelCopy = ml . Model . LoadForecastingModel < float > ( "model.zip" ) ;
77+ // Load the model.
78+ ITransformer modelCopy ;
79+ using ( var file = File . OpenRead ( "model.zip" ) )
80+ modelCopy = ml . Model . Load ( file , out DataViewSchema schema ) ;
81+
82+ // We must create a new prediction engine from the persisted model.
83+ var forecastEngineCopy = modelCopy . CreateTimeSeriesEngine < TimeSeriesData , ForecastResult > ( ml ) ;
7584
7685 // Forecast with the checkpointed model loaded from disk.
77- modelCopy . ForecastWithConfidenceIntervals ( 5 , out forecast , out confidenceIntervalLowerBounds , out confidenceIntervalUpperBounds ) ;
78- PrintForecastValuesAndIntervals ( forecast , confidenceIntervalLowerBounds , confidenceIntervalUpperBounds ) ;
79- // Forecasted values:
80- // [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
86+ forecast = forecastEngineCopy . Predict ( ) ;
87+ PrintForecastValuesAndIntervals ( forecast . Forecast , forecast . ConfidenceLowerBound , forecast . ConfidenceUpperBound ) ;
88+ // [1.791331, 1.255525, 0.3060154, -0.200446, 0.5657795]
8189 // Confidence intervals:
82- // [-1.808158 - 3.544394 ] [-1.8586 - 3.495622 ] [-1.871486 - 3.485341 ] [-1.836414 - 3.524514 ] [-1.736431 - 3.627447 ]
90+ // [0.1592142 - 3.423448 ] [-0.5617217 - 3.072772 ] [-1.512994 - 2.125025 ] [-2.022905 - 1.622013 ] [-1.351382 - 2.482941 ]
8391
8492 // Forecast with the original model(that was checkpointed to disk).
85- model . ForecastWithConfidenceIntervals ( 5 , out forecast , out confidenceIntervalLowerBounds , out confidenceIntervalUpperBounds ) ;
86- PrintForecastValuesAndIntervals ( forecast , confidenceIntervalLowerBounds , confidenceIntervalUpperBounds ) ;
87- // Forecasted values:
88- // [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
93+ forecast = forecastEngine . Predict ( ) ;
94+ PrintForecastValuesAndIntervals ( forecast . Forecast , forecast . ConfidenceLowerBound , forecast . ConfidenceUpperBound ) ;
95+ // [1.791331, 1.255525, 0.3060154, -0.200446, 0.5657795]
8996 // Confidence intervals:
90- // [-1.808158 - 3.544394 ] [-1.8586 - 3.495622 ] [-1.871486 - 3.485341 ] [-1.836414 - 3.524514 ] [-1.736431 - 3.627447 ]
97+ // [0.1592142 - 3.423448 ] [-0.5617217 - 3.072772 ] [-1.512994 - 2.125025 ] [-2.022905 - 1.622013 ] [-1.351382 - 2.482941 ]
9198 }
9299
93100 static void PrintForecastValuesAndIntervals ( float [ ] forecast , float [ ] confidenceIntervalLowerBounds , float [ ] confidenceIntervalUpperBounds )
@@ -100,6 +107,13 @@ static void PrintForecastValuesAndIntervals(float[] forecast, float[] confidence
100107 Console . WriteLine ( ) ;
101108 }
102109
110+ class ForecastResult
111+ {
112+ public float [ ] Forecast { get ; set ; }
113+ public float [ ] ConfidenceLowerBound { get ; set ; }
114+ public float [ ] ConfidenceUpperBound { get ; set ; }
115+ }
116+
103117 class TimeSeriesData
104118 {
105119 public float Value ;
0 commit comments