|
10 | 10 |
|
11 | 11 | namespace Microsoft.ML |
12 | 12 | { |
13 | | - // REVIEW: Temporarly moving here since it is used by the Legacy project. Remove when removing the legacy project. |
14 | | - /// <summary> |
15 | | - /// A class that runs the previously trained model (and the preceding transform pipeline) on the |
16 | | - /// in-memory data in batch mode. |
17 | | - /// This can also be used with trained pipelines that do not end with a predictor: in this case, the |
18 | | - /// 'prediction' will be just the outcome of all the transformations. |
19 | | - /// </summary> |
20 | | - /// <typeparam name="TSrc">The user-defined type that holds the example.</typeparam> |
21 | | - /// <typeparam name="TDst">The user-defined type that holds the prediction.</typeparam> |
22 | | - [BestFriend] |
23 | | - internal sealed class BatchPredictionEngine<TSrc, TDst> |
24 | | - where TSrc : class |
25 | | - where TDst : class, new() |
26 | | - { |
27 | | - // The source data view. |
28 | | - private readonly DataViewConstructionUtils.StreamingDataView<TSrc> _srcDataView; |
29 | | - // The transformation engine. |
30 | | - private readonly PipeEngine<TDst> _pipeEngine; |
31 | | - |
32 | | - internal BatchPredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns, |
33 | | - SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) |
34 | | - { |
35 | | - Contracts.AssertValue(env); |
36 | | - Contracts.AssertValue(modelStream); |
37 | | - Contracts.AssertValueOrNull(inputSchemaDefinition); |
38 | | - Contracts.AssertValueOrNull(outputSchemaDefinition); |
39 | | - |
40 | | - // Initialize pipe. |
41 | | - _srcDataView = DataViewConstructionUtils.CreateFromEnumerable(env, new TSrc[] { }, inputSchemaDefinition); |
42 | | - var pipe = DataViewConstructionUtils.LoadPipeWithPredictor(env, modelStream, _srcDataView); |
43 | | - _pipeEngine = new PipeEngine<TDst>(env, pipe, ignoreMissingColumns, outputSchemaDefinition); |
44 | | - } |
45 | | - |
46 | | - internal BatchPredictionEngine(IHostEnvironment env, IDataView dataPipeline, bool ignoreMissingColumns, |
47 | | - SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) |
48 | | - { |
49 | | - Contracts.AssertValue(env); |
50 | | - Contracts.AssertValue(dataPipeline); |
51 | | - Contracts.AssertValueOrNull(inputSchemaDefinition); |
52 | | - Contracts.AssertValueOrNull(outputSchemaDefinition); |
53 | | - |
54 | | - // Initialize pipe. |
55 | | - _srcDataView = DataViewConstructionUtils.CreateFromEnumerable(env, new TSrc[] { }, inputSchemaDefinition); |
56 | | - var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, dataPipeline, _srcDataView); |
57 | | - |
58 | | - _pipeEngine = new PipeEngine<TDst>(env, pipe, ignoreMissingColumns, outputSchemaDefinition); |
59 | | - } |
60 | | - |
61 | | - /// <summary> |
62 | | - /// Run the prediction pipe. This will enumerate the <paramref name="examples"/> exactly once, |
63 | | - /// cache all the examples (by reference) into its internal representation and then run |
64 | | - /// the transformation pipe. |
65 | | - /// </summary> |
66 | | - /// <param name="examples">The examples to run the prediction on.</param> |
67 | | - /// <param name="reuseRowObjects">If <c>true</c>, the engine will not allocate memory per output, and |
68 | | - /// the returned <typeparamref name="TDst"/> objects will actually always be the same object. The user is |
69 | | - /// expected to clone the values himself if needed.</param> |
70 | | - /// <returns>The <see cref="IEnumerable{TDst}"/> that contains all the pipeline results.</returns> |
71 | | - public IEnumerable<TDst> Predict(IEnumerable<TSrc> examples, bool reuseRowObjects) |
72 | | - { |
73 | | - Contracts.CheckValue(examples, nameof(examples)); |
74 | | - |
75 | | - _pipeEngine.Reset(); |
76 | | - _srcDataView.SetData(examples); |
77 | | - return _pipeEngine.RunPipe(reuseRowObjects); |
78 | | - } |
79 | | - } |
80 | 13 |
|
81 | 14 | /// <summary> |
82 | 15 | /// Utility class to run the pipeline to completion and produce a strongly-typed IEnumerable as a result. |
|
0 commit comments