diff --git a/docs/README.md b/docs/README.md index ad01070e..12633350 100644 --- a/docs/README.md +++ b/docs/README.md @@ -12,4 +12,4 @@ Project Docs - [API](https://docs.microsoft.com/en-us/nimbusml/overview) - [Tutorials](https://docs.microsoft.com/en-us/nimbusml/tutorials) - [Developer Guide](developers/developer-guide.md) -- [Contributing to ML.NET](CONTRIBUTING.md) \ No newline at end of file +- [Contributing to ML.NET](CONTRIBUTING.md) diff --git a/src/DotNetBridge/Bridge.cs b/src/DotNetBridge/Bridge.cs index 26e5a84d..96100247 100644 --- a/src/DotNetBridge/Bridge.cs +++ b/src/DotNetBridge/Bridge.cs @@ -17,7 +17,7 @@ using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.LightGbm; using Microsoft.ML.Transforms; -using Microsoft.ML.TimeSeries; +using Microsoft.ML.Transforms.TimeSeries; namespace Microsoft.MachineLearning.DotNetBridge { @@ -329,7 +329,7 @@ private static unsafe int GenericExec(EnvironmentBlock* penv, sbyte* psz, int cd //env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly); //env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesProcessingEntryPoints).Assembly); //env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly); - env.ComponentCatalog.RegisterAssembly(typeof(ForecastExtensions).Assembly); + env.ComponentCatalog.RegisterAssembly(typeof(SsaChangePointDetector).Assembly); using (var ch = host.Start("Executing")) { diff --git a/src/python/nimbusml.pyproj b/src/python/nimbusml.pyproj index d4b65307..159d5f43 100644 --- a/src/python/nimbusml.pyproj +++ b/src/python/nimbusml.pyproj @@ -89,6 +89,7 @@ + @@ -139,6 +140,7 @@ + @@ -235,6 +237,17 @@ + + + + + + + + + + + @@ -275,6 +288,7 @@ + @@ -580,8 +594,10 @@ + + @@ -590,6 +606,7 @@ + diff --git a/src/python/nimbusml/_pipeline.py b/src/python/nimbusml/_pipeline.py index 7237ef7a..74435d87 100644 --- a/src/python/nimbusml/_pipeline.py +++ b/src/python/nimbusml/_pipeline.py @@ -1839,7 +1839,7 @@ def predict_proba(self, X, verbose=0, **params): last_node = self.last_node last_node._check_implements_method('predict_proba') - scores = self.predict(X, verbose, **params) + scores, _ = self._predict(X, verbose=verbose, **params) # REVIEW: Consider adding an entry point that extracts the # probability column instead. @@ -1883,7 +1883,7 @@ def decision_function(self, X, verbose=0, **params): last_node = self.last_node last_node._check_implements_method('decision_function') - scores = self.predict(X, verbose, **params) + scores, _ = self._predict(X, verbose=verbose, **params) # REVIEW: Consider adding an entry point that extracts the score # column instead. diff --git a/src/python/nimbusml/examples/CharTokenizer.py b/src/python/nimbusml/examples/CharTokenizer.py index 8762ce6a..56c6ab04 100644 --- a/src/python/nimbusml/examples/CharTokenizer.py +++ b/src/python/nimbusml/examples/CharTokenizer.py @@ -1,25 +1,45 @@ ############################################################################### # CharTokenizer import numpy -from nimbusml import FileDataStream, DataSchema +from nimbusml import FileDataStream, DataSchema, Pipeline from nimbusml.datasets import get_dataset +from nimbusml.preprocessing import FromKey +from nimbusml.preprocessing.text import CharTokenizer +from nimbusml.preprocessing.schema import ColumnSelector +from nimbusml.feature_extraction.text import WordEmbedding # data input (as a FileDataStream) -path = get_dataset('infert').as_filepath() +path = get_dataset('wiki_detox_train').as_filepath() file_schema = DataSchema.read_schema( - path, sep=',', numeric_dtype=numpy.float32) + path, sep='\t', numeric_dtype=numpy.float32) data = FileDataStream(path, schema=file_schema) -print(data.schema) - -# Section below throws "System.Runtime.InteropServices.SEHException" -# Logged in https://github.com/Microsoft/NimbusML/issues/31 - -# # transform usage -# xf = CharTokenizer(columns={'id_1': 'id', 'edu_1': 'education'}) -# -# # fit and transform -# features = xf.fit_transform(data) -# -# # print features -# print(features.head()) +print(data.head()) + +# Sentiment SentimentText +# 0 1.0 ==RUDE== Dude, you are rude upload that carl p... +# 1 1.0 == OK! == IM GOING TO VANDALIZE WILD ONES WIK... +# 2 1.0 Stop trolling, zapatancas, calling me a liar m... +# 3 1.0 ==You're cool== You seem like a really cool g... +# 4 1.0 ::::: Why are you threatening me? I'm not bein... + +# After using Character Tokenizer, it will convert the vector of Char to Key type. +# Use FromKey to retrieve the data from Key first, then send into WordEmbedding. + +pipe = Pipeline([ + CharTokenizer(columns={'SentimentText_Transform': 'SentimentText'}), + FromKey(columns={'SentimentText_FromKey': 'SentimentText_Transform'}), + WordEmbedding(model_kind='GloVe50D', columns={'Feature': 'SentimentText_FromKey'}), + ColumnSelector(columns=['Sentiment', 'SentimentText', 'Feature']) + ]) + +print(pipe.fit_transform(data).head()) + +# Sentiment ... Feature.149 +# 0 1.0 ... 2.67440 +# 1 1.0 ... 0.78858 +# 2 1.0 ... 2.67440 +# 3 1.0 ... 2.67440 +# 4 1.0 ... 2.67440 + +# [5 rows x 152 columns] diff --git a/src/python/nimbusml/examples/SsaForecaster.py b/src/python/nimbusml/examples/SsaForecaster.py new file mode 100644 index 00000000..da2bfa4a --- /dev/null +++ b/src/python/nimbusml/examples/SsaForecaster.py @@ -0,0 +1,49 @@ +############################################################################### +# SsaForecaster +import pandas as pd +from nimbusml import Pipeline, FileDataStream +from nimbusml.datasets import get_dataset +from nimbusml.timeseries import SsaForecaster + +# data input (as a FileDataStream) +path = get_dataset('timeseries').as_filepath() + +data = FileDataStream.read_csv(path) +print(data.head()) +# t1 t2 t3 +# 0 0.01 0.01 0.0100 +# 1 0.02 0.02 0.0200 +# 2 0.03 0.03 0.0200 +# 3 0.03 0.03 0.0250 +# 4 0.03 0.03 0.0005 + +# define the training pipeline +pipeline = Pipeline([ + SsaForecaster(forcasting_confident_lower_bound_column_name="cmin", + forcasting_confident_upper_bound_column_name="cmax", + series_length=6, + train_size=8, + window_size=3, + horizon=2, + # max_growth={'TimeSpan': 1, 'Growth': 1} + columns={'t2_fc': 't2'}) +]) + +result = pipeline.fit_transform(data) + +pd.set_option('display.float_format', lambda x: '%.2f' % x) +print(result) + +# Output +# t1 t2 t3 t2_fc.0 t2_fc.1 cmin.0 cmin.1 cmax.0 cmax.1 +# 0 0.01 0.01 0.01 0.10 0.12 0.09 0.11 0.11 0.13 +# 1 0.02 0.02 0.02 0.06 0.08 0.06 0.07 0.07 0.09 +# 2 0.03 0.03 0.02 0.04 0.05 0.03 0.04 0.05 0.07 +# 3 0.03 0.03 0.03 0.05 0.06 0.04 0.05 0.05 0.07 +# 4 0.03 0.03 0.00 0.05 0.07 0.04 0.05 0.06 0.08 +# 5 0.03 0.05 0.01 0.06 0.08 0.05 0.07 0.07 0.10 +# 6 0.05 0.07 0.05 0.09 0.12 0.08 0.10 0.10 0.13 +# 7 0.07 0.09 0.09 0.12 0.16 0.11 0.15 0.13 0.17 +# 8 0.09 99.00 99.00 57.92 82.88 57.91 82.87 57.93 82.89 +# 9 1.10 0.10 0.10 60.50 77.18 60.49 77.17 60.50 77.19 + diff --git a/src/python/nimbusml/examples/examples_from_dataframe/SsaForecaster_df.py b/src/python/nimbusml/examples/examples_from_dataframe/SsaForecaster_df.py new file mode 100644 index 00000000..c5cf680d --- /dev/null +++ b/src/python/nimbusml/examples/examples_from_dataframe/SsaForecaster_df.py @@ -0,0 +1,88 @@ +############################################################################### +# SsaForecaster +import numpy as np +import pandas as pd +from nimbusml.timeseries import SsaForecaster + +# This example creates a time series (list of data with the +# i-th element corresponding to the i-th time slot). + +# Generate sample series data with a recurring pattern +seasonality_size = 5 +seasonal_data = np.arange(seasonality_size) + +data = np.tile(seasonal_data, 3) +X_train = pd.Series(data, name="ts") + +# X_train looks like this +# 0 0 +# 1 1 +# 2 2 +# 3 3 +# 4 4 +# 5 0 +# 6 1 +# 7 2 +# 8 3 +# 9 4 +# 10 0 +# 11 1 +# 12 2 +# 13 3 +# 14 4 + +x_test = X_train.copy() +x_test[-3:] = [100, 110, 120] + +# x_test looks like this +# 0 0 +# 1 1 +# 2 2 +# 3 3 +# 4 4 +# 5 0 +# 6 1 +# 7 2 +# 8 3 +# 9 4 +# 10 0 +# 11 1 +# 12 100 +# 13 110 +# 14 120 + +training_seasons = 3 +training_size = seasonality_size * training_seasons + +forecaster = SsaForecaster(forcasting_confident_lower_bound_column_name="cmin", + forcasting_confident_upper_bound_column_name="cmax", + series_length=8, + train_size=training_size, + window_size=seasonality_size + 1, + horizon=4, + # max_growth={'TimeSpan': 1, 'Growth': 1} + ) << {'fc': 'ts'} + +forecaster.fit(X_train, verbose=1) +data = forecaster.transform(x_test) + +pd.set_option('display.float_format', lambda x: '%.2f' % x) +print(data) + +# Output +# ts fc.0 fc.1 fc.2 fc.3 cmin.0 cmin.1 cmin.2 cmin.3 cmax.0 cmax.1 cmax.2 cmax.3 +# 0 0 1.00 2.00 3.00 4.00 1.00 2.00 3.00 4.00 1.00 2.00 3.00 4.00 +# 1 1 2.00 3.00 4.00 -0.00 2.00 3.00 4.00 -0.00 2.00 3.00 4.00 0.00 +# 2 2 3.00 4.00 -0.00 1.00 3.00 4.00 -0.00 1.00 3.00 4.00 0.00 1.00 +# 3 3 4.00 -0.00 1.00 2.00 4.00 -0.00 1.00 2.00 4.00 0.00 1.00 2.00 +# 4 4 -0.00 1.00 2.00 3.00 -0.00 1.00 2.00 3.00 0.00 1.00 2.00 3.00 +# 5 0 1.00 2.00 3.00 4.00 1.00 2.00 3.00 4.00 1.00 2.00 3.00 4.00 +# 6 1 2.00 3.00 4.00 -0.00 2.00 3.00 4.00 -0.00 2.00 3.00 4.00 0.00 +# 7 2 3.00 4.00 -0.00 1.00 3.00 4.00 -0.00 1.00 3.00 4.00 0.00 1.00 +# 8 3 4.00 -0.00 1.00 2.00 4.00 -0.00 1.00 2.00 4.00 0.00 1.00 2.00 +# 9 4 -0.00 1.00 2.00 3.00 -0.00 1.00 2.00 3.00 0.00 1.00 2.00 3.00 +# 10 0 1.00 2.00 3.00 4.00 1.00 2.00 3.00 4.00 1.00 2.00 3.00 4.00 +# 11 1 2.00 3.00 4.00 -0.00 2.00 3.00 4.00 -0.00 2.00 3.00 4.00 0.00 +# 12 100 3.00 4.00 0.00 1.00 3.00 4.00 -0.00 1.00 3.00 4.00 0.00 1.00 +# 13 110 4.00 -0.00 1.00 75.50 4.00 -0.00 1.00 75.50 4.00 -0.00 1.00 75.50 +# 14 120 -0.00 1.00 83.67 83.25 -0.00 1.00 83.67 83.25 -0.00 1.00 83.67 83.25 diff --git a/src/python/nimbusml/internal/core/timeseries/_ssaforecaster.py b/src/python/nimbusml/internal/core/timeseries/_ssaforecaster.py new file mode 100644 index 00000000..164d631b --- /dev/null +++ b/src/python/nimbusml/internal/core/timeseries/_ssaforecaster.py @@ -0,0 +1,135 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------------------------- +# - Generated by tools/entrypoint_compiler.py: do not edit by hand +""" +SsaForecaster +""" + +__all__ = ["SsaForecaster"] + + +from ...entrypoints.timeseriesprocessingentrypoints_ssaforecasting import \ + timeseriesprocessingentrypoints_ssaforecasting +from ...utils.utils import trace +from ..base_pipeline_item import BasePipelineItem, DefaultSignature + + +class SsaForecaster(BasePipelineItem, DefaultSignature): + """ + **Description** + This transform forecasts using Singular Spectrum Analysis (SSA). + + :param window_size: The length of the window on the series for building the + trajectory matrix (parameter L). + + :param series_length: The length of series that is kept in buffer for + modeling (parameter N). + + :param train_size: The length of series from the begining used for + training. + + :param horizon: The number of values to forecast. + + :param confidence_level: The confidence level in [0, 1) for forecasting. + + :param forcasting_confident_lower_bound_column_name: The name of the + confidence interval lower bound column. + + :param forcasting_confident_upper_bound_column_name: The name of the + confidence interval upper bound column. + + :param rank_selection_method: The rank selection method. + + :param rank: The desired rank of the subspace used for SSA projection + (parameter r). This parameter should be in the range in [1, + windowSize]. If set to null, the rank is automatically determined based + on prediction error minimization. + + :param max_rank: The maximum rank considered during the rank selection + process. If not provided (i.e. set to null), it is set to windowSize - + 1. + + :param should_stablize: The flag determining whether the model should be + stabilized. + + :param should_maintain_info: The flag determining whether the meta + information for the model needs to be maintained. + + :param max_growth: The maximum growth on the exponential trend. + + :param discount_factor: The discount factor in [0,1] used for online + updates. + + :param is_adaptive: The flag determing whether the model is adaptive. + + :param params: Additional arguments sent to compute engine. + + """ + + @trace + def __init__( + self, + window_size=0, + series_length=0, + train_size=0, + horizon=0, + confidence_level=0.95, + forcasting_confident_lower_bound_column_name=None, + forcasting_confident_upper_bound_column_name=None, + rank_selection_method='Exact', + rank=None, + max_rank=None, + should_stablize=True, + should_maintain_info=False, + max_growth=None, + discount_factor=1.0, + is_adaptive=False, + **params): + BasePipelineItem.__init__( + self, type='transform', **params) + + self.window_size = window_size + self.series_length = series_length + self.train_size = train_size + self.horizon = horizon + self.confidence_level = confidence_level + self.forcasting_confident_lower_bound_column_name = forcasting_confident_lower_bound_column_name + self.forcasting_confident_upper_bound_column_name = forcasting_confident_upper_bound_column_name + self.rank_selection_method = rank_selection_method + self.rank = rank + self.max_rank = max_rank + self.should_stablize = should_stablize + self.should_maintain_info = should_maintain_info + self.max_growth = max_growth + self.discount_factor = discount_factor + self.is_adaptive = is_adaptive + + @property + def _entrypoint(self): + return timeseriesprocessingentrypoints_ssaforecasting + + @trace + def _get_node(self, **all_args): + algo_args = dict( + source=self.source, + name=self._name_or_source, + window_size=self.window_size, + series_length=self.series_length, + train_size=self.train_size, + horizon=self.horizon, + confidence_level=self.confidence_level, + forcasting_confident_lower_bound_column_name=self.forcasting_confident_lower_bound_column_name, + forcasting_confident_upper_bound_column_name=self.forcasting_confident_upper_bound_column_name, + rank_selection_method=self.rank_selection_method, + rank=self.rank, + max_rank=self.max_rank, + should_stablize=self.should_stablize, + should_maintain_info=self.should_maintain_info, + max_growth=self.max_growth, + discount_factor=self.discount_factor, + is_adaptive=self.is_adaptive) + + all_args.update(algo_args) + return self._entrypoint(**all_args) diff --git a/src/python/nimbusml/internal/entrypoints/timeseriesprocessingentrypoints_ssaforecasting.py b/src/python/nimbusml/internal/entrypoints/timeseriesprocessingentrypoints_ssaforecasting.py new file mode 100644 index 00000000..4d9cb05e --- /dev/null +++ b/src/python/nimbusml/internal/entrypoints/timeseriesprocessingentrypoints_ssaforecasting.py @@ -0,0 +1,206 @@ +# - Generated by tools/entrypoint_compiler.py: do not edit by hand +""" +TimeSeriesProcessingEntryPoints.SsaForecasting +""" + +import numbers + +from ..utils.entrypoints import EntryPoint +from ..utils.utils import try_set, unlist + + +def timeseriesprocessingentrypoints_ssaforecasting( + source, + data, + name, + output_data=None, + model=None, + window_size=0, + series_length=0, + train_size=0, + horizon=0, + confidence_level=0.95, + forcasting_confident_lower_bound_column_name=None, + forcasting_confident_upper_bound_column_name=None, + rank_selection_method='Exact', + rank=None, + max_rank=None, + should_stablize=True, + should_maintain_info=False, + max_growth=None, + discount_factor=1.0, + is_adaptive=False, + **params): + """ + **Description** + This transform forecasts using Singular Spectrum Analysis (SSA). + + :param source: The name of the source column. (inputs). + :param data: Input dataset (inputs). + :param name: The name of the new column. (inputs). + :param window_size: The length of the window on the series for + building the trajectory matrix (parameter L). (inputs). + :param series_length: The length of series that is kept in buffer + for modeling (parameter N). (inputs). + :param train_size: The length of series from the begining used + for training. (inputs). + :param horizon: The number of values to forecast. (inputs). + :param confidence_level: The confidence level in [0, 1) for + forecasting. (inputs). + :param forcasting_confident_lower_bound_column_name: The name of + the confidence interval lower bound column. (inputs). + :param forcasting_confident_upper_bound_column_name: The name of + the confidence interval upper bound column. (inputs). + :param rank_selection_method: The rank selection method. + (inputs). + :param rank: The desired rank of the subspace used for SSA + projection (parameter r). This parameter should be in the + range in [1, windowSize]. If set to null, the rank is + automatically determined based on prediction error + minimization. (inputs). + :param max_rank: The maximum rank considered during the rank + selection process. If not provided (i.e. set to null), it is + set to windowSize - 1. (inputs). + :param should_stablize: The flag determining whether the model + should be stabilized. (inputs). + :param should_maintain_info: The flag determining whether the + meta information for the model needs to be maintained. + (inputs). + :param max_growth: The maximum growth on the exponential trend. + (inputs). + :param discount_factor: The discount factor in [0,1] used for + online updates. (inputs). + :param is_adaptive: The flag determing whether the model is + adaptive (inputs). + :param output_data: Transformed dataset (outputs). + :param model: Transform model (outputs). + """ + + entrypoint_name = 'TimeSeriesProcessingEntryPoints.SsaForecasting' + inputs = {} + outputs = {} + + if source is not None: + inputs['Source'] = try_set( + obj=source, + none_acceptable=False, + is_of_type=str, + is_column=True) + if data is not None: + inputs['Data'] = try_set( + obj=data, + none_acceptable=False, + is_of_type=str) + if name is not None: + inputs['Name'] = try_set( + obj=name, + none_acceptable=False, + is_of_type=str, + is_column=True) + if window_size is not None: + inputs['WindowSize'] = try_set( + obj=window_size, + none_acceptable=False, + is_of_type=numbers.Real) + if series_length is not None: + inputs['SeriesLength'] = try_set( + obj=series_length, + none_acceptable=False, + is_of_type=numbers.Real) + if train_size is not None: + inputs['TrainSize'] = try_set( + obj=train_size, + none_acceptable=False, + is_of_type=numbers.Real) + if horizon is not None: + inputs['Horizon'] = try_set( + obj=horizon, + none_acceptable=False, + is_of_type=numbers.Real) + if confidence_level is not None: + inputs['ConfidenceLevel'] = try_set( + obj=confidence_level, + none_acceptable=True, + is_of_type=numbers.Real) + if forcasting_confident_lower_bound_column_name is not None: + inputs['ForcastingConfidentLowerBoundColumnName'] = try_set( + obj=forcasting_confident_lower_bound_column_name, + none_acceptable=True, + is_of_type=str, + is_column=True) + if forcasting_confident_upper_bound_column_name is not None: + inputs['ForcastingConfidentUpperBoundColumnName'] = try_set( + obj=forcasting_confident_upper_bound_column_name, + none_acceptable=True, + is_of_type=str, + is_column=True) + if rank_selection_method is not None: + inputs['RankSelectionMethod'] = try_set( + obj=rank_selection_method, + none_acceptable=True, + is_of_type=str, + values=[ + 'Fixed', + 'Exact', + 'Fast']) + if rank is not None: + inputs['Rank'] = try_set( + obj=rank, + none_acceptable=True, + is_of_type=numbers.Real) + if max_rank is not None: + inputs['MaxRank'] = try_set( + obj=max_rank, + none_acceptable=True, + is_of_type=numbers.Real) + if should_stablize is not None: + inputs['ShouldStablize'] = try_set( + obj=should_stablize, + none_acceptable=True, + is_of_type=bool) + if should_maintain_info is not None: + inputs['ShouldMaintainInfo'] = try_set( + obj=should_maintain_info, + none_acceptable=True, + is_of_type=bool) + if max_growth is not None: + inputs['MaxGrowth'] = try_set( + obj=max_growth, + none_acceptable=True, + is_of_type=dict, + field_names=[ + 'TimeSpan', + 'Growth']) + if discount_factor is not None: + inputs['DiscountFactor'] = try_set( + obj=discount_factor, + none_acceptable=True, + is_of_type=numbers.Real) + if is_adaptive is not None: + inputs['IsAdaptive'] = try_set( + obj=is_adaptive, + none_acceptable=True, + is_of_type=bool) + if output_data is not None: + outputs['OutputData'] = try_set( + obj=output_data, + none_acceptable=False, + is_of_type=str) + if model is not None: + outputs['Model'] = try_set( + obj=model, + none_acceptable=False, + is_of_type=str) + + input_variables = { + x for x in unlist(inputs.values()) + if isinstance(x, str) and x.startswith("$")} + output_variables = { + x for x in unlist(outputs.values()) + if isinstance(x, str) and x.startswith("$")} + + entrypoint = EntryPoint( + name=entrypoint_name, inputs=inputs, outputs=outputs, + input_variables=input_variables, + output_variables=output_variables) + return entrypoint diff --git a/src/python/nimbusml/tests/pipeline/test_pipeline_subclassing.py b/src/python/nimbusml/tests/pipeline/test_pipeline_subclassing.py new file mode 100644 index 00000000..56d9ef42 --- /dev/null +++ b/src/python/nimbusml/tests/pipeline/test_pipeline_subclassing.py @@ -0,0 +1,73 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------------------------- +import os +import unittest + +import numpy as np +import pandas as pd +from nimbusml import Pipeline +from nimbusml.linear_model import LogisticRegressionBinaryClassifier + + +def generate_dataset_1(): + X = pd.DataFrame({'x1': [2, 3, 2, 2, 8, 9, 10, 8], + 'x2': [1, 2, 3, 1, 7, 10, 9, 8]}) + y = pd.DataFrame({'y': [1, 1, 1, 1, 0, 0, 0, 0]}) + return X, y + + +class CustomPipeline(Pipeline): + # Override the predict method + def predict(self, X, *args, **kwargs): + return kwargs.get('test_return_value') + + +class TestPipelineSubclassing(unittest.TestCase): + + def test_pipeline_subclass_can_override_predict(self): + X, y = generate_dataset_1() + + pipeline = Pipeline([LogisticRegressionBinaryClassifier()]) + pipeline.fit(X, y) + result = pipeline.predict(X)['PredictedLabel'] + + self.assertTrue(np.array_equal(result.values, y['y'].values)) + + pipeline = CustomPipeline([LogisticRegressionBinaryClassifier()]) + pipeline.fit(X, y) + + self.assertEqual(pipeline.predict(X, test_return_value=3), 3) + + + def test_pipeline_subclass_correctly_supports_predict_proba(self): + X, y = generate_dataset_1() + + pipeline = Pipeline([LogisticRegressionBinaryClassifier()]) + pipeline.fit(X, y) + orig_result = pipeline.predict_proba(X) + + pipeline = CustomPipeline([LogisticRegressionBinaryClassifier()]) + pipeline.fit(X, y) + new_result = pipeline.predict_proba(X) + + self.assertTrue(np.array_equal(orig_result, new_result)) + + + def test_pipeline_subclass_correctly_supports_decision_function(self): + X, y = generate_dataset_1() + + pipeline = Pipeline([LogisticRegressionBinaryClassifier()]) + pipeline.fit(X, y) + orig_result = pipeline.decision_function(X) + + pipeline = CustomPipeline([LogisticRegressionBinaryClassifier()]) + pipeline.fit(X, y) + new_result = pipeline.decision_function(X) + + self.assertTrue(np.array_equal(orig_result, new_result)) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/python/nimbusml/tests/timeseries/test_ssaforecaster.py b/src/python/nimbusml/tests/timeseries/test_ssaforecaster.py new file mode 100644 index 00000000..ec2cb228 --- /dev/null +++ b/src/python/nimbusml/tests/timeseries/test_ssaforecaster.py @@ -0,0 +1,69 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------------------------- + +import unittest + +import numpy as np +import pandas as pd +from nimbusml import Pipeline, FileDataStream +from nimbusml.datasets import get_dataset +from nimbusml.timeseries import SsaForecaster + + +class TestSsaForecaster(unittest.TestCase): + + @unittest.skip('ml.net libraries containing timeseries forecasting are not included with nimbusml yet.') + def test_simple_forecast(self): + seasonality_size = 5 + seasonal_data = np.arange(seasonality_size) + + data = np.tile(seasonal_data, 3) + + X_train = pd.Series(data, name="ts") + + training_seasons = 3 + training_size = seasonality_size * training_seasons + + forecaster = SsaForecaster(forcasting_confident_lower_bound_column_name="cmin", + forcasting_confident_upper_bound_column_name="cmax", + series_length=8, + train_size=training_size, + window_size=seasonality_size + 1, + horizon=2) << {'fc': 'ts'} + + forecaster.fit(X_train, verbose=1) + data = forecaster.transform(X_train) + + self.assertEqual(round(data.loc[0, 'fc.0']), 1.0) + self.assertEqual(round(data.loc[0, 'fc.1']), 2.0) + + self.assertEqual(len(data['fc.0']), 15) + + @unittest.skip('ml.net libraries containing timeseries forecasting are not included with nimbusml yet.') + def test_multiple_user_specified_columns_is_not_allowed(self): + path = get_dataset('timeseries').as_filepath() + data = FileDataStream.read_csv(path) + + try: + pipeline = Pipeline([ + SsaForecaster(forcasting_confident_lower_bound_column_name="cmin", + forcasting_confident_upper_bound_column_name="cmax", + series_length=8, + train_size=15, + window_size=5, + horizon=2, + columns=['t2', 't3']) + ]) + pipeline.fit_transform(data) + + except RuntimeError as e: + self.assertTrue('Only one column is allowed' in str(e)) + return + + self.fail() + + +if __name__ == '__main__': + unittest.main() diff --git a/src/python/nimbusml/timeseries/__init__.py b/src/python/nimbusml/timeseries/__init__.py index 13db4520..626bcbc3 100644 --- a/src/python/nimbusml/timeseries/__init__.py +++ b/src/python/nimbusml/timeseries/__init__.py @@ -2,10 +2,12 @@ from ._iidchangepointdetector import IidChangePointDetector from ._ssaspikedetector import SsaSpikeDetector from ._ssachangepointdetector import SsaChangePointDetector +from ._ssaforecaster import SsaForecaster __all__ = [ 'IidSpikeDetector', 'IidChangePointDetector', 'SsaSpikeDetector', - 'SsaChangePointDetector' + 'SsaChangePointDetector', + 'SsaForecaster' ] diff --git a/src/python/nimbusml/timeseries/_ssaforecaster.py b/src/python/nimbusml/timeseries/_ssaforecaster.py new file mode 100644 index 00000000..6d0a52ce --- /dev/null +++ b/src/python/nimbusml/timeseries/_ssaforecaster.py @@ -0,0 +1,132 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------------------------- +# - Generated by tools/entrypoint_compiler.py: do not edit by hand +""" +SsaForecaster +""" + +__all__ = ["SsaForecaster"] + + +from sklearn.base import TransformerMixin + +from ..base_transform import BaseTransform +from ..internal.core.timeseries._ssaforecaster import SsaForecaster as core +from ..internal.utils.utils import trace + + +class SsaForecaster(core, BaseTransform, TransformerMixin): + """ + **Description** + This transform forecasts using Singular Spectrum Analysis (SSA). + + :param columns: see `Columns `_. + + :param window_size: The length of the window on the series for building the + trajectory matrix (parameter L). + + :param series_length: The length of series that is kept in buffer for + modeling (parameter N). + + :param train_size: The length of series from the begining used for + training. + + :param horizon: The number of values to forecast. + + :param confidence_level: The confidence level in [0, 1) for forecasting. + + :param forcasting_confident_lower_bound_column_name: The name of the + confidence interval lower bound column. + + :param forcasting_confident_upper_bound_column_name: The name of the + confidence interval upper bound column. + + :param rank_selection_method: The rank selection method. + + :param rank: The desired rank of the subspace used for SSA projection + (parameter r). This parameter should be in the range in [1, + windowSize]. If set to null, the rank is automatically determined based + on prediction error minimization. + + :param max_rank: The maximum rank considered during the rank selection + process. If not provided (i.e. set to null), it is set to windowSize - + 1. + + :param should_stablize: The flag determining whether the model should be + stabilized. + + :param should_maintain_info: The flag determining whether the meta + information for the model needs to be maintained. + + :param max_growth: The maximum growth on the exponential trend. + + :param discount_factor: The discount factor in [0,1] used for online + updates. + + :param is_adaptive: The flag determing whether the model is adaptive. + + :param params: Additional arguments sent to compute engine. + + """ + + @trace + def __init__( + self, + window_size=0, + series_length=0, + train_size=0, + horizon=0, + confidence_level=0.95, + forcasting_confident_lower_bound_column_name=None, + forcasting_confident_upper_bound_column_name=None, + rank_selection_method='Exact', + rank=None, + max_rank=None, + should_stablize=True, + should_maintain_info=False, + max_growth=None, + discount_factor=1.0, + is_adaptive=False, + columns=None, + **params): + + if columns: + params['columns'] = columns + BaseTransform.__init__(self, **params) + core.__init__( + self, + window_size=window_size, + series_length=series_length, + train_size=train_size, + horizon=horizon, + confidence_level=confidence_level, + forcasting_confident_lower_bound_column_name=forcasting_confident_lower_bound_column_name, + forcasting_confident_upper_bound_column_name=forcasting_confident_upper_bound_column_name, + rank_selection_method=rank_selection_method, + rank=rank, + max_rank=max_rank, + should_stablize=should_stablize, + should_maintain_info=should_maintain_info, + max_growth=max_growth, + discount_factor=discount_factor, + is_adaptive=is_adaptive, + **params) + self._columns = columns + + def get_params(self, deep=False): + """ + Get the parameters for this operator. + """ + return core.get_params(self) + + def _nodes_with_presteps(self): + """ + Inserts preprocessing before this one. + """ + from ..preprocessing.schema import TypeConverter + return [ + TypeConverter( + result_type='R4')._steal_io(self), + self] diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py index 5dac16f5..d0e42a64 100644 --- a/src/python/tests/test_estimator_checks.py +++ b/src/python/tests/test_estimator_checks.py @@ -17,7 +17,8 @@ from nimbusml.preprocessing import TensorFlowScorer from nimbusml.preprocessing.filter import SkipFilter, TakeFilter from nimbusml.timeseries import (IidSpikeDetector, IidChangePointDetector, - SsaSpikeDetector, SsaChangePointDetector) + SsaSpikeDetector, SsaChangePointDetector, + SsaForecaster) from sklearn.utils.estimator_checks import _yield_all_checks, MULTI_OUTPUT this = os.path.abspath(os.path.dirname(__file__)) @@ -62,6 +63,8 @@ 'check_fit2d_1sample', # SSA requires more than one sample 'SsaChangePointDetector': 'check_estimator_sparse_data' 'check_fit2d_1sample', # SSA requires more than one sample + 'SsaForecaster': 'check_estimator_sparse_data' + 'check_fit2d_1sample', # SSA requires more than one sample # bug, low tolerance 'FastLinearRegressor': 'check_supervised_y_2d, ' 'check_regressor_data_not_an_array, ' @@ -193,6 +196,13 @@ 'IidChangePointDetector': IidChangePointDetector(columns=['F0']), 'SsaSpikeDetector': SsaSpikeDetector(columns=['F0'], seasonal_window_size=2), 'SsaChangePointDetector': SsaChangePointDetector(columns=['F0'], seasonal_window_size=2), + 'SsaForecaster': SsaForecaster(columns=['F0'], + window_size=2, + series_length=5, + train_size=5, + horizon=1, + forcasting_confident_lower_bound_column_name="cmin", + forcasting_confident_upper_bound_column_name="cmax"), 'TensorFlowScorer': TensorFlowScorer( model_location=os.path.join( this, @@ -270,6 +280,9 @@ def load_json(file_path): # skip SymSgdBinaryClassifier for now, because of crashes. if 'SymSgdBinaryClassifier' in class_name: continue + # skip for now because the ml.net binaries do not contain the SsaForecasting code. + if 'SsaForecaster' in class_name: + continue mod = __import__('nimbusml.' + e[0], fromlist=[str(class_name)]) the_class = getattr(mod, class_name) diff --git a/src/python/tests_extended/test_docs_example.py b/src/python/tests_extended/test_docs_example.py index 23fb2f82..27470667 100644 --- a/src/python/tests_extended/test_docs_example.py +++ b/src/python/tests_extended/test_docs_example.py @@ -65,13 +65,14 @@ def test_examples(self): 'SymSgdBinaryClassifier.py', 'SymSgdBinaryClassifier_infert_df.py', # MICROSOFTML_RESOURCE_PATH needs to be setup on linux + 'CharTokenizer.py', 'WordEmbedding.py', 'WordEmbedding_df.py', 'NaiveBayesClassifier_df.py' ]: continue # skip for ubuntu 14 tests - if platform.linux_distribution()[0] == 'Ubuntu' and platform.linux_distribution()[1][:2] == '14': + if platform.linux_distribution()[1] == 'jessie/sid': if name in [ # libdl needs to be setup 'Image.py', diff --git a/src/python/tools/compiler_utils.py b/src/python/tools/compiler_utils.py index 9a5e1e07..c64f5af3 100644 --- a/src/python/tools/compiler_utils.py +++ b/src/python/tools/compiler_utils.py @@ -135,6 +135,7 @@ def _nodes_with_presteps(self): 'IidChangePointDetector': timeseries_to_r4_converter, 'SsaSpikeDetector': timeseries_to_r4_converter, 'SsaChangePointDetector': timeseries_to_r4_converter, + 'SsaForecaster': timeseries_to_r4_converter, 'PcaTransformer': '''from ..preprocessing.schema import TypeConverter diff --git a/src/python/tools/manifest.json b/src/python/tools/manifest.json index 67951c74..b79dea53 100644 --- a/src/python/tools/manifest.json +++ b/src/python/tools/manifest.json @@ -4007,6 +4007,235 @@ "ITransformOutput" ] }, + { + "Name": "TimeSeriesProcessingEntryPoints.SsaForecasting", + "Desc": "This transform forecasts using Singular Spectrum Analysis (SSA).", + "FriendlyName": "SSA Forecasting", + "ShortName": "ssafcst", + "Inputs": [ + { + "Name": "Source", + "Type": "String", + "Desc": "The name of the source column.", + "Aliases": [ + "src" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Name", + "Type": "String", + "Desc": "The name of the new column.", + "Required": true, + "SortOrder": 2.0, + "IsNullable": false + }, + { + "Name": "WindowSize", + "Type": "Int", + "Desc": "The length of the window on the series for building the trajectory matrix (parameter L).", + "Required": true, + "SortOrder": 2.0, + "IsNullable": false, + "Default": 0 + }, + { + "Name": "SeriesLength", + "Type": "Int", + "Desc": "The length of series that is kept in buffer for modeling (parameter N).", + "Required": true, + "SortOrder": 2.0, + "IsNullable": false, + "Default": 0 + }, + { + "Name": "TrainSize", + "Type": "Int", + "Desc": "The length of series from the begining used for training.", + "Required": true, + "SortOrder": 2.0, + "IsNullable": false, + "Default": 0 + }, + { + "Name": "Horizon", + "Type": "Int", + "Desc": "The number of values to forecast.", + "Required": true, + "SortOrder": 2.0, + "IsNullable": false, + "Default": 0 + }, + { + "Name": "ConfidenceLevel", + "Type": "Float", + "Desc": "The confidence level in [0, 1) for forecasting.", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": 0.95 + }, + { + "Name": "ForcastingConfidentLowerBoundColumnName", + "Type": "String", + "Desc": "The name of the confidence interval lower bound column.", + "Aliases": [ + "cnfminname" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "ForcastingConfidentUpperBoundColumnName", + "Type": "String", + "Desc": "The name of the confidence interval upper bound column.", + "Aliases": [ + "cnfmaxnname" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "RankSelectionMethod", + "Type": { + "Kind": "Enum", + "Values": [ + "Fixed", + "Exact", + "Fast" + ] + }, + "Desc": "The rank selection method.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": "Exact" + }, + { + "Name": "Rank", + "Type": "Int", + "Desc": "The desired rank of the subspace used for SSA projection (parameter r). This parameter should be in the range in [1, windowSize]. If set to null, the rank is automatically determined based on prediction error minimization.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "MaxRank", + "Type": "Int", + "Desc": "The maximum rank considered during the rank selection process. If not provided (i.e. set to null), it is set to windowSize - 1.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "ShouldStablize", + "Type": "Bool", + "Desc": "The flag determining whether the model should be stabilized.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": true + }, + { + "Name": "ShouldMaintainInfo", + "Type": "Bool", + "Desc": "The flag determining whether the meta information for the model needs to be maintained.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": false + }, + { + "Name": "MaxGrowth", + "Type": { + "Kind": "Struct", + "Fields": [ + { + "Name": "TimeSpan", + "Type": "Int", + "Desc": "Time span of growth ratio. Must be strictly positive.", + "Required": false, + "SortOrder": 1.0, + "IsNullable": false, + "Default": 0 + }, + { + "Name": "Growth", + "Type": "Float", + "Desc": "Growth. Must be non-negative.", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": 0.0 + } + ] + }, + "Desc": "The maximum growth on the exponential trend.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "DiscountFactor", + "Type": "Float", + "Desc": "The discount factor in [0,1] used for online updates.", + "Aliases": [ + "disc" + ], + "Required": false, + "SortOrder": 5.0, + "IsNullable": false, + "Default": 1.0 + }, + { + "Name": "IsAdaptive", + "Type": "Bool", + "Desc": "The flag determing whether the model is adaptive", + "Aliases": [ + "adp" + ], + "Required": false, + "SortOrder": 6.0, + "IsNullable": false, + "Default": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "TimeSeriesProcessingEntryPoints.SsaSpikeDetector", "Desc": "This transform detects the spikes in a seasonal time-series using Singular Spectrum Analysis (SSA).", diff --git a/src/python/tools/manifest_diff.json b/src/python/tools/manifest_diff.json index 25708e21..58d6b3a5 100644 --- a/src/python/tools/manifest_diff.json +++ b/src/python/tools/manifest_diff.json @@ -563,6 +563,12 @@ "Module": "timeseries", "Type": "Transform" }, + { + "Name": "TimeSeriesProcessingEntryPoints.SsaForecasting", + "NewName": "SsaForecaster", + "Module": "timeseries", + "Type": "Transform" + }, { "Name": "Trainers.PoissonRegressor", "NewName": "PoissonRegressionRegressor",