From 11200ff3231dcef2ed807c254ee030dc69f2e2a8 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 23 Jan 2023 11:54:08 -0700 Subject: [PATCH] fixes one dal dispatching issues --- src/Microsoft.ML.Data/MLContext.cs | 20 +++++++++++++++++++ .../RandomForestClassification.cs | 16 +-------------- .../RandomForestRegression.cs | 16 +-------------- .../OlsLinearRegression.cs | 16 +-------------- 4 files changed, 23 insertions(+), 45 deletions(-) diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index 89c07c5715..f0e1986c5e 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Reflection; using Microsoft.ML.Data; using Microsoft.ML.Runtime; @@ -171,5 +172,24 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) [BestFriend] internal void CancelExecution() => ((ICancelable)_env).CancelExecution(); + + [BestFriend] + internal static readonly bool OneDalDispatchingEnabled = InitializeOneDalDispatchingEnabled(); + + private static bool InitializeOneDalDispatchingEnabled() + { + try + { + var asm = Assembly.Load("Microsoft.ML.OneDal"); + var type = asm.GetType("Microsoft.ML.OneDal.OneDalUtils"); + var method = type.GetMethod("IsDispatchingEnabled", BindingFlags.Public | BindingFlags.Static | BindingFlags.NonPublic); + var result = method.Invoke(null, null); + return (bool)result; + } + catch + { + return false; + } + } } } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 62809c333d..a63811576e 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -224,7 +224,7 @@ private protected override FastForestBinaryModelParameters TrainModelCore(TrainC FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount(); ConvertData(trainData); - if (!trainData.Schema.Weight.HasValue && IsDispatchingToOneDalEnabled()) + if (!trainData.Schema.Weight.HasValue && MLContext.OneDalDispatchingEnabled) { if (FastTreeTrainerOptions.FeatureFraction != 1.0) { @@ -262,20 +262,6 @@ public static extern unsafe int DecisionForestClassificationCompute( void* lteChildPtr, void* gtChildPtr, void* splitFeaturePtr, void* featureThresholdPtr, void* leafValuesPtr, void* modelPtr); } - [BestFriend] - private bool IsDispatchingToOneDalEnabled() - { - try - { - return OneDalUtils.IsDispatchingEnabled(); - } - catch (Exception) - { - // Bail to default implementation upon encountering any situation where dispatch failed - return false; - } - } - [BestFriend] private void TrainCoreOneDal(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount) { diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index f1969f2cb2..8744fdf16c 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -363,7 +363,7 @@ private protected override FastForestRegressionModelParameters TrainModelCore(Tr FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount(); ConvertData(trainData); - if (!trainData.Schema.Weight.HasValue && IsDispatchingToOneDalEnabled()) + if (!trainData.Schema.Weight.HasValue && MLContext.OneDalDispatchingEnabled) { if (FastTreeTrainerOptions.FeatureFraction != 1.0) { @@ -395,20 +395,6 @@ public static extern unsafe int DecisionForestRegressionCompute( void* lteChildPtr, void* gtChildPtr, void* splitFeaturePtr, void* featureThresholdPtr, void* leafValuesPtr, void* modelPtr); } - [BestFriend] - private bool IsDispatchingToOneDalEnabled() - { - try - { - return OneDalUtils.IsDispatchingEnabled(); - } - catch (Exception) - { - // fall back to original implementation for any circumstance that prevents dispatching - return false; - } - } - [BestFriend] private void TrainCoreOneDal(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount) { diff --git a/src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs b/src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs index 6f4f721121..1d072c1290 100644 --- a/src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs +++ b/src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs @@ -406,20 +406,6 @@ private void ComputeMklRegression(IChannel ch, FloatLabelCursor.Factory cursorFa xty = null; } - [BestFriend] - private bool IsDispatchingToOneDalEnabled() - { - try - { - return OneDalUtils.IsDispatchingEnabled(); - } - catch (Exception) - { - // Bail to default implementation upon any situation that prevents dispatching - return false; - } - } - private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount) { Host.AssertValue(ch); @@ -440,7 +426,7 @@ private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory curso var beta = new Double[m]; Double yMean = 0; - if (IsDispatchingToOneDalEnabled()) + if (MLContext.OneDalDispatchingEnabled) { ComputeOneDalRegression(ch, cursorFactory, m, ref beta, xtx, ref n, ref yMean); }