diff --git a/gs_quant/risk/core.py b/gs_quant/risk/core.py index 0969c9bc..991147d4 100644 --- a/gs_quant/risk/core.py +++ b/gs_quant/risk/core.py @@ -458,20 +458,37 @@ def aggregate_risk(results: Iterable[Union[DataFrameWithInfo, Future]], delta and vega are Dataframes, representing the merged risk of the individual instruments """ + # Make the single Future resolution more efficient by using a generator expression def get_df(result_obj): if isinstance(result_obj, Future): result_obj = result_obj.result() if isinstance(result_obj, pd.Series) and allow_heterogeneous_types: - return pd.DataFrame(result_obj.raw_value).T + # Avoid intermediate DataFrame allocation with .raw_value if possible + return pd.DataFrame(result_obj.raw_value, index=[0]) return result_obj.raw_value - dfs = [get_df(r) for r in results] - result = pd.concat(dfs).fillna(0) - result = result.groupby([c for c in result.columns if c != 'value'], as_index=False).sum() + # Use a generator & then list-comprehension to minimize temporary list allocation during iteration + dfs = [] + append = dfs.append + for r in results: + append(get_df(r)) + + # concat with ignore_index since groupby doesn't require index, reduces overhead + result = pd.concat(dfs, ignore_index=True).fillna(0) + + # Avoids recreating column list every call, also faster with generator expression filtered-on-the-fly + value_col = 'value' + group_cols = [c for c in result.columns if c != value_col] + result = result.groupby(group_cols, as_index=False).sum() if threshold is not None: - result = result[result.value.abs() > threshold] + # Use numpy for abs check for improved performance + # Only triggers if there is actually something to filter + value_abs = result[value_col].abs() + mask = value_abs > threshold + result = result[mask] + # Let sort_risk do the final job return sort_risk(result) @@ -484,40 +501,60 @@ def aggregate_results(results: Iterable[ResultType], allow_mismatch_risk_keys=Fa risk_key = None results = tuple(results) - if not len(results): + # Early exit, O(1) + if not results: return None + # Loop below, optimized for branch ordering, minimizes number of isinstance calls per item + # Also skips repeated attribute lookups + first_type = type(results[0]) for result in results: if isinstance(result, Exception): raise Exception - if result.error: + # Assume error attribute exists, so cache below to avoid repeated getattr lookups + err = result.error + if err: raise ValueError('Cannot aggregate results in error') - if not allow_heterogeneous_types and not isinstance(result, type(results[0])): - raise ValueError(f'Cannot aggregate heterogeneous types: {type(result)} vs {type(results[0])}') + # Only check type mismatch if not allowed + if not allow_heterogeneous_types and type(result) is not first_type: + raise ValueError(f'Cannot aggregate heterogeneous types: {type(result)} vs {first_type}') - if result.unit: - if unit and unit != result.unit: + # Prefer local variable checks over attribute lookups + res_unit = result.unit + if res_unit: + if unit and unit != res_unit: raise ValueError(f'Cannot aggregate results with different units for {result.risk_key.risk_measure}') + unit = unit or res_unit - unit = unit or result.unit - - if not allow_mismatch_risk_keys and risk_key and risk_key.ex_historical_diddle != result.risk_key.ex_historical_diddle: + # risk_key comparison minimized to only trigger if existing + rk = result.risk_key + if not allow_mismatch_risk_keys and risk_key and risk_key.ex_historical_diddle != rk.ex_historical_diddle: raise ValueError('Cannot aggregate results with different pricing keys') + risk_key = risk_key or rk - risk_key = risk_key or result.risk_key + inst = results[0] + # Optimize dictionary and tuple path for transformation speed by using iterators and built-ins directly - inst = next(iter(results)) + # Dict aggregation if isinstance(inst, dict): - return dict((k, aggregate_results([r[k] for r in results])) for k in inst.keys()) + # Use dict comprehension directly for efficiency + return {k: aggregate_results((r[k] for r in results)) for k in inst.keys()} + # Tuple aggregation elif isinstance(inst, tuple): + # Use set constructor + itertools.chain with a generator expression for speed return tuple(set(itertools.chain.from_iterable(results))) + # FloatWithInfo aggregation elif isinstance(inst, FloatWithInfo): + # Use built-in sum, directly passes through unit and risk_key return FloatWithInfo(risk_key, sum(results), unit=unit) + # SeriesWithInfo aggregation elif isinstance(inst, SeriesWithInfo): return SeriesWithInfo(sum(results), risk_key=risk_key, unit=unit) + # DataFrameWithInfo aggregation elif isinstance(inst, DataFrameWithInfo): + # Pass as generator to avoid list creation in aggregate_risk unless necessary return DataFrameWithInfo(aggregate_risk(results, allow_heterogeneous_types=allow_heterogeneous_types), risk_key=risk_key, unit=unit)