Skip to content

Commit 3401d7c

Browse files
authored
Create forecasting prediction engine and conform time series forecasting API to estimator standards. (#3910)
* forecast API. * refactor code. * refactor code. * fix. * fix. * fix. * fix. * fix. * Update documentation. * fix. * Update documentation. * Update tests. * clean up tests. * Update with PR review comment and entrypoint for forecasting API. * Make GrowthRatio accessible in nimbus ML. * update comments. * Add samples. * PR feedback. * Add TestEstimatorCore, samples and misc PR feedback. * update docs. * update entry points.
1 parent e0c4caa commit 3401d7c

27 files changed

+3171
-1760
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
### Input and Output Columns
2+
There is only one input column.
3+
The input column must be <xref:System.Single> where a <xref:System.Single> value indicates a value at a timestamp in the time series.
4+
5+
It produces either just one vector of forecasted values or three vectors: a vector of forecasted values, a vector of confidence lower bounds and a vector of confidence upper bounds.

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectAnomalyBySrCnn.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public static void Example()
4040
ITransformer model = ml.Transforms.DetectAnomalyBySrCnn(outputColumnName, inputColumnName, 16, 5, 5, 3, 8, 0.35).Fit(dataView);
4141

4242
// Create a time series prediction engine from the model.
43-
var engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, SrCnnAnomalyDetection>(ml);
43+
var engine = model.CreateTimeSeriesEngine<TimeSeriesData, SrCnnAnomalyDetection>(ml);
4444

4545
Console.WriteLine($"{outputColumnName} column obtained post-transformation.");
4646
Console.WriteLine("Data\tAlert\tScore\tMag");

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectChangePointBySsa.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public static void Example()
5757
ITransformer model = ml.Transforms.DetectChangePointBySsa(outputColumnName, inputColumnName, confidence, changeHistoryLength, TrainingSize, SeasonalitySize + 1).Fit(dataView);
5858

5959
// Create a prediction engine from the model for feeding new data.
60-
var engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, ChangePointPrediction>(ml);
60+
var engine = model.CreateTimeSeriesEngine<TimeSeriesData, ChangePointPrediction>(ml);
6161

6262
// Start streaming new data points with no change point to the prediction engine.
6363
Console.WriteLine($"Output from ChangePoint predictions on new data:");
@@ -99,7 +99,7 @@ public static void Example()
9999
model = ml.Model.Load(file, out DataViewSchema schema);
100100

101101
// We must create a new prediction engine from the persisted model.
102-
engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, ChangePointPrediction>(ml);
102+
engine = model.CreateTimeSeriesEngine<TimeSeriesData, ChangePointPrediction>(ml);
103103

104104
// Run predictions on the loaded model.
105105
for (int i = 0; i < 5; i++)

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectChangePointBySsaStream.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public static void Example()
5757
ITransformer model = ml.Transforms.DetectChangePointBySsa(outputColumnName, inputColumnName, confidence, changeHistoryLength, TrainingSize, SeasonalitySize + 1).Fit(dataView);
5858

5959
// Create a prediction engine from the model for feeding new data.
60-
var engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, ChangePointPrediction>(ml);
60+
var engine = model.CreateTimeSeriesEngine<TimeSeriesData, ChangePointPrediction>(ml);
6161

6262
// Start streaming new data points with no change point to the prediction engine.
6363
Console.WriteLine($"Output from ChangePoint predictions on new data:");
@@ -103,7 +103,7 @@ public static void Example()
103103
model = ml.Model.Load(stream, out DataViewSchema schema);
104104

105105
// We must create a new prediction engine from the persisted model.
106-
engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, ChangePointPrediction>(ml);
106+
engine = model.CreateTimeSeriesEngine<TimeSeriesData, ChangePointPrediction>(ml);
107107

108108
// Run predictions on the loaded model.
109109
for (int i = 0; i < 5; i++)

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectIidChangePoint.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public static void Example()
5656
ITransformer model = ml.Transforms.DetectIidChangePoint(outputColumnName, inputColumnName, 95, Size / 4).Fit(dataView);
5757

5858
// Create a time series prediction engine from the model.
59-
var engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, ChangePointPrediction>(ml);
59+
var engine = model.CreateTimeSeriesEngine<TimeSeriesData, ChangePointPrediction>(ml);
6060

6161
Console.WriteLine($"{outputColumnName} column obtained post-transformation.");
6262
Console.WriteLine("Data\tAlert\tScore\tP-Value\tMartingale value");
@@ -97,7 +97,7 @@ public static void Example()
9797
model = ml.Model.Load(file, out DataViewSchema schema);
9898

9999
// Create a time series prediction engine from the checkpointed model.
100-
engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, ChangePointPrediction>(ml);
100+
engine = model.CreateTimeSeriesEngine<TimeSeriesData, ChangePointPrediction>(ml);
101101
for (int index = 0; index < 8; index++)
102102
{
103103
// Anomaly change point detection.

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectIidSpike.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public static void Example()
4848
ITransformer model = ml.Transforms.DetectIidSpike(outputColumnName, inputColumnName, 95, Size).Fit(dataView);
4949

5050
// Create a time series prediction engine from the model.
51-
var engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, IidSpikePrediction>(ml);
51+
var engine = model.CreateTimeSeriesEngine<TimeSeriesData, IidSpikePrediction>(ml);
5252

5353
Console.WriteLine($"{outputColumnName} column obtained post-transformation.");
5454
Console.WriteLine("Data\tAlert\tScore\tP-Value");

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectSpikeBySsa.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public static void Example()
5454
ITransformer model = ml.Transforms.DetectSpikeBySsa(outputColumnName, inputColumnName, 95, 8, TrainingSize, SeasonalitySize + 1).Fit(dataView);
5555

5656
// Create a prediction engine from the model for feeding new data.
57-
var engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, SsaSpikePrediction>(ml);
57+
var engine = model.CreateTimeSeriesEngine<TimeSeriesData, SsaSpikePrediction>(ml);
5858

5959
// Start streaming new data points with no change point to the prediction engine.
6060
Console.WriteLine($"Output from spike predictions on new data:");
@@ -94,7 +94,7 @@ public static void Example()
9494
model = ml.Model.Load(file, out DataViewSchema schema);
9595

9696
// We must create a new prediction engine from the persisted model.
97-
engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, SsaSpikePrediction>(ml);
97+
engine = model.CreateTimeSeriesEngine<TimeSeriesData, SsaSpikePrediction>(ml);
9898

9999
// Run predictions on the loaded model.
100100
for (int i = 0; i < 5; i++)

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/Forecasting.cs

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
34
using Microsoft.ML;
45
using Microsoft.ML.Transforms.TimeSeries;
5-
using Microsoft.ML.TimeSeries;
66

77
namespace Samples.Dynamic
88
{
@@ -16,8 +16,7 @@ public static void Example()
1616
// as well as the source of randomness.
1717
var ml = new MLContext();
1818

19-
// Generate sample series data with a recurring pattern
20-
const int SeasonalitySize = 5;
19+
// Generate sample series data with a recurring pattern.
2120
var data = new List<TimeSeriesData>()
2221
{
2322
new TimeSeriesData(0),
@@ -44,43 +43,58 @@ public static void Example()
4443

4544
// Setup arguments.
4645
var inputColumnName = nameof(TimeSeriesData.Value);
46+
var outputColumnName = nameof(ForecastResult.Forecast);
4747

4848
// Instantiate the forecasting model.
49-
var model = ml.Forecasting.AdaptiveSingularSpectrumSequenceModeler(inputColumnName, data.Count, SeasonalitySize + 1, SeasonalitySize,
50-
1, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalitySize / 2, false, false);
49+
var model = ml.Forecasting.ForecastBySsa(outputColumnName, inputColumnName, 5, 11, data.Count, 5);
5150

5251
// Train.
53-
model.Train(dataView);
52+
var transformer = model.Fit(dataView);
5453

5554
// Forecast next five values.
56-
var forecast = model.Forecast(5);
55+
var forecastEngine = transformer.CreateTimeSeriesEngine<TimeSeriesData, ForecastResult>(ml);
56+
var forecast = forecastEngine.Predict();
57+
5758
Console.WriteLine($"Forecasted values:");
58-
Console.WriteLine("[{0}]", string.Join(", ", forecast));
59+
Console.WriteLine("[{0}]", string.Join(", ", forecast.Forecast));
5960
// Forecasted values:
60-
// [2.452744, 2.589339, 2.729183, 2.873005, 3.028931]
61+
// [1.977226, 1.020494, 1.760543, 3.437509, 4.266461]
6162

6263
// Update with new observations.
63-
dataView = ml.Data.LoadFromEnumerable(new List<TimeSeriesData>() { new TimeSeriesData(0), new TimeSeriesData(0), new TimeSeriesData(0), new TimeSeriesData(0) });
64-
model.Update(dataView);
64+
forecastEngine.Predict(new TimeSeriesData(0));
65+
forecastEngine.Predict(new TimeSeriesData(0));
66+
forecastEngine.Predict(new TimeSeriesData(0));
67+
forecastEngine.Predict(new TimeSeriesData(0));
6568

6669
// Checkpoint.
67-
ml.Model.SaveForecastingModel(model, "model.zip");
70+
forecastEngine.CheckPoint(ml, "model.zip");
6871

6972
// Load the checkpointed model from disk.
70-
var modelCopy = ml.Model.LoadForecastingModel<float>("model.zip");
73+
// Load the model.
74+
ITransformer modelCopy;
75+
using (var file = File.OpenRead("model.zip"))
76+
modelCopy = ml.Model.Load(file, out DataViewSchema schema);
77+
78+
// We must create a new prediction engine from the persisted model.
79+
var forecastEngineCopy = modelCopy.CreateTimeSeriesEngine<TimeSeriesData, ForecastResult>(ml);
7180

7281
// Forecast with the checkpointed model loaded from disk.
73-
forecast = modelCopy.Forecast(5);
74-
Console.WriteLine("[{0}]", string.Join(", ", forecast));
75-
// [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
82+
forecast = forecastEngineCopy.Predict();
83+
Console.WriteLine("[{0}]", string.Join(", ", forecast.Forecast));
84+
// [1.791331, 1.255525, 0.3060154, -0.200446, 0.5657795]
7685

7786
// Forecast with the original model(that was checkpointed to disk).
78-
forecast = model.Forecast(5);
79-
Console.WriteLine("[{0}]", string.Join(", ", forecast));
80-
// [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
87+
forecast = forecastEngine.Predict();
88+
Console.WriteLine("[{0}]", string.Join(", ", forecast.Forecast));
89+
// [1.791331, 1.255525, 0.3060154, -0.200446, 0.5657795]
8190

8291
}
8392

93+
class ForecastResult
94+
{
95+
public float[] Forecast { get; set; }
96+
}
97+
8498
class TimeSeriesData
8599
{
86100
public float Value;

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/ForecastingWithConfidenceInterval.cs

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using System.Collections.Generic;
33
using Microsoft.ML;
44
using Microsoft.ML.Transforms.TimeSeries;
5-
using Microsoft.ML.TimeSeries;
5+
using System.IO;
66

77
namespace Samples.Dynamic
88
{
@@ -16,8 +16,7 @@ public static void Example()
1616
// as well as the source of randomness.
1717
var ml = new MLContext();
1818

19-
// Generate sample series data with a recurring pattern
20-
const int SeasonalitySize = 5;
19+
// Generate sample series data with a recurring pattern.
2120
var data = new List<TimeSeriesData>()
2221
{
2322
new TimeSeriesData(0),
@@ -44,50 +43,58 @@ public static void Example()
4443

4544
// Setup arguments.
4645
var inputColumnName = nameof(TimeSeriesData.Value);
46+
var outputColumnName = nameof(ForecastResult.Forecast);
4747

48-
// Instantiate forecasting model.
49-
var model = ml.Forecasting.AdaptiveSingularSpectrumSequenceModeler(inputColumnName, data.Count, SeasonalitySize + 1, SeasonalitySize,
50-
1, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalitySize / 2, shouldComputeForecastIntervals: true, false);
48+
// Instantiate the forecasting model.
49+
var model = ml.Forecasting.ForecastBySsa(outputColumnName, inputColumnName, 5, 11, data.Count, 5,
50+
confidenceLevel: 0.95f,
51+
forcastingConfidentLowerBoundColumnName: "ConfidenceLowerBound",
52+
forcastingConfidentUpperBoundColumnName: "ConfidenceUpperBound");
5153

5254
// Train.
53-
model.Train(dataView);
54-
55-
// Forecast next five values with confidence internal.
56-
float[] forecast;
57-
float[] confidenceIntervalLowerBounds;
58-
float[] confidenceIntervalUpperBounds;
59-
model.ForecastWithConfidenceIntervals(5, out forecast, out confidenceIntervalLowerBounds, out confidenceIntervalUpperBounds);
60-
PrintForecastValuesAndIntervals(forecast, confidenceIntervalLowerBounds, confidenceIntervalUpperBounds);
55+
var transformer = model.Fit(dataView);
56+
57+
// Forecast next five values.
58+
var forecastEngine = transformer.CreateTimeSeriesEngine<TimeSeriesData, ForecastResult>(ml);
59+
var forecast = forecastEngine.Predict();
60+
61+
PrintForecastValuesAndIntervals(forecast.Forecast, forecast.ConfidenceLowerBound, forecast.ConfidenceUpperBound);
6162
// Forecasted values:
62-
// [2.452744, 2.589339, 2.729183, 2.873005, 3.028931]
63+
// [1.977226, 1.020494, 1.760543, 3.437509, 4.266461]
6364
// Confidence intervals:
64-
// [-0.2235315 - 5.12902] [-0.08777174 - 5.266451] [0.05076938 - 5.407597] [0.1925406 - 5.553469] [0.3469928 - 5.71087]
65+
// [0.3451088 - 3.609343] [-0.7967533 - 2.83774] [-0.058467 - 3.579552] [1.61505 - 5.259968] [2.349299 - 6.183623]
6566

6667
// Update with new observations.
67-
dataView = ml.Data.LoadFromEnumerable(new List<TimeSeriesData>() { new TimeSeriesData(0), new TimeSeriesData(0), new TimeSeriesData(0), new TimeSeriesData(0) });
68-
model.Update(dataView);
68+
forecastEngine.Predict(new TimeSeriesData(0));
69+
forecastEngine.Predict(new TimeSeriesData(0));
70+
forecastEngine.Predict(new TimeSeriesData(0));
71+
forecastEngine.Predict(new TimeSeriesData(0));
6972

7073
// Checkpoint.
71-
ml.Model.SaveForecastingModel(model, "model.zip");
74+
forecastEngine.CheckPoint(ml, "model.zip");
7275

7376
// Load the checkpointed model from disk.
74-
var modelCopy = ml.Model.LoadForecastingModel<float>("model.zip");
77+
// Load the model.
78+
ITransformer modelCopy;
79+
using (var file = File.OpenRead("model.zip"))
80+
modelCopy = ml.Model.Load(file, out DataViewSchema schema);
81+
82+
// We must create a new prediction engine from the persisted model.
83+
var forecastEngineCopy = modelCopy.CreateTimeSeriesEngine<TimeSeriesData, ForecastResult>(ml);
7584

7685
// Forecast with the checkpointed model loaded from disk.
77-
modelCopy.ForecastWithConfidenceIntervals(5, out forecast, out confidenceIntervalLowerBounds, out confidenceIntervalUpperBounds);
78-
PrintForecastValuesAndIntervals(forecast, confidenceIntervalLowerBounds, confidenceIntervalUpperBounds);
79-
// Forecasted values:
80-
// [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
86+
forecast = forecastEngineCopy.Predict();
87+
PrintForecastValuesAndIntervals(forecast.Forecast, forecast.ConfidenceLowerBound, forecast.ConfidenceUpperBound);
88+
// [1.791331, 1.255525, 0.3060154, -0.200446, 0.5657795]
8189
// Confidence intervals:
82-
// [-1.808158 - 3.544394] [-1.8586 - 3.495622] [-1.871486 - 3.485341] [-1.836414 - 3.524514] [-1.736431 - 3.627447]
90+
// [0.1592142 - 3.423448] [-0.5617217 - 3.072772] [-1.512994 - 2.125025] [-2.022905 - 1.622013] [-1.351382 - 2.482941]
8391

8492
// Forecast with the original model(that was checkpointed to disk).
85-
model.ForecastWithConfidenceIntervals(5, out forecast, out confidenceIntervalLowerBounds, out confidenceIntervalUpperBounds);
86-
PrintForecastValuesAndIntervals(forecast, confidenceIntervalLowerBounds, confidenceIntervalUpperBounds);
87-
// Forecasted values:
88-
// [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
93+
forecast = forecastEngine.Predict();
94+
PrintForecastValuesAndIntervals(forecast.Forecast, forecast.ConfidenceLowerBound, forecast.ConfidenceUpperBound);
95+
// [1.791331, 1.255525, 0.3060154, -0.200446, 0.5657795]
8996
// Confidence intervals:
90-
// [-1.808158 - 3.544394] [-1.8586 - 3.495622] [-1.871486 - 3.485341] [-1.836414 - 3.524514] [-1.736431 - 3.627447]
97+
// [0.1592142 - 3.423448] [-0.5617217 - 3.072772] [-1.512994 - 2.125025] [-2.022905 - 1.622013] [-1.351382 - 2.482941]
9198
}
9299

93100
static void PrintForecastValuesAndIntervals(float[] forecast, float[] confidenceIntervalLowerBounds, float[] confidenceIntervalUpperBounds)
@@ -100,6 +107,13 @@ static void PrintForecastValuesAndIntervals(float[] forecast, float[] confidence
100107
Console.WriteLine();
101108
}
102109

110+
class ForecastResult
111+
{
112+
public float[] Forecast { get; set; }
113+
public float[] ConfidenceLowerBound { get; set; }
114+
public float[] ConfidenceUpperBound { get; set; }
115+
}
116+
103117
class TimeSeriesData
104118
{
105119
public float Value;

docs/samples/Microsoft.ML.Samples/Program.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Reflection;
3+
using Samples.Dynamic;
34

45
namespace Microsoft.ML.Samples
56
{

0 commit comments

Comments
 (0)