-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Adding Factorization Machines #383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System.Reflection; | ||
| using System.Runtime.CompilerServices; | ||
| using System.Runtime.InteropServices; | ||
|
|
||
| [assembly: InternalsVisibleTo("Microsoft.ML.StandardLearners, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,94 @@ | ||||
| // Licensed to the .NET Foundation under one or more agreements. | ||||
| // The .NET Foundation licenses this file to you under the MIT license. | ||||
| // See the LICENSE file in the project root for more information. | ||||
|
|
||||
| using Microsoft.ML.Runtime.Internal.CpuMath; | ||||
| using Microsoft.ML.Runtime.Internal.Utilities; | ||||
| using System.Runtime.InteropServices; | ||||
|
|
||||
| using System.Security; | ||||
|
|
||||
| namespace Microsoft.ML.Runtime.FactorizationMachine | ||||
| { | ||||
| internal unsafe static class FieldAwareFactorizationMachineInterface | ||||
| { | ||||
| internal const string NativePath = "FactorizationMachineNative"; | ||||
| public const int CbAlign = 16; | ||||
|
|
||||
| private static bool Compat(AlignedArray a) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @sfilipi and @wschin , could I ask, was the usage of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a few small benchmark performance tests I've run,
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FFM's core computation is done by SSE code, which requires the memory blocks to be aligned. The main computation doesn't call any member functions of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From my investigation, I concluded that the performance penalty being paid was where we are moving the array elements around in memory to manually align it.
In my experience, doing this copying is worse performance than just using unaligned reads. Another way to fix this issue is to use both a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||
| { | ||||
| Contracts.AssertValue(a); | ||||
| Contracts.Assert(a.Size > 0); | ||||
| return a.CbAlign == CbAlign; | ||||
| } | ||||
|
|
||||
| private unsafe static float* Ptr(AlignedArray a, float* p) | ||||
| { | ||||
| Contracts.AssertValue(a); | ||||
| float* q = p + a.GetBase((long)p); | ||||
| Contracts.Assert(((long)q & (CbAlign - 1)) == 0); | ||||
| return q; | ||||
| } | ||||
|
|
||||
| [DllImport(NativePath), SuppressUnmanagedCodeSecurity] | ||||
| public static extern void CalculateIntermediateVariablesNative(int fieldCount, int latentDim, int count, int* /*const*/ fieldIndices, int* /*const*/ featureIndices, | ||||
| float* /*const*/ featureValues, float* /*const*/ linearWeights, float* /*const*/ latentWeights, float* latentSum, float* response); | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
are this comments here for a reason? #Resolved
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added those comments
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||
|
|
||||
| [DllImport(NativePath), SuppressUnmanagedCodeSecurity] | ||||
| public static extern void CalculateGradientAndUpdateNative(float lambdaLinear, float lambdaLatent, float learningRate, int fieldCount, int latentDim, float weight, | ||||
| int count, int* /*const*/ fieldIndices, int* /*const*/ featureIndices, float* /*const*/ featureValues, float* /*const*/ latentSum, float slope, | ||||
| float* linearWeights, float* latentWeights, float* linearAccumulatedSquaredGrads, float* latentAccumulatedSquaredGrads); | ||||
|
|
||||
| public static void CalculateIntermediateVariables(int fieldCount, int latentDim, int count, int[] fieldIndices, int[] featureIndices, float[] featureValues, | ||||
| float[] linearWeights, AlignedArray latentWeights, AlignedArray latentSum, ref float response) | ||||
| { | ||||
| Contracts.AssertNonEmpty(fieldIndices); | ||||
| Contracts.AssertNonEmpty(featureValues); | ||||
| Contracts.AssertNonEmpty(featureIndices); | ||||
| Contracts.AssertNonEmpty(linearWeights); | ||||
| Contracts.Assert(Compat(latentWeights)); | ||||
| Contracts.Assert(Compat(latentSum)); | ||||
|
|
||||
| unsafe | ||||
| { | ||||
| fixed (int* pf = &fieldIndices[0]) | ||||
| fixed (int* pi = &featureIndices[0]) | ||||
| fixed (float* px = &featureValues[0]) | ||||
| fixed (float* pw = &linearWeights[0]) | ||||
| fixed (float* pv = &latentWeights.Items[0]) | ||||
| fixed (float* pq = &latentSum.Items[0]) | ||||
| fixed (float* pr = &response) | ||||
| CalculateIntermediateVariablesNative(fieldCount, latentDim, count, pf, pi, px, pw, Ptr(latentWeights, pv), Ptr(latentSum, pq), pr); | ||||
| } | ||||
| } | ||||
|
|
||||
| public static void CalculateGradientAndUpdate(float lambdaLinear, float lambdaLatent, float learningRate, int fieldCount, int latentDim, | ||||
| float weight, int count, int[] fieldIndices, int[] featureIndices, float[] featureValues, AlignedArray latentSum, float slope, | ||||
| float[] linearWeights, AlignedArray latentWeights, float[] linearAccumulatedSquaredGrads, AlignedArray latentAccumulatedSquaredGrads) | ||||
| { | ||||
| Contracts.AssertNonEmpty(fieldIndices); | ||||
| Contracts.AssertNonEmpty(featureIndices); | ||||
| Contracts.AssertNonEmpty(featureValues); | ||||
| Contracts.Assert(Compat(latentSum)); | ||||
| Contracts.AssertNonEmpty(linearWeights); | ||||
| Contracts.Assert(Compat(latentWeights)); | ||||
| Contracts.AssertNonEmpty(linearAccumulatedSquaredGrads); | ||||
| Contracts.Assert(Compat(latentAccumulatedSquaredGrads)); | ||||
|
|
||||
| unsafe | ||||
| { | ||||
| fixed (int* pf = &fieldIndices[0]) | ||||
| fixed (int* pi = &featureIndices[0]) | ||||
| fixed (float* px = &featureValues[0]) | ||||
| fixed (float* pq = &latentSum.Items[0]) | ||||
| fixed (float* pw = &linearWeights[0]) | ||||
| fixed (float* pv = &latentWeights.Items[0]) | ||||
| fixed (float* phw = &linearAccumulatedSquaredGrads[0]) | ||||
| fixed (float* phv = &latentAccumulatedSquaredGrads.Items[0]) | ||||
| CalculateGradientAndUpdateNative(lambdaLinear, lambdaLatent, learningRate, fieldCount, latentDim, weight, count, pf, pi, px, | ||||
| Ptr(latentSum, pq), slope, pw, Ptr(latentWeights, pv), phw, Ptr(latentAccumulatedSquaredGrads, phv)); | ||||
| } | ||||
|
|
||||
| } | ||||
| } | ||||
| } | ||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AlignedArray.Items is internal, but gets accessed in FactorizationMachines.
Will get rid of it we move off AlignedArray.
In reply to: 198213276 [](ancestors = 198213276)