diff --git a/cirro/helpers/preprocess_dataset.py b/cirro/helpers/preprocess_dataset.py index 8682cbbe..1e4366c5 100644 --- a/cirro/helpers/preprocess_dataset.py +++ b/cirro/helpers/preprocess_dataset.py @@ -16,6 +16,16 @@ logger = logging.getLogger(__name__) +def _fix_s3_path(path: str) -> str: + """ + Fix the S3 path to ensure it starts with 's3://'. + """ + normalized_path = path.replace(os.sep, '/').strip() + if normalized_path.startswith("s3:/") and not normalized_path.startswith("s3://"): + return normalized_path.replace("s3:/", "s3://", 1) + return path + + def write_json(dat, local_path: str, indent=4): """Write a JSON object to a local file.""" with Path(local_path).open(mode="wt") as handle: @@ -26,7 +36,7 @@ def read_csv(path: str, required_columns=None) -> 'DataFrame': """Read a CSV from the dataset and check for any required columns.""" if required_columns is None: required_columns = [] - + path = _fix_s3_path(path) import pandas as pd df = pd.read_csv(path) for col in required_columns: @@ -36,6 +46,7 @@ def read_csv(path: str, required_columns=None) -> 'DataFrame': def read_json(path: str): """Read a JSON object from a local file or S3 path.""" + path = _fix_s3_path(path) s3_path = S3Path(path) if s3_path.valid: diff --git a/pyproject.toml b/pyproject.toml index 265a7cbe..dbc7cbf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cirro" -version = "1.7.0" +version = "1.7.1" description = "CLI tool and SDK for interacting with the Cirro platform" authors = ["Cirro Bio "] license = "MIT"