Skip to content

Commit db4893e

Browse files
authored
[autorevert] requires same revision retry failure for job-track signals (#7396)
Based on the autorevert performance we noticed that all false-positives come from job signals, as they are more likely to be caused by flakiness / infra. This PR increases the confidence of autorevert decisions specifically for job-track signals by always issuing / waiting for failure in the same revision restart. ### Testing * manual testing * unit tests
1 parent 6498804 commit db4893e

File tree

3 files changed

+171
-1
lines changed

3 files changed

+171
-1
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ def has_success(self) -> bool:
137137
def has_failure(self) -> bool:
138138
return SignalStatus.FAILURE in self.statuses
139139

140+
def count_by_status(self, status: SignalStatus) -> int:
141+
"""Get the count of events with the specified status."""
142+
return self.statuses.get(status, 0)
143+
140144
def events_by_status(self, status: SignalStatus) -> List[SignalEvent]:
141145
"""Get all events with the specified status."""
142146
return [event for event in self.events if event.status == status]
@@ -270,6 +274,13 @@ class InfraCheckResult(Enum):
270274
RESTART_FAILURE = "restart_failure" # no failure after any success
271275

272276

277+
class SignalSource(Enum):
278+
"""Origin of a Signal: test-track or job-track."""
279+
280+
TEST = "test"
281+
JOB = "job"
282+
283+
273284
class Signal:
274285
"""A refined, column-like view of raw CI data for pattern detection.
275286
@@ -285,12 +296,15 @@ def __init__(
285296
workflow_name: str,
286297
commits: List[SignalCommit],
287298
job_base_name: Optional[str] = None,
299+
source: SignalSource = SignalSource.TEST,
288300
):
289301
self.key = key
290302
self.workflow_name = workflow_name
291303
# commits are ordered from newest to oldest
292304
self.commits = commits
293305
self.job_base_name = job_base_name
306+
# Track the origin of the signal (test-track or job-track).
307+
self.source = source
294308

295309
def detect_fixed(self) -> bool:
296310
"""
@@ -451,6 +465,16 @@ def process_valid_autorevert_pattern(
451465
):
452466
restart_commits.add(partition.successful[0].head_sha)
453467

468+
# Job-track specific requirement: when there is no gap (unknown empty),
469+
# require a failed rerun on the first failing commit to increase confidence.
470+
if (
471+
not partition.unknown
472+
and self.source == SignalSource.JOB
473+
and not partition.failed[-1].has_pending
474+
and len(partition.failed[-1].events) < 2
475+
):
476+
restart_commits.add(partition.failed[-1].head_sha)
477+
454478
if restart_commits:
455479
return RestartCommits(commit_shas=restart_commits)
456480

@@ -472,6 +496,15 @@ def process_valid_autorevert_pattern(
472496
f"not enough successes to make call: {partition.success_events_count()}",
473497
)
474498

499+
if (
500+
self.source == SignalSource.JOB
501+
and partition.failed[-1].count_by_status(SignalStatus.FAILURE) < 2
502+
):
503+
return Ineligible(
504+
IneligibleReason.INSUFFICIENT_FAILURES,
505+
"job-track signal requires at least 2 failures on the first failing commit",
506+
)
507+
475508
if partition.unknown:
476509
# there are still pending/missing commits in the unknown partition
477510
unknown_shas = ", ".join(c.head_sha for c in partition.unknown)

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Dict, Iterable, List, Optional, Set, Tuple
1313

1414
from .job_agg_index import JobAggIndex, JobMeta, SignalStatus as AggStatus
15-
from .signal import Signal, SignalCommit, SignalEvent, SignalStatus
15+
from .signal import Signal, SignalCommit, SignalEvent, SignalSource, SignalStatus
1616
from .signal_extraction_datasource import SignalExtractionDatasource
1717
from .signal_extraction_types import (
1818
JobBaseName,
@@ -127,6 +127,7 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]:
127127
workflow_name=s.workflow_name,
128128
commits=new_commits,
129129
job_base_name=s.job_base_name,
130+
source=s.source,
130131
)
131132
)
132133
return deduped
@@ -211,6 +212,7 @@ def _inject_pending_workflow_events(
211212
workflow_name=s.workflow_name,
212213
commits=new_commits,
213214
job_base_name=s.job_base_name,
215+
source=s.source,
214216
)
215217
)
216218
return out
@@ -422,6 +424,7 @@ def _build_test_signals(
422424
workflow_name=wf_name,
423425
commits=commit_objs,
424426
job_base_name=str(job_base_name),
427+
source=SignalSource.TEST,
425428
)
426429
)
427430

@@ -528,6 +531,7 @@ def _build_non_test_signals(
528531
workflow_name=wf_name,
529532
commits=commit_objs,
530533
job_base_name=str(base_name),
534+
source=SignalSource.JOB,
531535
)
532536
)
533537

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Signal,
1111
SignalCommit,
1212
SignalEvent,
13+
SignalSource,
1314
SignalStatus,
1415
)
1516

