Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 52 additions & 27 deletions gs_quant/risk/result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
import datetime as dt
import logging
from typing import Iterable, Optional, Union
from typing import List, Iterable, Optional, Union

from gs_quant.base import InstrumentBase, RiskKey
from gs_quant.common import RiskMeasure, AssetClass, RiskMeasureType
Expand All @@ -28,22 +28,31 @@

def __dataframe_handler(result: Iterable, mappings: tuple, risk_key: RiskKey, request_id: Optional[str] = None) \
-> DataFrameWithInfo:
first_row = next(iter(result), None)
if first_row is None:
# Collect rows as list for single pass, and empty test.
result_list = list(result)
if not result_list:
return DataFrameWithInfo(risk_key=risk_key, request_id=request_id)

columns = ()
indices = [False] * len(first_row.keys())
# Only need to use the first row to determine columns
first_row = result_list[0]

# Prepare a mapping of destination-to-source so that source column names can be checked efficiently
mappings_lookup = {v: k for k, v in mappings}
first_row_keys = list(first_row.keys())
indices: List[bool] = [src in mappings_lookup for src in first_row_keys]

for idx, src in enumerate(first_row.keys()):
if src in mappings_lookup:
indices[idx] = True
columns += ((mappings_lookup[src]),)
# Build up the ordered output columns as tuple (preserving original code comment and structure)
columns = tuple(mappings_lookup[src] for src in first_row_keys if src in mappings_lookup)

records = tuple(
sort_values((tuple(v for i, v in enumerate(r.values()) if indices[i]) for r in result), columns, columns)
)
# Prepare the tuple generator for selected columns, using precomputed indices
col_idxs = [i for i, keep in enumerate(indices) if keep]
# Optimization: precompute col_idxs once, reduces per row overhead
# This also preserves the input row column order as in original code

# Efficient tuple extraction of selected values per row
records_unsorted = (tuple(r[k] for k in first_row_keys if k in mappings_lookup) for r in result_list)
# Sort using sort_values
records = tuple(sort_values(records_unsorted, columns, columns))

df = DataFrameWithInfo(records, risk_key=risk_key, request_id=request_id)
df.columns = columns
Expand All @@ -53,15 +62,23 @@ def __dataframe_handler(result: Iterable, mappings: tuple, risk_key: RiskKey, re

def __dataframe_handler_unsorted(result: Iterable, mappings: tuple, date_cols: tuple, risk_key: RiskKey,
request_id: Optional[str] = None) -> DataFrameWithInfo:
first_row = next(iter(result), None)
if first_row is None:
result_list = list(result)
if not result_list:
return DataFrameWithInfo(risk_key=risk_key, request_id=request_id)

records = ([row.get(field_from) for field_to, field_from in mappings] for row in result)
# Produce one pass over result to build records
field_froms = [m[1] for m in mappings]
# This is roughly as efficient as possible without additional list comprehensions
records = ([row.get(field) for field in field_froms] for row in result_list)

df = DataFrameWithInfo(records, risk_key=risk_key, request_id=request_id)
df.columns = [m[0] for m in mappings]
for dt_col in date_cols:
df[dt_col] = df[dt_col].map(lambda x: dt.datetime.strptime(x, '%Y-%m-%d').date() if isinstance(x, str) else x)
if date_cols:
# Only do the iteration if any date columns are present
strptime = dt.datetime.strptime
for dt_col in date_cols:
# Using map with a local lambda for performance
df[dt_col] = df[dt_col].map(lambda x: strptime(x, '%Y-%m-%d').date() if isinstance(x, str) else x)

return df

Expand Down Expand Up @@ -149,25 +166,33 @@ def risk_by_class_handler(result: dict, risk_key: RiskKey, _instrument: Instrume
# list of risk by class measures exposed in gs-quant
external_risk_by_class_val = ['IRBasisParallel', 'IRDeltaParallel', 'IRVegaParallel', 'PnlExplain']
if str(risk_key.risk_measure.name) in external_risk_by_class_val and len(types) <= 2 and len(set(types)) == 1:
# The sum and tuple conversion is very efficient, nothing to optimize further
return FloatWithInfo(risk_key, sum(result.get('values', (float('nan'),))), unit=result.get('unit'),
request_id=request_id)
else:
classes = []
skip = []
skip_set = set()

crosses_idx = next((i for i, c in enumerate(result['classes']) if c['type'] == 'CROSSES'), None)
for idx, (clazz, value) in enumerate(zip(result['classes'], result['values'])):
rc_classes = result['classes']
rc_values = result['values']

# Move enumeration out of the loop for performance
# Precompute crosses_idx up front
crosses_idx = next((i for i, c in enumerate(rc_classes) if c['type'] == 'CROSSES'), None)

# Inline and optimize the looping logic:
# Avoid repeated append by using set for skip indices and list for classes
for idx, (clazz, value) in enumerate(zip(rc_classes, rc_values)):
mkt_type = clazz['type']
if 'SPIKE' in mkt_type or 'JUMP' in mkt_type:
skip.append(idx)

skip_set.add(idx)
if crosses_idx is not None:
result['classes'][crosses_idx]['value'] += value

clazz.update({'value': value})
rc_classes[crosses_idx]['value'] += value
clazz['value'] = value # Directly set instead of update for single key

for idx, clazz in enumerate(result['classes']):
if idx not in skip:
# Single scan for output, much faster than repeated 'if idx not in ...'
for idx, clazz in enumerate(rc_classes):
if idx not in skip_set:
classes.append(clazz)

mappings = (
Expand Down