diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 04b2f80bb..b3ec901c3 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -64,7 +64,8 @@ def _requires_pre_materialization(self, *args, **kwargs): if self.is_materialized: raise RuntimeError( f"'{self}' cannot be modified via '{func.__name__}' post " - f"materialization") + f"materialization" + ) return func(self, *args, **kwargs) return _requires_pre_materialization @@ -76,7 +77,8 @@ def _requires_post_materialization(self, *args, **kwargs): if not self.is_materialized: raise RuntimeError( f"'{func.__name__}' requires a materialized dataset. Please " - f"call `dataset.materialize(...)` first.") + f"call `dataset.materialize(...)` first." + ) return func(self, *args, **kwargs) return _requires_post_materialization @@ -126,7 +128,8 @@ def canonicalize_col_to_pattern( raise ValueError( f"{col_to_pattern_name} requires all columns to be " f"specified but the following columns are missing from " - f"the dictionary: {list(missing_cols)}.") + f"the dictionary: {list(missing_cols)}." + ) else: for col in missing_cols: col_to_pattern[col] = None @@ -171,14 +174,14 @@ class DataFrameToTensorFrameConverter: to_datetime function will be used to auto parse time columns. (default: :obj:`None`) """ + def __init__( self, col_to_stype: dict[str, torch_frame.stype], col_stats: dict[str, dict[StatType, Any]], target_col: str | None = None, col_to_sep: dict[str, str | None] | None = None, - col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] - | None = None, + col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] | None = None, col_to_text_tokenizer_cfg: dict[str, TextTokenizerConfig] | None = None, col_to_time_format: dict[str, str | None] = None, @@ -219,8 +222,9 @@ def _get_mapper(self, col: str) -> TensorMapper: return CategoricalTensorMapper(index) elif stype == torch_frame.multicategorical: index, _ = self.col_stats[col][StatType.MULTI_COUNT] - return MultiCategoricalTensorMapper(index, - sep=self.col_to_sep[col]) + return MultiCategoricalTensorMapper( + index, sep=self.col_to_sep[col] + ) elif stype == torch_frame.timestamp: return TimestampTensorMapper(format=self.col_to_time_format[col]) elif stype == torch_frame.text_embedded: @@ -240,8 +244,9 @@ def _get_mapper(self, col: str) -> TensorMapper: elif stype == torch_frame.embedding: return EmbeddingTensorMapper() else: - raise NotImplementedError(f"Unable to process the semantic " - f"type '{stype.value}'") + raise NotImplementedError( + f"Unable to process the semantic " f"type '{stype.value}'" + ) def _merge_feat(self, tf: TensorFrame) -> TensorFrame: r"""Merge child and parent :obj:`stypes` in the @@ -258,13 +263,16 @@ def _merge_feat(self, tf: TensorFrame) -> TensorFrame: if stype.parent in tf.stypes: parent_feat = tf.feat_dict[stype.parent] tf.feat_dict[stype.parent] = torch_frame.cat( - [parent_feat, child_feat], dim=1) + [parent_feat, child_feat], dim=1 + ) else: tf.feat_dict[stype.parent] = child_feat # Unify col_names_dict - tf.col_names_dict[stype.parent] = tf.col_names_dict.get( - stype.parent, []) + tf.col_names_dict[stype] + tf.col_names_dict[stype.parent] = ( + tf.col_names_dict.get(stype.parent, []) + + tf.col_names_dict[stype] + ) tf.feat_dict.pop(stype) tf.col_names_dict.pop(stype) @@ -293,7 +301,8 @@ def __call__( feat_dict[stype]: dict[str, MultiNestedTensor] = {} for key in xs[0].keys(): feat_dict[stype][key] = MultiNestedTensor.cat( - [x[key] for x in xs], dim=1) + [x[key] for x in xs], dim=1 + ) elif stype.use_multi_embedding_tensor: feat_dict[stype] = MultiEmbeddingTensor.cat(xs, dim=1) else: @@ -302,7 +311,8 @@ def __call__( y: Tensor | None = None if self.target_col is not None and self.target_col in df: y = self._get_mapper(self.target_col).forward( - df[self.target_col], device=device) + df[self.target_col], device=device + ) tf = TensorFrame(feat_dict, self.col_names_dict, y) return self._merge_feat(tf) @@ -352,31 +362,38 @@ class Dataset(ABC): to_datetime function will be used to auto parse time columns. (default: :obj:`None`) """ + def __init__( self, df: DataFrame, col_to_stype: dict[str, torch_frame.stype], + in_memory: bool = True, target_col: str | None = None, split_col: str | None = None, col_to_sep: str | None | dict[str, str | None] = None, col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] - | TextEmbedderConfig | None = None, + | TextEmbedderConfig + | None = None, col_to_text_tokenizer_cfg: dict[str, TextTokenizerConfig] - | TextTokenizerConfig | None = None, + | TextTokenizerConfig + | None = None, col_to_time_format: str | None | dict[str, str | None] = None, ): self.df = df self.target_col = target_col + self.in_memory = in_memory if split_col is not None: if split_col not in df.columns: raise ValueError( f"Given split_col ({split_col}) does not match columns of " - f"the given df.") + f"the given df." + ) if split_col in col_to_stype: raise ValueError( f"col_to_stype should not contain the split_col " - f"({col_to_stype}).") + f"({col_to_stype})." + ) if not set(df[split_col]).issubset(set(SPLIT_TO_NUM.values())): raise ValueError( f"split_col must only contain {set(SPLIT_TO_NUM.values())}" @@ -387,26 +404,38 @@ def __init__( cols = self.feat_cols + ([] if target_col is None else [target_col]) missing_cols = set(cols) - set(df.columns) if len(missing_cols) > 0: - raise ValueError(f"The column(s) '{missing_cols}' are specified " - f"but missing in the data frame") + raise ValueError( + f"The column(s) '{missing_cols}' are specified " + f"but missing in the data frame" + ) - if (target_col is not None and self.col_to_stype[target_col] - == torch_frame.multicategorical): + if ( + target_col is not None + and self.col_to_stype[target_col] == torch_frame.multicategorical + ): raise ValueError( - "Multilabel classification task is not yet supported.") + "Multilabel classification task is not yet supported." + ) # Canonicalize and validate self.col_to_sep = self.canonicalize_and_validate_col_to_pattern( - col_to_sep, "col_to_sep") - (self.col_to_time_format - ) = self.canonicalize_and_validate_col_to_pattern( - col_to_time_format, "col_to_time_format") - (self.col_to_text_embedder_cfg - ) = self.canonicalize_and_validate_col_to_pattern( - col_to_text_embedder_cfg, "col_to_text_embedder_cfg") - (self.col_to_text_tokenizer_cfg - ) = self.canonicalize_and_validate_col_to_pattern( - col_to_text_tokenizer_cfg, "col_to_text_tokenizer_cfg") + col_to_sep, "col_to_sep" + ) + ( + self.col_to_time_format + ) = self.canonicalize_and_validate_col_to_pattern( + col_to_time_format, "col_to_time_format" + ) + ( + self.col_to_text_embedder_cfg + ) = self.canonicalize_and_validate_col_to_pattern( + col_to_text_embedder_cfg, "col_to_text_embedder_cfg" + ) + ( + self.col_to_text_tokenizer_cfg + ) = self.canonicalize_and_validate_col_to_pattern( + col_to_text_tokenizer_cfg, "col_to_text_tokenizer_cfg" + ) self._is_materialized: bool = False self._col_stats: dict[str, dict[StatType, Any]] = {} @@ -421,11 +450,13 @@ def canonicalize_and_validate_col_to_pattern( col_to_pattern_name=col_to_pattern_name, col_to_pattern=col_to_pattern, columns=[ - col for col, stype in self.col_to_stype.items() + col + for col, stype in self.col_to_stype.items() if stype == COL_TO_PATTERN_STYPE_MAPPING[col_to_pattern_name] ], requires_all_inclusive=not COL_TO_PATTERN_ALLOW_NONE_MAPPING[ - col_to_pattern_name], + col_to_pattern_name + ], ) assert isinstance(canonical_col_to_pattern, dict) @@ -433,7 +464,8 @@ def canonicalize_and_validate_col_to_pattern( for col, pattern in canonical_col_to_pattern.items(): pass_validation = False required_type = COL_TO_PATTERN_REQUIRED_TYPE_MAPPING[ - col_to_pattern_name] + col_to_pattern_name + ] allow_none = COL_TO_PATTERN_ALLOW_NONE_MAPPING[col_to_pattern_name] if isinstance(pattern, required_type): pass_validation = True @@ -479,8 +511,11 @@ def __len__(self) -> int: def __getitem__(self, index: IndexSelectType) -> Dataset: is_col_select = isinstance(index, str) - is_col_select |= (isinstance(index, (list, tuple)) and len(index) > 0 - and isinstance(index[0], str)) + is_col_select |= ( + isinstance(index, (list, tuple)) + and len(index) > 0 + and isinstance(index[0], str) + ) if is_col_select: return self.col_select(index) @@ -522,7 +557,8 @@ def num_classes(self) -> int: f"num_classes attribute is only supported when the target " f"column ({self.target_col}) stats contains StatType.COUNT, " f"but only the following target column stats are calculated: " - f"{list(self.col_stats[self.target_col].keys())}.") + f"{list(self.col_stats[self.target_col].keys())}." + ) num_classes = len(self.col_stats[self.target_col][StatType.COUNT][0]) assert num_classes > 1 return num_classes @@ -552,6 +588,12 @@ def materialize( if self.is_materialized: # Materialized without specifying path at first and materialize # again by specifying the path + if self._tensor_frame is None and self.df is not None: + self._tensor_frame = self._to_tensor_frame_converter( + self.df, device + ) + self._update_col_stats() + if path is not None and not osp.isfile(path): torch_frame.save(self._tensor_frame, self._col_stats, path) return self @@ -559,7 +601,8 @@ def materialize( if path is not None and osp.isfile(path): # Load tensor_frame and col_stats self._tensor_frame, self._col_stats = torch_frame.load( - path, device) + path, device + ) # Instantiate the converter self._to_tensor_frame_converter = self._get_tensorframe_converter() # Mark the dataset has been materialized @@ -587,18 +630,20 @@ def materialize( # 2. Create the `TensorFrame`: self._to_tensor_frame_converter = self._get_tensorframe_converter() - self._tensor_frame = self._to_tensor_frame_converter(self.df, device) + if self.in_memory: + self._tensor_frame = self._to_tensor_frame_converter( + self.df, device + ) + + # # 3. Update col stats based on `TensorFrame` + self._update_col_stats() - # 3. Update col stats based on `TensorFrame` - self._update_col_stats() + if path is not None: + # Cache the dataset if user specifies the path + torch_frame.save(self._tensor_frame, self._col_stats, path) # 4. Mark the dataset as materialized: self._is_materialized = True - - if path is not None: - # Cache the dataset if user specifies the path - torch_frame.save(self._tensor_frame, self._col_stats, path) - return self def _get_tensorframe_converter(self) -> DataFrameToTensorFrameConverter: @@ -620,9 +665,11 @@ def _update_col_stats(self): offset = self._tensor_frame.feat_dict[torch_frame.embedding].offset emb_dim_list = offset[1:] - offset[:-1] for i, col_name in enumerate( - self._tensor_frame.col_names_dict[torch_frame.embedding]): + self._tensor_frame.col_names_dict[torch_frame.embedding] + ): self._col_stats[col_name][StatType.EMB_DIM] = int( - emb_dim_list[i]) + emb_dim_list[i] + ) @property def is_materialized(self) -> bool: @@ -665,7 +712,8 @@ def index_select(self, index: IndexSelectType) -> Dataset: iloc = index.cpu().numpy() if isinstance(index, Tensor) else index dataset.df = self.df.iloc[iloc] - dataset._tensor_frame = self._tensor_frame[index] + if self.in_memory: + dataset._tensor_frame = self._tensor_frame[index] return dataset @@ -706,12 +754,16 @@ def get_split(self, split: str) -> Dataset: if self.split_col is None: raise ValueError( f"'get_split' is not supported for '{self}' since 'split_col' " - f"is not specified.") + f"is not specified." + ) if split not in ["train", "val", "test"]: - raise ValueError(f"The split named '{split}' is not available. " - f"Needs to be either 'train', 'val', or 'test'.") - indices = self.df.index[self.df[self.split_col] == - SPLIT_TO_NUM[split]].tolist() + raise ValueError( + f"The split named '{split}' is not available. " + f"Needs to be either 'train', 'val', or 'test'." + ) + indices = self.df.index[ + self.df[self.split_col] == SPLIT_TO_NUM[split] + ].tolist() return self[indices] def split(self) -> tuple[Dataset, Dataset, Dataset]: diff --git a/torch_frame/data/loader.py b/torch_frame/data/loader.py index b41142f60..139435d9e 100644 --- a/torch_frame/data/loader.py +++ b/torch_frame/data/loader.py @@ -30,16 +30,23 @@ class DataLoader(torch.utils.data.DataLoader): **kwargs (optional): Additional keyword arguments of :class:`torch.utils.data.DataLoader`. """ + def __init__( self, dataset: Dataset | TensorFrame, *args, **kwargs, ): - kwargs.pop('collate_fn', None) + kwargs.pop("collate_fn", None) + self.in_memory = isinstance(dataset, TensorFrame) or dataset.in_memory if isinstance(dataset, Dataset): - self.tensor_frame: TensorFrame = dataset.materialize().tensor_frame + if dataset.in_memory: + self.tensor_frame: TensorFrame = dataset.materialize().tensor_frame + else: + self.tensor_frame: TensorFrame = ( + lambda index: dataset[index].materialize().tensor_frame + ) else: self.tensor_frame: TensorFrame = dataset @@ -51,4 +58,7 @@ def __init__( ) def collate_fn(self, index: IndexSelectType) -> TensorFrame: - return self.tensor_frame[index] + if self.in_memory: + return self.tensor_frame[index] + else: + return self.tensor_frame(index)