Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 108 additions & 56 deletions torch_frame/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def _requires_pre_materialization(self, *args, **kwargs):
if self.is_materialized:
raise RuntimeError(
f"'{self}' cannot be modified via '{func.__name__}' post "
f"materialization")
f"materialization"
)
return func(self, *args, **kwargs)

return _requires_pre_materialization
Expand All @@ -76,7 +77,8 @@ def _requires_post_materialization(self, *args, **kwargs):
if not self.is_materialized:
raise RuntimeError(
f"'{func.__name__}' requires a materialized dataset. Please "
f"call `dataset.materialize(...)` first.")
f"call `dataset.materialize(...)` first."
)
return func(self, *args, **kwargs)

return _requires_post_materialization
Expand Down Expand Up @@ -126,7 +128,8 @@ def canonicalize_col_to_pattern(
raise ValueError(
f"{col_to_pattern_name} requires all columns to be "
f"specified but the following columns are missing from "
f"the dictionary: {list(missing_cols)}.")
f"the dictionary: {list(missing_cols)}."
)
else:
for col in missing_cols:
col_to_pattern[col] = None
Expand Down Expand Up @@ -171,14 +174,14 @@ class DataFrameToTensorFrameConverter:
to_datetime function will be used to auto parse time columns.
(default: :obj:`None`)
"""

def __init__(
self,
col_to_stype: dict[str, torch_frame.stype],
col_stats: dict[str, dict[StatType, Any]],
target_col: str | None = None,
col_to_sep: dict[str, str | None] | None = None,
col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
| None = None,
col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] | None = None,
col_to_text_tokenizer_cfg: dict[str, TextTokenizerConfig]
| None = None,
col_to_time_format: dict[str, str | None] = None,
Expand Down Expand Up @@ -219,8 +222,9 @@ def _get_mapper(self, col: str) -> TensorMapper:
return CategoricalTensorMapper(index)
elif stype == torch_frame.multicategorical:
index, _ = self.col_stats[col][StatType.MULTI_COUNT]
return MultiCategoricalTensorMapper(index,
sep=self.col_to_sep[col])
return MultiCategoricalTensorMapper(
index, sep=self.col_to_sep[col]
)
elif stype == torch_frame.timestamp:
return TimestampTensorMapper(format=self.col_to_time_format[col])
elif stype == torch_frame.text_embedded:
Expand All @@ -240,8 +244,9 @@ def _get_mapper(self, col: str) -> TensorMapper:
elif stype == torch_frame.embedding:
return EmbeddingTensorMapper()
else:
raise NotImplementedError(f"Unable to process the semantic "
f"type '{stype.value}'")
raise NotImplementedError(
f"Unable to process the semantic " f"type '{stype.value}'"
)

def _merge_feat(self, tf: TensorFrame) -> TensorFrame:
r"""Merge child and parent :obj:`stypes<torch_frame.stype>` in the
Expand All @@ -258,13 +263,16 @@ def _merge_feat(self, tf: TensorFrame) -> TensorFrame:
if stype.parent in tf.stypes:
parent_feat = tf.feat_dict[stype.parent]
tf.feat_dict[stype.parent] = torch_frame.cat(
[parent_feat, child_feat], dim=1)
[parent_feat, child_feat], dim=1
)
else:
tf.feat_dict[stype.parent] = child_feat

# Unify col_names_dict
tf.col_names_dict[stype.parent] = tf.col_names_dict.get(
stype.parent, []) + tf.col_names_dict[stype]
tf.col_names_dict[stype.parent] = (
tf.col_names_dict.get(stype.parent, [])
+ tf.col_names_dict[stype]
)

tf.feat_dict.pop(stype)
tf.col_names_dict.pop(stype)
Expand Down Expand Up @@ -293,7 +301,8 @@ def __call__(
feat_dict[stype]: dict[str, MultiNestedTensor] = {}
for key in xs[0].keys():
feat_dict[stype][key] = MultiNestedTensor.cat(
[x[key] for x in xs], dim=1)
[x[key] for x in xs], dim=1
)
elif stype.use_multi_embedding_tensor:
feat_dict[stype] = MultiEmbeddingTensor.cat(xs, dim=1)
else:
Expand All @@ -302,7 +311,8 @@ def __call__(
y: Tensor | None = None
if self.target_col is not None and self.target_col in df:
y = self._get_mapper(self.target_col).forward(
df[self.target_col], device=device)
df[self.target_col], device=device
)

