Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions gs_quant/backtests/generic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@ def apply_action(self,
class AddScaledTradeActionImpl(OrderBasedActionImpl):
def __init__(self, action: AddScaledTradeAction):
super().__init__(action)
self._scaling_level_signal = interpolate_signal(self.action.scaling_level) \
if isinstance(self.action.scaling_level, dict) else None
if isinstance(self.action.scaling_level, dict):
self._scaling_level_signal = interpolate_signal(self.action.scaling_level)
self._scaling_level_signal_values = self._scaling_level_signal.values
self._scaling_level_signal_index_map = {date: idx for idx, date in enumerate(self._scaling_level_signal.index)}
else:
self._scaling_level_signal = None

@staticmethod
def __portfolio_scaling_for_available_cash(portfolio, available_cash, cur_day, unscaled_prices_by_day,
Expand Down Expand Up @@ -248,9 +252,11 @@ def _nav_scale_orders(self, orders, price_measure, trigger_infos):
orders[day].scale(scaling_factors_by_day[day])

def _scaling_level_for_date(self, d: dt.date) -> float:
# Avoid pandas __contains__ and __getitem__ in hotpath, use a cached index mapping
if self._scaling_level_signal is not None:
if d in self._scaling_level_signal:
return self._scaling_level_signal[d]
idx = self._scaling_level_signal_index_map.get(d)
if idx is not None:
return self._scaling_level_signal_values[idx]
return 0
else:
return self.action.scaling_level
Expand Down