@@ -439,6 +440,138 @@ def test_success_restart_even_when_failed_side_pending_and_insufficient_failures
439440
self.assertIn("sha_success_ok", res.commit_shas)
440441
self.assertNotIn("sha_fail_pend", res.commit_shas)
441442

443+
def test_job_track_requires_failed_rerun_when_no_gap_missing_rerun(self):
444+
# Job-track: require a failed rerun on the suspected commit when there is no gap.
445+
# Build commits newest -> older
446+
c_fail_newest = SignalCommit(
447+
head_sha="sha_fail_newest",
448+
timestamp=ts(self.t0, 0),
449+
events=[
450+
SignalEvent(
451+
name="job",
452+
status=SignalStatus.FAILURE,
453+
started_at=ts(self.t0, 7),
454+
wf_run_id=100,
455+
run_attempt=1,
456+
)
457+
],
458+
)
459+
c_fail_new = SignalCommit(
460+
head_sha="sha_fail_new",
461+
timestamp=ts(self.t0, 0),
462+
events=[
463+
SignalEvent(
464+
name="job",
465+
status=SignalStatus.FAILURE,
466+
started_at=ts(self.t0, 5),
467+
wf_run_id=101,
468+
run_attempt=1,
469+
)
470+
],
471+
)
472+
# Suspected commit: first failure attempt=1, no rerun yet (missing failed rerun)
473+
c_suspected = SignalCommit(
474+
head_sha="sha_suspected",
475+
timestamp=ts(self.t0, 0),
476+
events=[
477+
SignalEvent(
478+
name="job",
479+
status=SignalStatus.FAILURE,
480+
started_at=ts(self.t0, 4),
481+
wf_run_id=321,
482+
run_attempt=1,
483+
),
484+
],
485+
)
486+
# Base successful commit with two successes
487+
c_base = SignalCommit(
488+
head_sha="sha_base",
489+
timestamp=ts(self.t0, 0),
490+
events=[
491+
self._ev("job", SignalStatus.SUCCESS, 3),
492+
self._ev("job", SignalStatus.SUCCESS, 6),
493+
],
494+
)
495+
496+
s = Signal(
497+
key="job",
498+
workflow_name="wf",
499+
commits=[c_fail_newest, c_fail_new, c_suspected, c_base],
500+
source=SignalSource.JOB,
501+
)
502+
res = s.process_valid_autorevert_pattern()
503+
# Should not produce an AutorevertPattern; instead propose restart of suspected commit
504+
self.assertNotIsInstance(res, AutorevertPattern)
505+
self.assertTrue(hasattr(res, "commit_shas"))
506+
self.assertIn("sha_suspected", res.commit_shas)
507+
508+
def test_job_track_allows_autorevert_when_failed_rerun_present(self):
509+
# Same as above, but suspected has a failed rerun (attempt 2) on the same wf_run_id.
510+
c_fail_newest = SignalCommit(
511+
head_sha="sha_fail_newest",
512+
timestamp=ts(self.t0, 0),
513+
events=[
514+
SignalEvent(
515+
name="job",
516+
status=SignalStatus.FAILURE,
517+
started_at=ts(self.t0, 7),
518+
wf_run_id=100,
519+
run_attempt=1,
520+
)
521+
],
522+
)
523+
c_fail_new = SignalCommit(
524+
head_sha="sha_fail_new",
525+
timestamp=ts(self.t0, 0),
526+
events=[
527+
SignalEvent(
528+
name="job",
529+
status=SignalStatus.FAILURE,
530+
started_at=ts(self.t0, 5),
531+
wf_run_id=101,
532+
run_attempt=1,
533+
)
534+
],
535+
)
536+
# Suspected commit: failure attempt=1 then failure attempt=2 on same run id
537+
c_suspected = SignalCommit(
538+
head_sha="sha_suspected",
539+
timestamp=ts(self.t0, 0),
540+
events=[
541+
SignalEvent(
542+
name="job",
543+
status=SignalStatus.FAILURE,
544+
started_at=ts(self.t0, 4),
545+
wf_run_id=321,
546+
run_attempt=1,
547+
),
548+
SignalEvent(
549+
name="job",
550+
status=SignalStatus.FAILURE,
551+
started_at=ts(self.t0, 6),
552+
wf_run_id=321,
553+
run_attempt=2,
554+
),
555+
],
556+
)
557+
c_base = SignalCommit(
558+
head_sha="sha_base",
559+
timestamp=ts(self.t0, 0),
560+
events=[
561+
self._ev("job", SignalStatus.SUCCESS, 3),
562+
self._ev("job", SignalStatus.SUCCESS, 6),
563+
],
564+
)
565+
566+
s = Signal(
567+
key="job",
568+
workflow_name="wf",
569+
commits=[c_fail_newest, c_fail_new, c_suspected, c_base],
570+
source=SignalSource.JOB,
571+
)
572+
res = s.process_valid_autorevert_pattern()
573+
self.assertIsInstance(res, AutorevertPattern)
574+
442575

443576
if __name__ == "__main__":
444577
unittest.main()

0 commit comments

Comments
 (0)