Skip to content

Commit 1e26350

Browse files
authored
Add Initialize() method to TensorFlow and ImageAnalytics (#795)
* Add Initialize() method to TensorFlow and ImageAnalytics * Trigger a build
1 parent 6e0d8d0 commit 1e26350

File tree

5 files changed

+29
-4
lines changed

5 files changed

+29
-4
lines changed

src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ namespace Microsoft.ML.Runtime.ImageAnalytics.EntryPoints
1111
{
1212
public static class ImageAnalytics
1313
{
14+
// This method is needed for the Pipeline API, since ModuleCatalog does not load entry points that are located
15+
// in assemblies that aren't directly used in the code. Users who want to use ImageAnalytics components will have to call
16+
// ImageAnalytics.Initialize() before creating the pipeline.
17+
/// <summary>
18+
/// Initialize the Image Analytics environment. Call this method before adding Image components to a learning pipeline.
19+
/// </summary>
20+
public static void Initialize()
21+
{
22+
}
23+
1424
[TlcModule.EntryPoint(Name = "Transforms.ImageLoader", Desc = ImageLoaderTransform.Summary,
1525
UserName = ImageLoaderTransform.UserName, ShortName = ImageLoaderTransform.LoaderSignature)]
1626
public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input)

src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
using Microsoft.ML.Runtime.Internal.Utilities;
1919
using Microsoft.ML.Runtime.Model;
2020

21-
[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), typeof(ImageResizerTransform.Arguments),
21+
[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(IDataTransform), typeof(ImageResizerTransform), typeof(ImageResizerTransform.Arguments),
2222
typeof(SignatureDataTransform), ImageResizerTransform.UserName, "ImageResizerTransform", "ImageResizer")]
2323

24-
[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform),
24+
[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(IDataTransform), typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform),
2525
ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
2626

2727
[assembly: LoadableClass(typeof(ImageResizerTransform), null, typeof(SignatureLoadModel),

src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
<ItemGroup>
1111
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1212
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
13+
<ProjectReference Include="..\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
1314
</ItemGroup>
1415

1516
<ItemGroup>

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,23 @@
55
using System;
66
using System.Runtime.InteropServices;
77
using Microsoft.ML.Runtime.Data;
8+
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
89

910
namespace Microsoft.ML.Transforms.TensorFlow
1011
{
11-
internal partial class TensorFlowUtils
12+
public static class TensorFlowUtils
1213
{
14+
// This method is needed for the Pipeline API, since ModuleCatalog does not load entry points that are located
15+
// in assemblies that aren't directly used in the code. Users who want to use TensorFlow components will have to call
16+
// TensorFlowUtils.Initialize() before creating the pipeline.
17+
/// <summary>
18+
/// Initialize the TensorFlow environment. Call this method before adding TensorFlow components to a learning pipeline.
19+
/// </summary>
20+
public static void Initialize()
21+
{
22+
ImageAnalytics.Initialize();
23+
}
24+
1325
internal static PrimitiveType Tf2MlNetType(TFDataType type)
1426
{
1527
switch (type)
@@ -27,7 +39,7 @@ internal static PrimitiveType Tf2MlNetType(TFDataType type)
2739
}
2840
}
2941

30-
public static unsafe void FetchData<T>(IntPtr data, T[] result)
42+
internal static unsafe void FetchData<T>(IntPtr data, T[] result)
3143
{
3244
var size = result.Length;
3345

test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.ML.Runtime.LightGBM;
1010
using Microsoft.ML.Trainers;
1111
using Microsoft.ML.Transforms;
12+
using Microsoft.ML.Transforms.TensorFlow;
1213
using System.Collections.Generic;
1314
using System.IO;
1415
using Xunit;
@@ -57,6 +58,7 @@ public void TensorFlowTransformCifarLearningPipelineTest()
5758
pipeline.Add(new TextToKeyConverter("Label"));
5859
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
5960

61+
TensorFlowUtils.Initialize();
6062
var model = pipeline.Train<CifarData, CifarPrediction>();
6163
string[] scoreLabels;
6264
model.TryGetScoreLabelNames(out scoreLabels);

0 commit comments

Comments
 (0)