Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def has_success(self) -> bool:
def has_failure(self) -> bool:
return SignalStatus.FAILURE in self.statuses

def count_by_status(self, status: SignalStatus) -> int:
"""Get the count of events with the specified status."""
return self.statuses.get(status, 0)

def events_by_status(self, status: SignalStatus) -> List[SignalEvent]:
"""Get all events with the specified status."""
return [event for event in self.events if event.status == status]
Expand Down Expand Up @@ -270,6 +274,13 @@ class InfraCheckResult(Enum):
RESTART_FAILURE = "restart_failure" # no failure after any success


class SignalSource(Enum):
"""Origin of a Signal: test-track or job-track."""

TEST = "test"
JOB = "job"


class Signal:
"""A refined, column-like view of raw CI data for pattern detection.

Expand All @@ -285,12 +296,15 @@ def __init__(
workflow_name: str,
commits: List[SignalCommit],
job_base_name: Optional[str] = None,
source: SignalSource = SignalSource.TEST,
):
self.key = key
self.workflow_name = workflow_name
# commits are ordered from newest to oldest
self.commits = commits
self.job_base_name = job_base_name
# Track the origin of the signal (test-track or job-track).
self.source = source

def detect_fixed(self) -> bool:
"""
Expand Down Expand Up @@ -451,6 +465,16 @@ def process_valid_autorevert_pattern(
):
restart_commits.add(partition.successful[0].head_sha)

# Job-track specific requirement: when there is no gap (unknown empty),
# require a failed rerun on the first failing commit to increase confidence.
if (
not partition.unknown
and self.source == SignalSource.JOB
and not partition.failed[-1].has_pending
and len(partition.failed[-1].events) < 2
):
restart_commits.add(partition.failed[-1].head_sha)

if restart_commits:
return RestartCommits(commit_shas=restart_commits)

Expand All @@ -472,6 +496,15 @@ def process_valid_autorevert_pattern(
f"not enough successes to make call: {partition.success_events_count()}",
)

if (
self.source == SignalSource.JOB
and partition.failed[-1].count_by_status(SignalStatus.FAILURE) < 2
):
return Ineligible(
IneligibleReason.INSUFFICIENT_FAILURES,
"job-track signal requires at least 2 failures on the first failing commit",
)

if partition.unknown:
# there are still pending/missing commits in the unknown partition
unknown_shas = ", ".join(c.head_sha for c in partition.unknown)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Dict, Iterable, List, Optional, Set, Tuple

from .job_agg_index import JobAggIndex, JobMeta, SignalStatus as AggStatus
from .signal import Signal, SignalCommit, SignalEvent, SignalStatus
from .signal import Signal, SignalCommit, SignalEvent, SignalSource, SignalStatus
from .signal_extraction_datasource import SignalExtractionDatasource
from .signal_extraction_types import (
JobBaseName,
Expand Down Expand Up @@ -127,6 +127,7 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]:
workflow_name=s.workflow_name,
commits=new_commits,
job_base_name=s.job_base_name,
source=s.source,
)
)
return deduped
Expand Down Expand Up @@ -211,6 +212,7 @@ def _inject_pending_workflow_events(
workflow_name=s.workflow_name,
commits=new_commits,
job_base_name=s.job_base_name,
source=s.source,
)
)
return out
Expand Down Expand Up @@ -422,6 +424,7 @@ def _build_test_signals(
workflow_name=wf_name,
commits=commit_objs,
job_base_name=str(job_base_name),
source=SignalSource.TEST,
)
)

Expand Down Expand Up @@ -529,6 +532,7 @@ def _build_non_test_signals(
workflow_name=wf_name,
commits=commit_objs,
job_base_name=str(base_name),
source=SignalSource.JOB,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Signal,
SignalCommit,
SignalEvent,
SignalSource,
SignalStatus,
)

