|
20 | 20 | new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName, |
21 | 21 | FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")] |
22 | 22 |
|
23 | | -[assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictor), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachinePredictor.LoaderSignature)] |
24 | | - |
25 | 23 | [assembly: LoadableClass(typeof(void), typeof(FieldAwareFactorizationMachineTrainer), null, typeof(SignatureEntryPointModule), FieldAwareFactorizationMachineTrainer.LoadName)] |
26 | 24 |
|
27 | 25 | namespace Microsoft.ML.Runtime.FactorizationMachine |
28 | 26 | { |
29 | | - internal sealed class FieldAwareFactorizationMachineUtils |
30 | | - { |
31 | | - internal static int GetAlignedVectorLength(int length) |
32 | | - { |
33 | | - int res = length % 4; |
34 | | - if (res == 0) |
35 | | - return length; |
36 | | - else |
37 | | - return length + (4 - res); |
38 | | - } |
39 | | - |
40 | | - internal static bool LoadOneExampleIntoBuffer(ValueGetter<VBuffer<float>>[] getters, VBuffer<float> featureBuffer, bool norm, ref int count, |
41 | | - int[] fieldIndexBuffer, int[] featureIndexBuffer, float[] featureValueBuffer) |
42 | | - { |
43 | | - count = 0; |
44 | | - float featureNorm = 0; |
45 | | - int bias = 0; |
46 | | - float annihilation = 0; |
47 | | - for (int f = 0; f < getters.Length; f++) |
48 | | - { |
49 | | - getters[f](ref featureBuffer); |
50 | | - foreach (var pair in featureBuffer.Items()) |
51 | | - { |
52 | | - fieldIndexBuffer[count] = f; |
53 | | - featureIndexBuffer[count] = bias + pair.Key; |
54 | | - featureValueBuffer[count] = pair.Value; |
55 | | - featureNorm += pair.Value * pair.Value; |
56 | | - annihilation += pair.Value - pair.Value; |
57 | | - count++; |
58 | | - } |
59 | | - bias += featureBuffer.Length; |
60 | | - } |
61 | | - featureNorm = MathUtils.Sqrt(featureNorm); |
62 | | - if (norm) |
63 | | - { |
64 | | - for (int i = 0; i < count; i++) |
65 | | - featureValueBuffer[i] /= featureNorm; |
66 | | - } |
67 | | - return FloatUtils.IsFinite(annihilation); |
68 | | - } |
69 | | - } |
70 | | - |
71 | 27 | /// <summary> |
72 | 28 | /// Train a field-aware factorization machine using ADAGRAD (an advanced stochastic gradient method). See references below |
73 | 29 | /// for details. This trainer is essentially faster the one introduced in [2] because of some implemtation tricks[3]. |
74 | 30 | /// [1] http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf |
75 | 31 | /// [2] http://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf |
76 | | - /// [3] fast-ffm.tex in FactorizationMachine project folder |
| 32 | + /// [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf |
77 | 33 | /// </summary> |
78 | 34 | public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase<RoleMappedData, FieldAwareFactorizationMachinePredictor>, |
79 | 35 | IIncrementalTrainer<RoleMappedData, FieldAwareFactorizationMachinePredictor>, IValidatingTrainer<RoleMappedData>, |
@@ -327,6 +283,8 @@ private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, R |
327 | 283 | Func<int, bool> pred = c => fieldColumnIndexes.Contains(c) || c == data.Schema.Label.Index || (data.Schema.Weight != null && c == data.Schema.Weight.Index); |
328 | 284 | InitializeTrainingState(fieldCount, totalFeatureCount, predictor, out float[] linearWeights, |
329 | 285 | out AlignedArray latentWeightsAligned, out float[] linearAccSqGrads, out AlignedArray latentAccSqGradsAligned); |
| 286 | + |
| 287 | + // refer to Algorithm 3 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf |
330 | 288 | while (iter++ < _numIterations) |
331 | 289 | { |
332 | 290 | using (var cursor = data.Data.GetRowCursor(pred, rng)) |
@@ -358,9 +316,13 @@ private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, R |
358 | 316 | badExampleCount++; |
359 | 317 | continue; |
360 | 318 | } |
| 319 | + |
| 320 | + // refer to Algorithm 1 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf |
361 | 321 | FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(fieldCount, _latentDimAligned, count, |
362 | 322 | featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse); |
363 | 323 | var slope = CalculateLossSlope(label, modelResponse); |
| 324 | + |
| 325 | + // refer to Algorithm 2 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf |
364 | 326 | FieldAwareFactorizationMachineInterface.CalculateGradientAndUpdate(_lambdaLinear, _lambdaLatent, _learningRate, fieldCount, _latentDimAligned, weight, count, |
365 | 327 | featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum, slope, linearWeights, latentWeightsAligned, linearAccSqGrads, latentAccSqGradsAligned); |
366 | 328 | loss += weight * CalculateLoss(label, modelResponse); |
@@ -453,259 +415,4 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm |
453 | 415 | () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); |
454 | 416 | } |
455 | 417 | } |
456 | | - |
457 | | - public sealed class FieldAwareFactorizationMachinePredictor : PredictorBase<float>, ISchemaBindableMapper, ICanSaveModel |
458 | | - { |
459 | | - public const string LoaderSignature = "FieldAwareFactMacPredict"; |
460 | | - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; |
461 | | - private bool _norm; |
462 | | - internal int FieldCount { get; } |
463 | | - internal int FeatureCount { get; } |
464 | | - internal int LatentDim { get; } |
465 | | - internal int LatentDimAligned { get; } |
466 | | - private readonly float[] _linearWeights; |
467 | | - private readonly AlignedArray _latentWeightsAligned; |
468 | | - |
469 | | - private static VersionInfo GetVersionInfo() |
470 | | - { |
471 | | - return new VersionInfo( |
472 | | - modelSignature: "FAFAMAPD", |
473 | | - verWrittenCur: 0x00010001, |
474 | | - verReadableCur: 0x00010001, |
475 | | - verWeCanReadBack: 0x00010001, |
476 | | - loaderSignature: LoaderSignature); |
477 | | - } |
478 | | - |
479 | | - internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim, |
480 | | - float[] linearWeights, AlignedArray latentWeightsAligned) : base(env, LoaderSignature) |
481 | | - { |
482 | | - Host.Assert(fieldCount > 0); |
483 | | - Host.Assert(featureCount > 0); |
484 | | - Host.Assert(latentDim > 0); |
485 | | - Host.Assert(Utils.Size(linearWeights) == featureCount); |
486 | | - LatentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(latentDim); |
487 | | - Host.Assert(latentWeightsAligned.Size == checked(featureCount * fieldCount * LatentDimAligned)); |
488 | | - |
489 | | - _norm = norm; |
490 | | - FieldCount = fieldCount; |
491 | | - FeatureCount = featureCount; |
492 | | - LatentDim = latentDim; |
493 | | - _linearWeights = linearWeights; |
494 | | - _latentWeightsAligned = latentWeightsAligned; |
495 | | - } |
496 | | - |
497 | | - private FieldAwareFactorizationMachinePredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature) |
498 | | - { |
499 | | - Host.AssertValue(ctx); |
500 | | - |
501 | | - // *** Binary format *** |
502 | | - // bool: whether to normalize feature vectors |
503 | | - // int: number of fields |
504 | | - // int: number of features |
505 | | - // int: latent dimension |
506 | | - // float[]: linear coefficients |
507 | | - // float[]: latent representation of features |
508 | | - |
509 | | - var norm = ctx.Reader.ReadBoolean(); |
510 | | - var fieldCount = ctx.Reader.ReadInt32(); |
511 | | - Host.CheckDecode(fieldCount > 0); |
512 | | - var featureCount = ctx.Reader.ReadInt32(); |
513 | | - Host.CheckDecode(featureCount > 0); |
514 | | - var latentDim = ctx.Reader.ReadInt32(); |
515 | | - Host.CheckDecode(latentDim > 0); |
516 | | - LatentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(latentDim); |
517 | | - Host.Check(checked(featureCount * fieldCount * LatentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension too large"); |
518 | | - var linearWeights = ctx.Reader.ReadFloatArray(); |
519 | | - Host.CheckDecode(Utils.Size(linearWeights) == featureCount); |
520 | | - var latentWeights = ctx.Reader.ReadFloatArray(); |
521 | | - Host.CheckDecode(Utils.Size(latentWeights) == featureCount * fieldCount * latentDim); |
522 | | - |
523 | | - _norm = norm; |
524 | | - FieldCount = fieldCount; |
525 | | - FeatureCount = featureCount; |
526 | | - LatentDim = latentDim; |
527 | | - _linearWeights = linearWeights; |
528 | | - _latentWeightsAligned = new AlignedArray(FeatureCount * FieldCount * LatentDimAligned, 16); |
529 | | - for (int j = 0; j < FeatureCount; j++) |
530 | | - { |
531 | | - for (int f = 0; f < FieldCount; f++) |
532 | | - { |
533 | | - int vBias = j * FieldCount * LatentDim + f * LatentDim; |
534 | | - int vBiasAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned; |
535 | | - for (int k = 0; k < LatentDimAligned; k++) |
536 | | - { |
537 | | - if (k < LatentDim) |
538 | | - _latentWeightsAligned[vBiasAligned + k] = latentWeights[vBias + k]; |
539 | | - else |
540 | | - _latentWeightsAligned[vBiasAligned + k] = 0; |
541 | | - } |
542 | | - } |
543 | | - } |
544 | | - } |
545 | | - |
546 | | - public static FieldAwareFactorizationMachinePredictor Create(IHostEnvironment env, ModelLoadContext ctx) |
547 | | - { |
548 | | - Contracts.CheckValue(env, nameof(env)); |
549 | | - env.CheckValue(ctx, nameof(ctx)); |
550 | | - ctx.CheckAtModel(GetVersionInfo()); |
551 | | - return new FieldAwareFactorizationMachinePredictor(env, ctx); |
552 | | - } |
553 | | - |
554 | | - protected override void SaveCore(ModelSaveContext ctx) |
555 | | - { |
556 | | - Host.AssertValue(ctx); |
557 | | - ctx.SetVersionInfo(GetVersionInfo()); |
558 | | - |
559 | | - // *** Binary format *** |
560 | | - // bool: whether to normalize feature vectors |
561 | | - // int: number of fields |
562 | | - // int: number of features |
563 | | - // int: latent dimension |
564 | | - // float[]: linear coefficients |
565 | | - // float[]: latent representation of features |
566 | | - |
567 | | - Host.Assert(FieldCount > 0); |
568 | | - Host.Assert(FeatureCount > 0); |
569 | | - Host.Assert(LatentDim > 0); |
570 | | - Host.Assert(Utils.Size(_linearWeights) == FeatureCount); |
571 | | - Host.Assert(_latentWeightsAligned.Size == FeatureCount * FieldCount * LatentDimAligned); |
572 | | - |
573 | | - ctx.Writer.Write(_norm); |
574 | | - ctx.Writer.Write(FieldCount); |
575 | | - ctx.Writer.Write(FeatureCount); |
576 | | - ctx.Writer.Write(LatentDim); |
577 | | - ctx.Writer.WriteFloatArray(_linearWeights); |
578 | | - float[] latentWeights = new float[FeatureCount * FieldCount * LatentDim]; |
579 | | - for (int j = 0; j < FeatureCount; j++) |
580 | | - { |
581 | | - for (int f = 0; f < FieldCount; f++) |
582 | | - { |
583 | | - int vBias = j * FieldCount * LatentDim + f * LatentDim; |
584 | | - int vBiasAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned; |
585 | | - for (int k = 0; k < LatentDim; k++) |
586 | | - latentWeights[vBias + k] = _latentWeightsAligned[vBiasAligned + k]; |
587 | | - } |
588 | | - } |
589 | | - ctx.Writer.WriteFloatArray(latentWeights); |
590 | | - } |
591 | | - |
592 | | - internal float CalculateResponse(ValueGetter<VBuffer<float>>[] getters, VBuffer<float> featureBuffer, |
593 | | - int[] featureFieldBuffer, int[] featureIndexBuffer, float[] featureValueBuffer, AlignedArray latentSum) |
594 | | - { |
595 | | - int count = 0; |
596 | | - float modelResponse = 0; |
597 | | - FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(getters, featureBuffer, _norm, ref count, |
598 | | - featureFieldBuffer, featureIndexBuffer, featureValueBuffer); |
599 | | - FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(FieldCount, LatentDimAligned, count, |
600 | | - featureFieldBuffer, featureIndexBuffer, featureValueBuffer, _linearWeights, _latentWeightsAligned, latentSum, ref modelResponse); |
601 | | - return modelResponse; |
602 | | - } |
603 | | - |
604 | | - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) |
605 | | - { |
606 | | - return new FieldAwareFactorizationMachineScalarRowMapper(env, schema, new BinaryClassifierSchema(), this); |
607 | | - } |
608 | | - |
609 | | - internal void CopyLinearWeightsTo(float[] linearWeights) |
610 | | - { |
611 | | - Host.AssertValue(_linearWeights); |
612 | | - Host.AssertValue(linearWeights); |
613 | | - Array.Copy(_linearWeights, linearWeights, _linearWeights.Length); |
614 | | - } |
615 | | - |
616 | | - internal void CopyLatentWeightsTo(AlignedArray latentWeights) |
617 | | - { |
618 | | - Host.AssertValue(_latentWeightsAligned); |
619 | | - Host.AssertValue(latentWeights); |
620 | | - latentWeights.CopyFrom(_latentWeightsAligned); |
621 | | - } |
622 | | - } |
623 | | - |
624 | | - internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBoundRowMapper |
625 | | - { |
626 | | - private readonly FieldAwareFactorizationMachinePredictor _pred; |
627 | | - |
628 | | - public RoleMappedSchema InputSchema { get; } |
629 | | - |
630 | | - public ISchema OutputSchema { get; } |
631 | | - |
632 | | - public ISchemaBindableMapper Bindable => _pred; |
633 | | - |
634 | | - private readonly ColumnInfo[] _columns; |
635 | | - private readonly List<int> _inputColumnIndexes; |
636 | | - private readonly IHostEnvironment _env; |
637 | | - |
638 | | - public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleMappedSchema schema, |
639 | | - ISchema outputSchema, FieldAwareFactorizationMachinePredictor pred) |
640 | | - { |
641 | | - Contracts.AssertValue(env); |
642 | | - Contracts.AssertValue(schema); |
643 | | - Contracts.CheckParam(outputSchema.ColumnCount == 2, nameof(outputSchema)); |
644 | | - Contracts.CheckParam(outputSchema.GetColumnType(0).IsNumber, nameof(outputSchema)); |
645 | | - Contracts.CheckParam(outputSchema.GetColumnType(1).IsNumber, nameof(outputSchema)); |
646 | | - Contracts.AssertValue(pred); |
647 | | - |
648 | | - _env = env; |
649 | | - _columns = schema.GetColumns(RoleMappedSchema.ColumnRole.Feature).ToArray(); |
650 | | - _pred = pred; |
651 | | - |
652 | | - var inputFeatureColumns = _columns.Select(c => new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, c.Name)).ToList(); |
653 | | - InputSchema = RoleMappedSchema.Create(schema.Schema, inputFeatureColumns); |
654 | | - OutputSchema = outputSchema; |
655 | | - |
656 | | - _inputColumnIndexes = new List<int>(); |
657 | | - foreach (var kvp in inputFeatureColumns) |
658 | | - { |
659 | | - if (schema.Schema.TryGetColumnIndex(kvp.Value, out int index)) |
660 | | - _inputColumnIndexes.Add(index); |
661 | | - } |
662 | | - } |
663 | | - |
664 | | - public IRow GetOutputRow(IRow input, Func<int, bool> predicate, out Action action) |
665 | | - { |
666 | | - var latentSum = new AlignedArray(_pred.FieldCount * _pred.FieldCount * _pred.LatentDimAligned, 16); |
667 | | - var featureBuffer = new VBuffer<float>(); |
668 | | - var featureFieldBuffer = new int[_pred.FeatureCount]; |
669 | | - var featureIndexBuffer = new int[_pred.FeatureCount]; |
670 | | - var featureValueBuffer = new float[_pred.FeatureCount]; |
671 | | - var inputGetters = new ValueGetter<VBuffer<float>>[_pred.FieldCount]; |
672 | | - for (int f = 0; f < _pred.FieldCount; f++) |
673 | | - inputGetters[f] = input.GetGetter<VBuffer<float>>(_inputColumnIndexes[f]); |
674 | | - |
675 | | - action = null; |
676 | | - var getters = new Delegate[2]; |
677 | | - if (predicate(0)) |
678 | | - { |
679 | | - ValueGetter<float> responseGetter = (ref float value) => |
680 | | - { |
681 | | - value = _pred.CalculateResponse(inputGetters, featureBuffer, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum); |
682 | | - }; |
683 | | - getters[0] = responseGetter; |
684 | | - } |
685 | | - if (predicate(1)) |
686 | | - { |
687 | | - ValueGetter<float> probGetter = (ref float value) => |
688 | | - { |
689 | | - value = _pred.CalculateResponse(inputGetters, featureBuffer, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum); |
690 | | - value = MathUtils.SigmoidSlow(value); |
691 | | - }; |
692 | | - getters[1] = probGetter; |
693 | | - } |
694 | | - |
695 | | - return new SimpleRow(OutputSchema, input, getters); |
696 | | - } |
697 | | - |
698 | | - public Func<int, bool> GetDependencies(Func<int, bool> predicate) |
699 | | - { |
700 | | - if (Enumerable.Range(0, OutputSchema.ColumnCount).Any(predicate)) |
701 | | - return index => _inputColumnIndexes.Any(c => c == index); |
702 | | - else |
703 | | - return index => false; |
704 | | - } |
705 | | - |
706 | | - public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles() |
707 | | - { |
708 | | - return InputSchema.GetColumnRoles().Select(kvp => new KeyValuePair<RoleMappedSchema.ColumnRole, string>(kvp.Key, kvp.Value.Name)); |
709 | | - } |
710 | | - } |
711 | 418 | } |
0 commit comments