Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions gs_quant/timeseries/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ class ThresholdType(str, Enum):


@plot_function
def smooth_spikes(x: pd.Series, threshold: float,
threshold_type: ThresholdType = ThresholdType.percentage) -> pd.Series:
def smooth_spikes(
x: pd.Series,
threshold: float,
threshold_type: ThresholdType = ThresholdType.percentage,
) -> pd.Series:
"""
Smooth out the spikes of a series. If a point is larger/smaller than (1 +/- threshold) times both neighbors, replace
it with the average of those neighbours. Note: the first and last points in the input series are dropped.
Expand All @@ -66,7 +69,9 @@ def smooth_spikes(x: pd.Series, threshold: float,
"""

def check_percentage(previous, current, next_, multiplier) -> bool:
current_higher = current > previous * multiplier and current > next_ * multiplier
current_higher = (
current > previous * multiplier and current > next_ * multiplier
)
current_lower = previous > current * multiplier and next_ > current * multiplier
return current_higher or current_lower

Expand All @@ -78,8 +83,11 @@ def check_absolute(previous, current, next_, absolute) -> bool:
if len(x) < 3:
return pd.Series(dtype=float)

threshold_value, check_spike = (threshold, check_absolute) if threshold_type == ThresholdType.absolute else (
(1 + threshold), check_percentage)
threshold_value, check_spike = (
(threshold, check_absolute)
if threshold_type == ThresholdType.absolute
else ((1 + threshold), check_percentage)
)

result = x.copy()
current, next_ = x.iloc[0:2]
Expand Down Expand Up @@ -109,11 +117,11 @@ def repeat(x: pd.Series, n: int = 1) -> pd.Series:
Fill missing values with last seen value e.g. to combine daily with weekly or monthly data.
"""
if not 0 < n < 367:
raise MqValueError('n must be between 0 and 367')
raise MqValueError("n must be between 0 and 367")
if x.empty:
return x
index = pd.date_range(freq=f'{n}D', start=x.index[0], end=x.index[-1])
return x.reindex(index, method='ffill')
index = pd.date_range(freq=f"{n}D", start=x.index[0], end=x.index[-1])
return x.reindex(index, method="ffill")


@plot_function
Expand Down Expand Up @@ -277,8 +285,11 @@ def diff(x: pd.Series, obs: Union[Window, int, str] = 1) -> pd.Series:


@plot_function
def compare(x: Union[pd.Series, Real], y: Union[pd.Series, Real], method: Interpolate = Interpolate.STEP) \
-> Union[pd.Series, Real]:
def compare(
x: Union[pd.Series, Real],
y: Union[pd.Series, Real],
method: Interpolate = Interpolate.STEP,
) -> Union[pd.Series, Real]:
"""
Compare two series or scalars against each other

Expand Down Expand Up @@ -328,7 +339,9 @@ class LagMode(Enum):


@plot_function
def lag(x: pd.Series, obs: Union[Window, int, str] = 1, mode: LagMode = LagMode.EXTEND) -> pd.Series:
def lag(
x: pd.Series, obs: Union[Window, int, str] = 1, mode: LagMode = LagMode.EXTEND
) -> pd.Series:
"""
Lag timeseries by a number of observations or a relative date.

Expand Down Expand Up @@ -365,28 +378,32 @@ def lag(x: pd.Series, obs: Union[Window, int, str] = 1, mode: LagMode = LagMode.
end = x.index[-1]
y = x.copy() # avoid mutating the provided series

match = re.fullmatch('(\\d+)y', obs)
match = re.fullmatch("(\\d+)y", obs)
if match:
y.index += pd.DateOffset(years=int(match.group(1)))
y = y.groupby(y.index).first()
else:
y.index = pd.DatetimeIndex([(i + pd.DateOffset(relative_date_add(obs))).date() for i in y.index])
y.index = pd.DatetimeIndex(
[(i + pd.DateOffset(relative_date_add(obs))).date() for i in y.index]
)

if mode == LagMode.EXTEND:
return y
return y[:end]

obs = getattr(obs, 'w', obs)
obs = getattr(obs, "w", obs)
# Determine how we want to handle observations prior to start date
if mode == LagMode.EXTEND:
if x.empty:
return x
if x.index.resolution != 'day':
raise MqValueError(f'unable to extend index with resolution {x.index.resolution}')
kwargs = {'periods': abs(obs) + 1, 'freq': 'D'}
if x.index.resolution != "day":
raise MqValueError(
f"unable to extend index with resolution {x.index.resolution}"
)
kwargs = {"periods": abs(obs) + 1, "freq": "D"}
if obs > 0:
kwargs['start'] = x.index[-1]
kwargs["start"] = x.index[-1]
else:
kwargs['end'] = x.index[0]
kwargs["end"] = x.index[0]
x = x.reindex(x.index.union(pd.date_range(**kwargs)))
return x.shift(obs)
Loading