Expand Down Expand Up @@ -439,6 +440,138 @@ def test_success_restart_even_when_failed_side_pending_and_insufficient_failures
self.assertIn("sha_success_ok", res.commit_shas)
self.assertNotIn("sha_fail_pend", res.commit_shas)

def test_job_track_requires_failed_rerun_when_no_gap_missing_rerun(self):
# Job-track: require a failed rerun on the suspected commit when there is no gap.
# Build commits newest -> older
c_fail_newest = SignalCommit(
head_sha="sha_fail_newest",
timestamp=ts(self.t0, 0),
events=[
SignalEvent(
name="job",
status=SignalStatus.FAILURE,
started_at=ts(self.t0, 7),
wf_run_id=100,
run_attempt=1,
)
],
)
c_fail_new = SignalCommit(
head_sha="sha_fail_new",
timestamp=ts(self.t0, 0),
events=[
SignalEvent(
name="job",
status=SignalStatus.FAILURE,
started_at=ts(self.t0, 5),
wf_run_id=101,
run_attempt=1,
)
],
)
# Suspected commit: first failure attempt=1, no rerun yet (missing failed rerun)
c_suspected = SignalCommit(
head_sha="sha_suspected",
timestamp=ts(self.t0, 0),
events=[
SignalEvent(
name="job",
status=SignalStatus.FAILURE,
started_at=ts(self.t0, 4),
wf_run_id=321,
run_attempt=1,
),
],
)
# Base successful commit with two successes
c_base = SignalCommit(
head_sha="sha_base",
timestamp=ts(self.t0, 0),
events=[
self._ev("job", SignalStatus.SUCCESS, 3),
self._ev("job", SignalStatus.SUCCESS, 6),
],
)

s = Signal(
key="job",
workflow_name="wf",
commits=[c_fail_newest, c_fail_new, c_suspected, c_base],
source=SignalSource.JOB,
)
res = s.process_valid_autorevert_pattern()
# Should not produce an AutorevertPattern; instead propose restart of suspected commit
self.assertNotIsInstance(res, AutorevertPattern)
self.assertTrue(hasattr(res, "commit_shas"))
self.assertIn("sha_suspected", res.commit_shas)

def test_job_track_allows_autorevert_when_failed_rerun_present(self):
# Same as above, but suspected has a failed rerun (attempt 2) on the same wf_run_id.
c_fail_newest = SignalCommit(
head_sha="sha_fail_newest",
timestamp=ts(self.t0, 0),
events=[
SignalEvent(
name="job",
status=SignalStatus.FAILURE,
started_at=ts(self.t0, 7),
wf_run_id=100,
run_attempt=1,
)
],
)
c_fail_new = SignalCommit(
head_sha="sha_fail_new",
timestamp=ts(self.t0, 0),
events=[
SignalEvent(
name="job",
status=SignalStatus.FAILURE,
started_at=ts(self.t0, 5),
wf_run_id=101,
run_attempt=1,
)
],
)
# Suspected commit: failure attempt=1 then failure attempt=2 on same run id
c_suspected = SignalCommit(
head_sha="sha_suspected",
timestamp=ts(self.t0, 0),
events=[
SignalEvent(
name="job",
status=SignalStatus.FAILURE,
started_at=ts(self.t0, 4),
wf_run_id=321,
run_attempt=1,
),
SignalEvent(
name="job",
status=SignalStatus.FAILURE,
started_at=ts(self.t0, 6),
wf_run_id=321,
run_attempt=2,
),
],
)
c_base = SignalCommit(
head_sha="sha_base",
timestamp=ts(self.t0, 0),
events=[
self._ev("job", SignalStatus.SUCCESS, 3),
self._ev("job", SignalStatus.SUCCESS, 6),
],
)

s = Signal(
key="job",
workflow_name="wf",
commits=[c_fail_newest, c_fail_new, c_suspected, c_base],
source=SignalSource.JOB,
)
res = s.process_valid_autorevert_pattern()
self.assertIsInstance(res, AutorevertPattern)


if __name__ == "__main__":
unittest.main()