diff --git a/gs_quant/backtests/generic_engine.py b/gs_quant/backtests/generic_engine.py index cd8adc19..faf291ac 100644 --- a/gs_quant/backtests/generic_engine.py +++ b/gs_quant/backtests/generic_engine.py @@ -65,11 +65,10 @@ def __init__(self, action: Action): def get_base_orders_for_states(self, states: Collection[dt.date], **kwargs): orders = {} dated_priceables = getattr(self.action, 'dated_priceables', {}) - with PricingContext(): - for s in states: - active_portfolio = dated_priceables.get(s) or self.action.priceables - with PricingContext(pricing_date=s): - orders[s] = Portfolio(active_portfolio).calc(tuple(self._order_valuations)) + for s in states: + active_portfolio = dated_priceables.get(s) or self.action.priceables + with PricingContext(pricing_date=s): + orders[s] = Portfolio(active_portfolio).calc(tuple(self._order_valuations)) return orders def get_instrument_final_date(self, inst: Instrument, order_date: dt.date, info: namedtuple): @@ -85,10 +84,8 @@ def _raise_order(self, trigger_info: Optional[Union[AddTradeActionInfo, Iterable[AddTradeActionInfo]]] = None): state_list = make_list(state) if trigger_info is None or isinstance(trigger_info, AddTradeActionInfo): - trigger_info = [trigger_info for _ in range(len(state_list))] - ti_by_state = {} - for s, ti in zip_longest(state_list, trigger_info): - ti_by_state[s] = ti + trigger_info = [trigger_info] * len(state_list) + ti_by_state = dict(zip_longest(state_list, trigger_info)) orders = self.get_base_orders_for_states(state_list, trigger_infos=ti_by_state) final_orders = {} for d, p in orders.items():