tf = TensorFrame(feat_dict, self.col_names_dict, y)
return self._merge_feat(tf)
Expand Down Expand Up @@ -352,31 +362,38 @@ class Dataset(ABC):
to_datetime function will be used to auto parse time columns.
(default: :obj:`None`)
"""

def __init__(
self,
df: DataFrame,
col_to_stype: dict[str, torch_frame.stype],
in_memory: bool = True,
target_col: str | None = None,
split_col: str | None = None,
col_to_sep: str | None | dict[str, str | None] = None,
col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
| TextEmbedderConfig | None = None,
| TextEmbedderConfig
| None = None,
col_to_text_tokenizer_cfg: dict[str, TextTokenizerConfig]
| TextTokenizerConfig | None = None,
| TextTokenizerConfig
| None = None,
col_to_time_format: str | None | dict[str, str | None] = None,
):
self.df = df
self.target_col = target_col
self.in_memory = in_memory

if split_col is not None:
if split_col not in df.columns:
raise ValueError(
f"Given split_col ({split_col}) does not match columns of "
f"the given df.")
f"the given df."
)
if split_col in col_to_stype:
raise ValueError(
f"col_to_stype should not contain the split_col "
f"({col_to_stype}).")
f"({col_to_stype})."
)
if not set(df[split_col]).issubset(set(SPLIT_TO_NUM.values())):
raise ValueError(
f"split_col must only contain {set(SPLIT_TO_NUM.values())}"
Expand All @@ -387,26 +404,38 @@ def __init__(
cols = self.feat_cols + ([] if target_col is None else [target_col])
missing_cols = set(cols) - set(df.columns)
if len(missing_cols) > 0:
raise ValueError(f"The column(s) '{missing_cols}' are specified "
f"but missing in the data frame")
raise ValueError(
f"The column(s) '{missing_cols}' are specified "
f"but missing in the data frame"
)

if (target_col is not None and self.col_to_stype[target_col]
== torch_frame.multicategorical):
if (
target_col is not None
and self.col_to_stype[target_col] == torch_frame.multicategorical
):
raise ValueError(
"Multilabel classification task is not yet supported.")
"Multilabel classification task is not yet supported."
)

# Canonicalize and validate
self.col_to_sep = self.canonicalize_and_validate_col_to_pattern(
col_to_sep, "col_to_sep")
(self.col_to_time_format
) = self.canonicalize_and_validate_col_to_pattern(
col_to_time_format, "col_to_time_format")
(self.col_to_text_embedder_cfg
) = self.canonicalize_and_validate_col_to_pattern(
col_to_text_embedder_cfg, "col_to_text_embedder_cfg")
(self.col_to_text_tokenizer_cfg
) = self.canonicalize_and_validate_col_to_pattern(
col_to_text_tokenizer_cfg, "col_to_text_tokenizer_cfg")
col_to_sep, "col_to_sep"
)
(
self.col_to_time_format
) = self.canonicalize_and_validate_col_to_pattern(
col_to_time_format, "col_to_time_format"
)
(
self.col_to_text_embedder_cfg
) = self.canonicalize_and_validate_col_to_pattern(
col_to_text_embedder_cfg, "col_to_text_embedder_cfg"
)
(
self.col_to_text_tokenizer_cfg
) = self.canonicalize_and_validate_col_to_pattern(
col_to_text_tokenizer_cfg, "col_to_text_tokenizer_cfg"
)

self._is_materialized: bool = False
self._col_stats: dict[str, dict[StatType, Any]] = {}
Expand All @@ -421,19 +450,22 @@ def canonicalize_and_validate_col_to_pattern(
col_to_pattern_name=col_to_pattern_name,
col_to_pattern=col_to_pattern,
columns=[
col for col, stype in self.col_to_stype.items()
col
for col, stype in self.col_to_stype.items()
if stype == COL_TO_PATTERN_STYPE_MAPPING[col_to_pattern_name]
],
requires_all_inclusive=not COL_TO_PATTERN_ALLOW_NONE_MAPPING[
col_to_pattern_name],
col_to_pattern_name
],
)
assert isinstance(canonical_col_to_pattern, dict)

# Validate types of values.
for col, pattern in canonical_col_to_pattern.items():
pass_validation = False
required_type = COL_TO_PATTERN_REQUIRED_TYPE_MAPPING[
col_to_pattern_name]
col_to_pattern_name
]
allow_none = COL_TO_PATTERN_ALLOW_NONE_MAPPING[col_to_pattern_name]
if isinstance(pattern, required_type):
pass_validation = True
Expand Down Expand Up @@ -479,8 +511,11 @@ def __len__(self) -> int:

def __getitem__(self, index: IndexSelectType) -> Dataset:
is_col_select = isinstance(index, str)
is_col_select |= (isinstance(index, (list, tuple)) and len(index) > 0
and isinstance(index[0], str))
is_col_select |= (
isinstance(index, (list, tuple))
and len(index) > 0
and isinstance(index[0], str)
)

if is_col_select:
return self.col_select(index)
Expand Down Expand Up @@ -522,7 +557,8 @@ def num_classes(self) -> int:
f"num_classes attribute is only supported when the target "
f"column ({self.target_col}) stats contains StatType.COUNT, "
f"but only the following target column stats are calculated: "
f"{list(self.col_stats[self.target_col].keys())}.")
f"{list(self.col_stats[self.target_col].keys())}."
)
num_classes = len(self.col_stats[self.target_col][StatType.COUNT][0])
assert num_classes > 1
return num_classes
Expand Down Expand Up @@ -552,14 +588,21 @@ def materialize(
if self.is_materialized:
# Materialized without specifying path at first and materialize
# again by specifying the path
if self._tensor_frame is None and self.df is not None:
self._tensor_frame = self._to_tensor_frame_converter(
self.df, device
)
self._update_col_stats()

if path is not None and not osp.isfile(path):
torch_frame.save(self._tensor_frame, self._col_stats, path)
return self

if path is not None and osp.isfile(path):
# Load tensor_frame and col_stats
self._tensor_frame, self._col_stats = torch_frame.load(
path, device)
path, device
)
# Instantiate the converter
self._to_tensor_frame_converter = self._get_tensorframe_converter()
# Mark the dataset has been materialized
Expand Down Expand Up @@ -587,18 +630,20 @@ def materialize(

# 2. Create the `TensorFrame`:
self._to_tensor_frame_converter = self._get_tensorframe_converter()
self._tensor_frame = self._to_tensor_frame_converter(self.df, device)
if self.in_memory:
self._tensor_frame = self._to_tensor_frame_converter(
self.df, device
)

# # 3. Update col stats based on `TensorFrame`
self._update_col_stats()

# 3. Update col stats based on `TensorFrame`
self._update_col_stats()
if path is not None:
# Cache the dataset if user specifies the path
torch_frame.save(self._tensor_frame, self._col_stats, path)

# 4. Mark the dataset as materialized:
self._is_materialized = True

if path is not None:
# Cache the dataset if user specifies the path
torch_frame.save(self._tensor_frame, self._col_stats, path)

return self

def _get_tensorframe_converter(self) -> DataFrameToTensorFrameConverter:
Expand All @@ -620,9 +665,11 @@ def _update_col_stats(self):
offset = self._tensor_frame.feat_dict[torch_frame.embedding].offset
emb_dim_list = offset[1:] - offset[:-1]
for i, col_name in enumerate(
self._tensor_frame.col_names_dict[torch_frame.embedding]):
self._tensor_frame.col_names_dict[torch_frame.embedding]
):
self._col_stats[col_name][StatType.EMB_DIM] = int(
emb_dim_list[i])
emb_dim_list[i]
)

@property
def is_materialized(self) -> bool:
Expand Down Expand Up @@ -665,7 +712,8 @@ def index_select(self, index: IndexSelectType) -> Dataset:
iloc = index.cpu().numpy() if isinstance(index, Tensor) else index
dataset.df = self.df.iloc[iloc]

dataset._tensor_frame = self._tensor_frame[index]
if self.in_memory:
dataset._tensor_frame = self._tensor_frame[index]

return dataset

Expand Down Expand Up @@ -706,12 +754,16 @@ def get_split(self, split: str) -> Dataset:
if self.split_col is None:
raise ValueError(
f"'get_split' is not supported for '{self}' since 'split_col' "
f"is not specified.")
f"is not specified."
)
if split not in ["train", "val", "test"]:
raise ValueError(f"The split named '{split}' is not available. "
f"Needs to be either 'train', 'val', or 'test'.")
indices = self.df.index[self.df[self.split_col] ==
SPLIT_TO_NUM[split]].tolist()
raise ValueError(
f"The split named '{split}' is not available. "
f"Needs to be either 'train', 'val', or 'test'."
)
indices = self.df.index[
self.df[self.split_col] == SPLIT_TO_NUM[split]
].tolist()
return self[indices]

def split(self) -> tuple[Dataset, Dataset, Dataset]:
Expand Down
16 changes: 13 additions & 3 deletions torch_frame/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,23 @@ class DataLoader(torch.utils.data.DataLoader):
**kwargs (optional): Additional keyword arguments of
:class:`torch.utils.data.DataLoader`.
"""

def __init__(
self,
dataset: Dataset | TensorFrame,
*args,
**kwargs,
):
kwargs.pop('collate_fn', None)
kwargs.pop("collate_fn", None)
self.in_memory = isinstance(dataset, TensorFrame) or dataset.in_memory

if isinstance(dataset, Dataset):
self.tensor_frame: TensorFrame = dataset.materialize().tensor_frame
if dataset.in_memory:
self.tensor_frame: TensorFrame = dataset.materialize().tensor_frame
else:
self.tensor_frame: TensorFrame = (
lambda index: dataset[index].materialize().tensor_frame
)
else:
self.tensor_frame: TensorFrame = dataset

Expand All @@ -51,4 +58,7 @@ def __init__(
)

def collate_fn(self, index: IndexSelectType) -> TensorFrame:
return self.tensor_frame[index]
if self.in_memory:
return self.tensor_frame[index]
else:
return self.tensor_frame(index)