Skip to content

Commit 6f2e94f

Browse files
authored
Data splits to default to MLContext seed when not specified (#4764)
1 parent 90df4e0 commit 6f2e94f

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -495,13 +495,14 @@ internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment
495495
/// </summary>
496496
internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
497497
{
498+
Contracts.CheckValue(env, nameof(env));
499+
var host = env.Register("rand");
498500
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
499501
// build a single hash of it. If it is not, we generate a random number.
500-
501502
if (samplingKeyColumn == null)
502503
{
503504
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
504-
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)seed);
505+
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? host.Rand.Next()));
505506
}
506507
else
507508
{
@@ -517,11 +518,7 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa
517518
// instead of having two hash transformations.
518519
var origStratCol = samplingKeyColumn;
519520
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
520-
HashingEstimator.ColumnOptionsInternal columnOptions;
521-
if (seed.HasValue)
522-
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
523-
else
524-
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30);
521+
var columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)(seed ?? host.Rand.Next()));
525522
data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
526523
}
527524
else
@@ -533,7 +530,6 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa
533530
data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
534531
}
535532
}
536-
537533
}
538534
}
539535
}

0 commit comments

Comments
 (0)