From 5f04d02d86db2f572fa58a31fe464303c04f3014 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Fri, 1 Nov 2024 20:09:47 -0500 Subject: [PATCH 01/10] transform garden init commit --- src/transforms_garden/main.py | 226 ++++++++++++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 src/transforms_garden/main.py diff --git a/src/transforms_garden/main.py b/src/transforms_garden/main.py new file mode 100644 index 0000000..0633b16 --- /dev/null +++ b/src/transforms_garden/main.py @@ -0,0 +1,226 @@ +import os +import sys +import base64 +import subprocess +import yaml +import logging +from typing import Any, Dict, List +import importlib.util +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class YAMLValidator: + """Validate the YAML file structure.""" + + @staticmethod + def validate(yaml_data: Dict[str, Any]) -> None: + required_fields = ['metadata', 'source', 'dependencies', 'transforms', 'source_code', 'encoding'] + + for field in required_fields: + if field not in yaml_data: + logger.error(f"Missing required field: {field}") + raise ValueError(f"Missing required field: {field}") + + if yaml_data['encoding'] != 'base64': + logger.error("Invalid encoding type; must be 'base64'.") + raise ValueError("Invalid encoding type; must be 'base64'.") + + +class VirtualEnvManager: + """Manage the virtual environment and package installations.""" + + def __init__(self, venv_dir: str): + self.venv_dir = venv_dir + + def create_venv(self) -> None: + logger.info(f"Creating virtual environment in {self.venv_dir}") + subprocess.run(['python', '-m', 'venv', self.venv_dir], check=True) + + def install_packages(self, packages: List[str]) -> None: + logger.info(f"Installing packages: {', '.join(packages)}") + + # Get the list of currently installed packages + installed_packages = subprocess.check_output( + [os.path.join(self.venv_dir, 'bin', 'pip'), 'list', '--format=freeze'] + ).decode('utf-8').splitlines() + + installed_packages_set = {pkg.split('==')[0] for pkg in installed_packages} + + for package in packages: + if package not in installed_packages_set: + logger.info(f"Installing package: {package}") + subprocess.run([os.path.join(self.venv_dir, 'bin', 'pip'), 'install', package], check=True) + else: + logger.info(f"Package '{package}' is already installed; skipping installation.") + + +class PipelineDataHandler: + """Handles reading from source and writing to sink in the Beam pipeline.""" + + def __init__(self, source_path: str, sink_path: str): + self.source_path = source_path + self.sink_path = sink_path + + def read_source(self, p: beam.Pipeline) -> beam.PCollection: + """Read data from the source path into a Beam PCollection.""" + logger.info(f"Reading data from source path: {self.source_path}") + return p | "ReadFromSource" >> beam.io.ReadFromText(self.source_path) + + def write_sink(self, pcollection: beam.PCollection) -> None: + """Write the resulting PCollection to the sink path.""" + logger.info(f"Writing data to sink path: {self.sink_path}") + pcollection | "WriteToSink" >> beam.io.WriteToText(self.sink_path) + + +class TransformManager: + """Manage decoding and handling of transform source codes.""" + + def __init__(self, transforms: Dict[str, str], source_code: Dict[str, str]): + self.transforms = transforms + self.source_code = source_code + self.decoded_code = {} + + def decode_source_code(self) -> None: + logger.info("Decoding source code from base64.") + for source, code in self.source_code.items(): + decoded = base64.b64decode(code).decode('utf-8') + self.decoded_code[source] = decoded + logger.debug(f"Decoded source for {source}: {decoded}...") + +class PipelineBuilder: + """Build and run the Apache Beam pipeline.""" + + def __init__(self, + transforms: Dict[str, str], + decoded_code: Dict[str, str], + venv_dir: str, + source_path: str, + sink_path: str, + pipeline_options: Dict[str, Any]): + + self.transforms = transforms + self.decoded_code = decoded_code + self.venv_dir = venv_dir + self.pipeline_data = PipelineDataHandler(source_path, sink_path) + self.pipeline_options = PipelineOptions.from_dictionary(pipeline_options) + + def load_transform_class(self, module_name: str, class_name: str, source_code: str): + """Dynamically load a class from a decoded source code string.""" + # Create a temporary module name + temp_module_name = f"temp_{module_name.replace('.', '_')}" + spec = importlib.util.spec_from_loader(temp_module_name, loader=None) + module = importlib.util.module_from_spec(spec) + + # Execute the source code in the module's dictionary + exec(source_code, module.__dict__) + + # Register the module in sys.modules + sys.modules[temp_module_name] = module + + # Access the class from the module + transform_class = getattr(module, class_name) + + # Check if it's a valid Beam PTransform + if not issubclass(transform_class, beam.PTransform): + raise TypeError(f"{class_name} is not a subclass of beam.PTransform") + + return transform_class + + def build_pipeline(self) -> None: + logger.info("Building the Apache Beam pipeline.") + + # Building the pipeline + with beam.Pipeline(options=self.pipeline_options) as pipeline: + + # Read data from the source + input_data = self.pipeline_data.read_source(pipeline) + + # Apply each transform in transforms list + for transform in self.transforms: + try: + transform_name, transform_info = list(transform.items())[0] + source_file, class_name = transform_info.split(":") + source_code = self.decoded_code.get(source_file) + + if not source_code: + logger.error(f"Source code for {source_file} not found.") + continue + + # Dynamically load the transform class + transform_class = self.load_transform_class(source_file, class_name, source_code) + logger.info(f"Successfully loaded transform: {transform_name} ({class_name})") + + # Apply the transform + logger.info(f"Applying {transform_name} to the pipeline.") + input_data = input_data | f"Apply_{transform_name}" >> transform_class() + + except Exception as e: + logger.error(f"Error applying transform {transform_name}: {e}") + + # Write the transformed data to the sink + self.pipeline_data.write_sink(input_data) + + +class BeamstackSDK: + """Main SDK class to orchestrate the process.""" + + def __init__(self, yaml_file: str, source_path: str, sink_path: str, pipeline_options: Dict[str, Any]): + self.source_path = source_path + self.sink_path = sink_path + self.pipeline_options = pipeline_options + self.yaml_file = yaml_file + self.yaml_data = self.load_yaml() + self.venv_dir = 'venv' ### virtual env path + + def load_yaml(self) -> Dict[str, Any]: + with open(self.yaml_file, 'r') as file: + return yaml.safe_load(file) + + def run(self) -> None: + # Validate YAML + YAMLValidator.validate(self.yaml_data) + + # Create virtual environment and install dependencies + venv_manager = VirtualEnvManager(self.venv_dir) + venv_manager.create_venv() + venv_manager.install_packages(self.yaml_data['dependencies']) + + # Decode source code + transform_manager = TransformManager( + transforms=self.yaml_data['transforms'], + source_code=self.yaml_data['source_code'] + ) + transform_manager.decode_source_code() + + # Build and run the pipeline + pipeline_builder = PipelineBuilder( + transforms=self.yaml_data['transforms'], + decoded_code=transform_manager.decoded_code, + venv_dir=self.venv_dir, + source_path=self.source_path, + sink_path=self.sink_path, + pipeline_options=self.pipeline_options + ) + pipeline_builder.build_pipeline() + + +if __name__ == "__main__": + pipeline_options = { + 'runner': 'DirectRunner' + } + + yaml_path = 'pipeline.yaml' + source_path = 'input_data.txt' + sink_path = 'output_data.txt' + + transformsdk = BeamstackSDK( + yaml_file=yaml_path, + source_path=source_path, + sink_path=sink_path, + pipeline_options=pipeline_options + ) + + transformsdk.run() \ No newline at end of file From c2345eb940ba83e9cafea4a651001da052b8e2a2 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Fri, 1 Nov 2024 20:29:41 -0500 Subject: [PATCH 02/10] updated the structure --- src/transforms_garden/main.py | 170 +--------------------- src/transforms_garden/packages_handler.py | 33 +++++ src/transforms_garden/pipeline_handler.py | 115 +++++++++++++++ src/transforms_garden/utils.py | 0 src/transforms_garden/yaml_handler.py | 20 +++ 5 files changed, 176 insertions(+), 162 deletions(-) create mode 100644 src/transforms_garden/packages_handler.py create mode 100644 src/transforms_garden/pipeline_handler.py create mode 100644 src/transforms_garden/utils.py create mode 100644 src/transforms_garden/yaml_handler.py diff --git a/src/transforms_garden/main.py b/src/transforms_garden/main.py index 0633b16..4429739 100644 --- a/src/transforms_garden/main.py +++ b/src/transforms_garden/main.py @@ -1,170 +1,16 @@ -import os -import sys -import base64 -import subprocess import yaml import logging -from typing import Any, Dict, List -import importlib.util -import apache_beam as beam -from apache_beam.options.pipeline_options import PipelineOptions +from typing import Any, Dict +from yaml_handler import YAMLValidator +from packages_handler import VirtualEnvManager +from pipeline_handler import TransformManager, PipelineBuilder +from utils import LogHandler -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) -class YAMLValidator: - """Validate the YAML file structure.""" - - @staticmethod - def validate(yaml_data: Dict[str, Any]) -> None: - required_fields = ['metadata', 'source', 'dependencies', 'transforms', 'source_code', 'encoding'] - - for field in required_fields: - if field not in yaml_data: - logger.error(f"Missing required field: {field}") - raise ValueError(f"Missing required field: {field}") - - if yaml_data['encoding'] != 'base64': - logger.error("Invalid encoding type; must be 'base64'.") - raise ValueError("Invalid encoding type; must be 'base64'.") - - -class VirtualEnvManager: - """Manage the virtual environment and package installations.""" - - def __init__(self, venv_dir: str): - self.venv_dir = venv_dir - - def create_venv(self) -> None: - logger.info(f"Creating virtual environment in {self.venv_dir}") - subprocess.run(['python', '-m', 'venv', self.venv_dir], check=True) - - def install_packages(self, packages: List[str]) -> None: - logger.info(f"Installing packages: {', '.join(packages)}") - - # Get the list of currently installed packages - installed_packages = subprocess.check_output( - [os.path.join(self.venv_dir, 'bin', 'pip'), 'list', '--format=freeze'] - ).decode('utf-8').splitlines() - - installed_packages_set = {pkg.split('==')[0] for pkg in installed_packages} - - for package in packages: - if package not in installed_packages_set: - logger.info(f"Installing package: {package}") - subprocess.run([os.path.join(self.venv_dir, 'bin', 'pip'), 'install', package], check=True) - else: - logger.info(f"Package '{package}' is already installed; skipping installation.") - - -class PipelineDataHandler: - """Handles reading from source and writing to sink in the Beam pipeline.""" - - def __init__(self, source_path: str, sink_path: str): - self.source_path = source_path - self.sink_path = sink_path - - def read_source(self, p: beam.Pipeline) -> beam.PCollection: - """Read data from the source path into a Beam PCollection.""" - logger.info(f"Reading data from source path: {self.source_path}") - return p | "ReadFromSource" >> beam.io.ReadFromText(self.source_path) - - def write_sink(self, pcollection: beam.PCollection) -> None: - """Write the resulting PCollection to the sink path.""" - logger.info(f"Writing data to sink path: {self.sink_path}") - pcollection | "WriteToSink" >> beam.io.WriteToText(self.sink_path) - - -class TransformManager: - """Manage decoding and handling of transform source codes.""" - - def __init__(self, transforms: Dict[str, str], source_code: Dict[str, str]): - self.transforms = transforms - self.source_code = source_code - self.decoded_code = {} - - def decode_source_code(self) -> None: - logger.info("Decoding source code from base64.") - for source, code in self.source_code.items(): - decoded = base64.b64decode(code).decode('utf-8') - self.decoded_code[source] = decoded - logger.debug(f"Decoded source for {source}: {decoded}...") - -class PipelineBuilder: - """Build and run the Apache Beam pipeline.""" - - def __init__(self, - transforms: Dict[str, str], - decoded_code: Dict[str, str], - venv_dir: str, - source_path: str, - sink_path: str, - pipeline_options: Dict[str, Any]): - - self.transforms = transforms - self.decoded_code = decoded_code - self.venv_dir = venv_dir - self.pipeline_data = PipelineDataHandler(source_path, sink_path) - self.pipeline_options = PipelineOptions.from_dictionary(pipeline_options) - - def load_transform_class(self, module_name: str, class_name: str, source_code: str): - """Dynamically load a class from a decoded source code string.""" - # Create a temporary module name - temp_module_name = f"temp_{module_name.replace('.', '_')}" - spec = importlib.util.spec_from_loader(temp_module_name, loader=None) - module = importlib.util.module_from_spec(spec) - - # Execute the source code in the module's dictionary - exec(source_code, module.__dict__) - - # Register the module in sys.modules - sys.modules[temp_module_name] = module - - # Access the class from the module - transform_class = getattr(module, class_name) - - # Check if it's a valid Beam PTransform - if not issubclass(transform_class, beam.PTransform): - raise TypeError(f"{class_name} is not a subclass of beam.PTransform") - - return transform_class - - def build_pipeline(self) -> None: - logger.info("Building the Apache Beam pipeline.") - - # Building the pipeline - with beam.Pipeline(options=self.pipeline_options) as pipeline: - - # Read data from the source - input_data = self.pipeline_data.read_source(pipeline) - - # Apply each transform in transforms list - for transform in self.transforms: - try: - transform_name, transform_info = list(transform.items())[0] - source_file, class_name = transform_info.split(":") - source_code = self.decoded_code.get(source_file) - - if not source_code: - logger.error(f"Source code for {source_file} not found.") - continue - - # Dynamically load the transform class - transform_class = self.load_transform_class(source_file, class_name, source_code) - logger.info(f"Successfully loaded transform: {transform_name} ({class_name})") - - # Apply the transform - logger.info(f"Applying {transform_name} to the pipeline.") - input_data = input_data | f"Apply_{transform_name}" >> transform_class() - - except Exception as e: - logger.error(f"Error applying transform {transform_name}: {e}") - - # Write the transformed data to the sink - self.pipeline_data.write_sink(input_data) +logger = LogHandler.logger -class BeamstackSDK: +class BeamstackTransforms: """Main SDK class to orchestrate the process.""" def __init__(self, yaml_file: str, source_path: str, sink_path: str, pipeline_options: Dict[str, Any]): @@ -216,7 +62,7 @@ def run(self) -> None: source_path = 'input_data.txt' sink_path = 'output_data.txt' - transformsdk = BeamstackSDK( + transformsdk = BeamstackTransforms( yaml_file=yaml_path, source_path=source_path, sink_path=sink_path, diff --git a/src/transforms_garden/packages_handler.py b/src/transforms_garden/packages_handler.py new file mode 100644 index 0000000..eebc9f6 --- /dev/null +++ b/src/transforms_garden/packages_handler.py @@ -0,0 +1,33 @@ +import os +import subprocess +from typing import List +from utils import LogHandler + +logger = LogHandler.logger + +class VirtualEnvManager: + """Manage the virtual environment and package installations.""" + + def __init__(self, venv_dir: str): + self.venv_dir = venv_dir + + def create_venv(self) -> None: + logger.info(f"Creating virtual environment in {self.venv_dir}") + subprocess.run(['python', '-m', 'venv', self.venv_dir], check=True) + + def install_packages(self, packages: List[str]) -> None: + logger.info(f"Installing packages: {', '.join(packages)}") + + # Get the list of currently installed packages + installed_packages = subprocess.check_output( + [os.path.join(self.venv_dir, 'bin', 'pip'), 'list', '--format=freeze'] + ).decode('utf-8').splitlines() + + installed_packages_set = {pkg.split('==')[0] for pkg in installed_packages} + + for package in packages: + if package not in installed_packages_set: + logger.info(f"Installing package: {package}") + subprocess.run([os.path.join(self.venv_dir, 'bin', 'pip'), 'install', package], check=True) + else: + logger.info(f"Package '{package}' is already installed; skipping installation.") \ No newline at end of file diff --git a/src/transforms_garden/pipeline_handler.py b/src/transforms_garden/pipeline_handler.py new file mode 100644 index 0000000..6b98f6e --- /dev/null +++ b/src/transforms_garden/pipeline_handler.py @@ -0,0 +1,115 @@ +import sys +import base64 +from typing import Any, Dict +import importlib.util +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +from utils import LogHandler + +logger = LogHandler.logger + +class PipelineDataHandler: + """Handles reading from source and writing to sink in the Beam pipeline.""" + + def __init__(self, source_path: str, sink_path: str): + self.source_path = source_path + self.sink_path = sink_path + + def read_source(self, p: beam.Pipeline) -> beam.PCollection: + """Read data from the source path into a Beam PCollection.""" + logger.info(f"Reading data from source path: {self.source_path}") + return p | "ReadFromSource" >> beam.io.ReadFromText(self.source_path) + + def write_sink(self, pcollection: beam.PCollection) -> None: + """Write the resulting PCollection to the sink path.""" + logger.info(f"Writing data to sink path: {self.sink_path}") + pcollection | "WriteToSink" >> beam.io.WriteToText(self.sink_path) + + +class TransformManager: + """Manage decoding and handling of transform source codes.""" + + def __init__(self, transforms: Dict[str, str], source_code: Dict[str, str]): + self.transforms = transforms + self.source_code = source_code + self.decoded_code = {} + + def decode_source_code(self) -> None: + logger.info("Decoding source code from base64.") + for source, code in self.source_code.items(): + decoded = base64.b64decode(code).decode('utf-8') + self.decoded_code[source] = decoded + logger.debug(f"Decoded source for {source}: {decoded}...") + +class PipelineBuilder: + """Build and run the Apache Beam pipeline.""" + + def __init__(self, + transforms: Dict[str, str], + decoded_code: Dict[str, str], + venv_dir: str, + source_path: str, + sink_path: str, + pipeline_options: Dict[str, Any]): + + self.transforms = transforms + self.decoded_code = decoded_code + self.venv_dir = venv_dir + self.pipeline_data = PipelineDataHandler(source_path, sink_path) + self.pipeline_options = PipelineOptions.from_dictionary(pipeline_options) + + def load_transform_class(self, module_name: str, class_name: str, source_code: str): + """Dynamically load a class from a decoded source code string.""" + # Create a temporary module name + temp_module_name = f"temp_{module_name.replace('.', '_')}" + spec = importlib.util.spec_from_loader(temp_module_name, loader=None) + module = importlib.util.module_from_spec(spec) + + # Execute the source code in the module's dictionary + exec(source_code, module.__dict__) + + # Register the module in sys.modules + sys.modules[temp_module_name] = module + + # Access the class from the module + transform_class = getattr(module, class_name) + + # Check if it's a valid Beam PTransform + if not issubclass(transform_class, beam.PTransform): + raise TypeError(f"{class_name} is not a subclass of beam.PTransform") + + return transform_class + + def build_pipeline(self) -> None: + logger.info("Building the Apache Beam pipeline.") + + # Building the pipeline + with beam.Pipeline(options=self.pipeline_options) as pipeline: + + # Read data from the source + input_data = self.pipeline_data.read_source(pipeline) + + # Apply each transform in transforms list + for transform in self.transforms: + try: + transform_name, transform_info = list(transform.items())[0] + source_file, class_name = transform_info.split(":") + source_code = self.decoded_code.get(source_file) + + if not source_code: + logger.error(f"Source code for {source_file} not found.") + continue + + # Dynamically load the transform class + transform_class = self.load_transform_class(source_file, class_name, source_code) + logger.info(f"Successfully loaded transform: {transform_name} ({class_name})") + + # Apply the transform + logger.info(f"Applying {transform_name} to the pipeline.") + input_data = input_data | f"Apply_{transform_name}" >> transform_class() + + except Exception as e: + logger.error(f"Error applying transform {transform_name}: {e}") + + # Write the transformed data to the sink + self.pipeline_data.write_sink(input_data) \ No newline at end of file diff --git a/src/transforms_garden/utils.py b/src/transforms_garden/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/src/transforms_garden/yaml_handler.py b/src/transforms_garden/yaml_handler.py new file mode 100644 index 0000000..9fdd402 --- /dev/null +++ b/src/transforms_garden/yaml_handler.py @@ -0,0 +1,20 @@ +from typing import Any, Dict +from utils import LogHandler + +logger = LogHandler.logger + +class YAMLValidator: + """Validate the YAML file structure.""" + + @staticmethod + def validate(yaml_data: Dict[str, Any]) -> None: + required_fields = ['metadata', 'source', 'dependencies', 'transforms', 'source_code', 'encoding'] + + for field in required_fields: + if field not in yaml_data: + logger.error(f"Missing required field: {field}") + raise ValueError(f"Missing required field: {field}") + + if yaml_data['encoding'] != 'base64': + logger.error("Invalid encoding type; must be 'base64'.") + raise ValueError("Invalid encoding type; must be 'base64'.") \ No newline at end of file From d81aace53a6f63b7841c9aec4eb89d4f5b2b9b73 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Fri, 1 Nov 2024 20:30:34 -0500 Subject: [PATCH 03/10] minor fix --- src/transforms_garden/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transforms_garden/utils.py b/src/transforms_garden/utils.py index e69de29..e3f5f12 100644 --- a/src/transforms_garden/utils.py +++ b/src/transforms_garden/utils.py @@ -0,0 +1,8 @@ +import logging + +class LogHandler: + def logger(self): + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + return logger \ No newline at end of file From cae7516f1b76633581a2566bc41a4fb368f0ec82 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Tue, 5 Nov 2024 10:13:59 -0600 Subject: [PATCH 04/10] provider logic update wip --- src/transforms_garden/beamstack_transforms.py | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 src/transforms_garden/beamstack_transforms.py diff --git a/src/transforms_garden/beamstack_transforms.py b/src/transforms_garden/beamstack_transforms.py new file mode 100644 index 0000000..4cc72a4 --- /dev/null +++ b/src/transforms_garden/beamstack_transforms.py @@ -0,0 +1,153 @@ +import os +import json +import hashlib +import base64 +import subprocess +import sys +import yaml +import apache_beam as beam +from apache_beam.transforms import external +from typing import Any, Iterable, Mapping, Optional, Callable +from apache_beam.yaml.yaml_provider import ExternalProvider +from apache_beam.utils import subprocess_server +import logging +import importlib.util + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@ExternalProvider.register_provider_type('BeamstackTransform') +def BeamstackTransform(urns, path): + with open(path, 'r') as f: + transform_yaml = yaml.safe_load(f) + + config = { + 'urns': urns, + 'yaml_path': path, + 'dependencies': transform_yaml.get('dependencies', []), + 'runner': transform_yaml.get('runner', 'DirectRunner') + } + + return BeamstackTransformProvider(urns, config) + + +class BeamstackTransformProvider(ExternalProvider): + def __init__(self, urns, config): + super().__init__(urns, BeamstackExpansionService(config)) + self.config = config + self.transforms = config['urns'] + + def available(self) -> bool: + return True + + def cache_artifacts(self) -> Optional[Iterable[str]]: + return [self._service._venv()] + + def create_transform(self, typ: str, args: Mapping[str, Any], yaml_create_transform: Callable) -> beam.PTransform: + """Create a PTransform based on decoded source code and configurations.""" + if callable(self._service): + self._service = self._service() + + if typ in self.transforms: + transform_class = self._load_transform_class(typ) + if callable(transform_class): + processed_args = yaml_create_transform(args) + return transform_class(**processed_args) + else: + logger.error(f"{typ} is not a callable transform class.") + else: + logger.error(f"Transform type {typ} is not recognized.") + return None + + def _load_transform_class(self, transform_name): + """Dynamically loads and returns a transform class by name.""" + module_name, class_name = transform_name.split(":") + try: + spec = importlib.util.spec_from_file_location(module_name, os.path.join(self._service._venv_path(), f"{module_name}.py")) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + transform_class = getattr(module, class_name) + return transform_class + except Exception as e: + logger.error(f"Failed to load transform {transform_name}: {e}") + raise e + + @classmethod + def provider_from_spec(cls, spec): + urns = spec['transforms'] + config = spec['config'] + return cls(urns, config) + + +class BeamstackExpansionService: + VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/beamstack_venvs") + + def __init__(self, config): + self.config = config + self.runner = config.get('runner') + self.yaml_path = config.get('yaml_path') + self.base_python = sys.executable + self._packages = config.get('dependencies', []) + self._service = None + + self._load_yaml() + + def _load_yaml(self): + """Loads and decodes the transforms.yaml file.""" + with open(self.yaml_path, 'r') as file: + data = yaml.safe_load(file) + self._packages = data.get('dependencies', []) + self.transforms = data['transforms'] + self.source_code = data['source_code'] + self.encoding = data['encoding'] + + for src_name, encoded_code in self.source_code.items(): + decoded_code = base64.b64decode(encoded_code).decode('utf-8') + self._write_source_file(src_name, decoded_code) + + def _write_source_file(self, src_name, code): + """Writes decoded code to file for each source.""" + file_path = os.path.join(self._venv_path(), f"{src_name}.py") + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w') as f: + f.write(code) + + def _venv_path(self): + """Returns the path for the virtual environment directory based on the packages and runner.""" + key = json.dumps({'binary': self.base_python, 'packages': sorted(self._packages), 'runner': self.runner}) + venv_hash = hashlib.sha256(key.encode('utf-8')).hexdigest() + return os.path.join(self.VENV_CACHE, venv_hash) + + def _venv(self): + """Creates and returns the virtual environment path if not exists.""" + venv = self._venv_path() + if not os.path.exists(venv): + subprocess.run([self.base_python, '-m', 'venv', venv], check=True) + venv_pip = os.path.join(venv, 'bin', 'pip') + subprocess.run([venv_pip, 'install'] + self._packages, check=True) + return venv + + def __enter__(self): + venv = self._venv() + self._service_provider = subprocess_server.SubprocessServer( + external.ExpansionAndArtifactRetrievalStub, + [ + os.path.join(venv, 'bin', 'python'), + '-m', + 'apache_beam.runners.portability.expansion_service_main', + '--port', + '{{PORT}}', + '--fully_qualified_name_glob=*', + '--pickle_library=cloudpickle', + ] + ) + self._service = self._service_provider.__enter__() + return self._service + + def __exit__(self, *args): + self._service_provider.__exit__(*args) + self._service = None + + if os.path.exists(self._venv_path()): + subprocess.run(['rm', '-rf', self._venv_path()]) + logger.info("Cleaned up virtual environment after pipeline run.") \ No newline at end of file From ff493c42e653edc13df26c84884d621ab8d07ef9 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Fri, 8 Nov 2024 14:10:29 -0600 Subject: [PATCH 05/10] updated the transforms script --- src/transforms_garden/beamstack_transforms.py | 54 +++++++++++++++---- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/src/transforms_garden/beamstack_transforms.py b/src/transforms_garden/beamstack_transforms.py index 4cc72a4..1291416 100644 --- a/src/transforms_garden/beamstack_transforms.py +++ b/src/transforms_garden/beamstack_transforms.py @@ -25,7 +25,7 @@ def BeamstackTransform(urns, path): 'urns': urns, 'yaml_path': path, 'dependencies': transform_yaml.get('dependencies', []), - 'runner': transform_yaml.get('runner', 'DirectRunner') + 'runner': transform_yaml.get('runner', []) } return BeamstackTransformProvider(urns, config) @@ -35,7 +35,9 @@ class BeamstackTransformProvider(ExternalProvider): def __init__(self, urns, config): super().__init__(urns, BeamstackExpansionService(config)) self.config = config - self.transforms = config['urns'] + self.transforms = config.get('urns', {}) + + logger.info(f"Transforms: {self.transforms}") def available(self) -> bool: return True @@ -43,35 +45,65 @@ def available(self) -> bool: def cache_artifacts(self) -> Optional[Iterable[str]]: return [self._service._venv()] - def create_transform(self, typ: str, args: Mapping[str, Any], yaml_create_transform: Callable) -> beam.PTransform: + def create_transform(self, + typ: str, + args: Mapping[str, Any], + yaml_create_transform: Callable[[Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]) -> Optional[beam.PTransform]: """Create a PTransform based on decoded source code and configurations.""" if callable(self._service): self._service = self._service() + logger.info(f"Creating transform of type: {typ} with args: {args}") + if typ in self.transforms: transform_class = self._load_transform_class(typ) if callable(transform_class): - processed_args = yaml_create_transform(args) - return transform_class(**processed_args) + config_args = args.get('config', {}) + try: + return transform_class(**config_args) + except TypeError as e: + logger.error(f"Error initializing transform '{typ}': {e}") + raise else: logger.error(f"{typ} is not a callable transform class.") else: - logger.error(f"Transform type {typ} is not recognized.") + logger.error(f"Transform type '{typ}' is not recognized in BeamstackTransform.") return None + + def _module_class_map(self) -> dict: + """Transform module and class dictionary map""" + self.yaml_path = self.config.get('yaml_path') + + with open(self.yaml_path, 'r') as file: + data = yaml.safe_load(file) + self.transforms = data['transforms'] + + transform_map = {} + for item in self.transforms: + for _, value in item.items(): + file, transform_class = value.split(':') + transform_map[transform_class] = file + + return transform_map + def _load_transform_class(self, transform_name): """Dynamically loads and returns a transform class by name.""" - module_name, class_name = transform_name.split(":") + transform_map = self._module_class_map() + try: - spec = importlib.util.spec_from_file_location(module_name, os.path.join(self._service._venv_path(), f"{module_name}.py")) + logger.info(f"Loading transform class for: {transform_name}") + spec = importlib.util.spec_from_file_location(transform_map[transform_name], os.path.join(self._service._venv_path(), transform_map[transform_name])) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - transform_class = getattr(module, class_name) + transform_class = getattr(module, transform_name) + logger.info(f"Loaded transform class: {transform_class}") return transform_class except Exception as e: logger.error(f"Failed to load transform {transform_name}: {e}") raise e + @classmethod def provider_from_spec(cls, spec): urns = spec['transforms'] @@ -104,10 +136,11 @@ def _load_yaml(self): for src_name, encoded_code in self.source_code.items(): decoded_code = base64.b64decode(encoded_code).decode('utf-8') self._write_source_file(src_name, decoded_code) + self._source_module = src_name def _write_source_file(self, src_name, code): """Writes decoded code to file for each source.""" - file_path = os.path.join(self._venv_path(), f"{src_name}.py") + file_path = os.path.join(self._venv_path(), src_name) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'w') as f: f.write(code) @@ -121,6 +154,7 @@ def _venv_path(self): def _venv(self): """Creates and returns the virtual environment path if not exists.""" venv = self._venv_path() + print(f"Virtual Environment: {venv}") if not os.path.exists(venv): subprocess.run([self.base_python, '-m', 'venv', venv], check=True) venv_pip = os.path.join(venv, 'bin', 'pip') From c23fd94a6a2169133fd70e6dd03bf9bd678a82af Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Fri, 8 Nov 2024 15:45:43 -0600 Subject: [PATCH 06/10] fixed multiple class call issue --- src/transforms_garden/beamstack_transforms.py | 110 ++++++++++-------- 1 file changed, 59 insertions(+), 51 deletions(-) diff --git a/src/transforms_garden/beamstack_transforms.py b/src/transforms_garden/beamstack_transforms.py index 1291416..5a4ef55 100644 --- a/src/transforms_garden/beamstack_transforms.py +++ b/src/transforms_garden/beamstack_transforms.py @@ -55,20 +55,17 @@ def create_transform(self, logger.info(f"Creating transform of type: {typ} with args: {args}") - if typ in self.transforms: - transform_class = self._load_transform_class(typ) - if callable(transform_class): - config_args = args.get('config', {}) - try: - return transform_class(**config_args) - except TypeError as e: - logger.error(f"Error initializing transform '{typ}': {e}") - raise - else: - logger.error(f"{typ} is not a callable transform class.") + transform_class = self._load_transform_class(typ) + + if callable(transform_class): + config_args = args.get('config', {}) + try: + return transform_class(**config_args) + except TypeError as e: + logger.error(f"Error initializing transform '{typ}': {e}") + raise else: - logger.error(f"Transform type '{typ}' is not recognized in BeamstackTransform.") - return None + logger.error(f"{typ} is not a callable transform class.") def _module_class_map(self) -> dict: @@ -82,8 +79,8 @@ def _module_class_map(self) -> dict: transform_map = {} for item in self.transforms: for _, value in item.items(): - file, transform_class = value.split(':') - transform_map[transform_class] = file + module_name, transform_class = value.split(':') + transform_map[transform_class] = module_name return transform_map @@ -93,7 +90,11 @@ def _load_transform_class(self, transform_name): try: logger.info(f"Loading transform class for: {transform_name}") - spec = importlib.util.spec_from_file_location(transform_map[transform_name], os.path.join(self._service._venv_path(), transform_map[transform_name])) + + spec = importlib.util.spec_from_file_location(f"{transform_map[transform_name]}.py", + os.path.join(self._service._venv_path(), + f"{transform_map[transform_name]}.py")) + module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) transform_class = getattr(module, transform_name) @@ -103,7 +104,6 @@ def _load_transform_class(self, transform_name): logger.error(f"Failed to load transform {transform_name}: {e}") raise e - @classmethod def provider_from_spec(cls, spec): urns = spec['transforms'] @@ -129,14 +129,13 @@ def _load_yaml(self): with open(self.yaml_path, 'r') as file: data = yaml.safe_load(file) self._packages = data.get('dependencies', []) - self.transforms = data['transforms'] self.source_code = data['source_code'] self.encoding = data['encoding'] - for src_name, encoded_code in self.source_code.items(): + for module_name, encoded_code in self.source_code.items(): decoded_code = base64.b64decode(encoded_code).decode('utf-8') - self._write_source_file(src_name, decoded_code) - self._source_module = src_name + self._write_source_file(f"{module_name}.py", decoded_code) + self._source_module = f"{module_name}.py" def _write_source_file(self, src_name, code): """Writes decoded code to file for each source.""" @@ -149,39 +148,48 @@ def _venv_path(self): """Returns the path for the virtual environment directory based on the packages and runner.""" key = json.dumps({'binary': self.base_python, 'packages': sorted(self._packages), 'runner': self.runner}) venv_hash = hashlib.sha256(key.encode('utf-8')).hexdigest() - return os.path.join(self.VENV_CACHE, venv_hash) - - def _venv(self): - """Creates and returns the virtual environment path if not exists.""" - venv = self._venv_path() - print(f"Virtual Environment: {venv}") + venv = os.path.join(self.VENV_CACHE, venv_hash) + if not os.path.exists(venv): + installed_packages = subprocess.check_output( + [os.path.join(venv, 'bin', 'pip'), 'list', '--format=freeze'] + ).decode('utf-8').splitlines() + + installed_packages_set = {pkg.split('==')[0] for pkg in installed_packages} + subprocess.run([self.base_python, '-m', 'venv', venv], check=True) venv_pip = os.path.join(venv, 'bin', 'pip') - subprocess.run([venv_pip, 'install'] + self._packages, check=True) + + for package in self._packages: + if package not in installed_packages_set: + logger.info(f"Installing package: {package}") + subprocess.run([venv_pip, 'install'] + self._packages, check=True) + else: + logger.info(f"Package '{package}' is already installed; skipping installation.") + return venv - def __enter__(self): - venv = self._venv() - self._service_provider = subprocess_server.SubprocessServer( - external.ExpansionAndArtifactRetrievalStub, - [ - os.path.join(venv, 'bin', 'python'), - '-m', - 'apache_beam.runners.portability.expansion_service_main', - '--port', - '{{PORT}}', - '--fully_qualified_name_glob=*', - '--pickle_library=cloudpickle', - ] - ) - self._service = self._service_provider.__enter__() - return self._service - - def __exit__(self, *args): - self._service_provider.__exit__(*args) - self._service = None + # def __enter__(self): + # venv = self._venv_path + # self._service_provider = subprocess_server.SubprocessServer( + # external.ExpansionAndArtifactRetrievalStub, + # [ + # os.path.join(venv, 'bin', 'python3'), + # '-m', + # 'apache_beam.runners.portability.expansion_service_main', + # '--port', + # '{{PORT}}', + # '--fully_qualified_name_glob=*', + # '--pickle_library=cloudpickle', + # ] + # ) + # self._service = self._service_provider.__enter__() + # return self._service + + # def __exit__(self, *args): + # self._service_provider.__exit__(*args) + # self._service = None - if os.path.exists(self._venv_path()): - subprocess.run(['rm', '-rf', self._venv_path()]) - logger.info("Cleaned up virtual environment after pipeline run.") \ No newline at end of file + # if os.path.exists(self._venv_path()): + # subprocess.run(['rm', '-rf', self._venv_path()]) + # logger.info("Cleaned up virtual environment after pipeline run.") \ No newline at end of file From d5c684970776fd35b54e1ed68be7509382add2b7 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Wed, 27 Nov 2024 09:14:38 -0600 Subject: [PATCH 07/10] updated beamstack provider impl --- ...ck_transforms.py => beamstack_provider.py} | 158 ++++++++++++------ src/transforms_garden/main.py | 72 -------- src/transforms_garden/packages_handler.py | 33 ---- src/transforms_garden/pipeline_handler.py | 115 ------------- src/transforms_garden/utils.py | 8 - src/transforms_garden/yaml_handler.py | 20 --- 6 files changed, 109 insertions(+), 297 deletions(-) rename src/{transforms_garden/beamstack_transforms.py => beamstack_provider.py} (55%) delete mode 100644 src/transforms_garden/main.py delete mode 100644 src/transforms_garden/packages_handler.py delete mode 100644 src/transforms_garden/pipeline_handler.py delete mode 100644 src/transforms_garden/utils.py delete mode 100644 src/transforms_garden/yaml_handler.py diff --git a/src/transforms_garden/beamstack_transforms.py b/src/beamstack_provider.py similarity index 55% rename from src/transforms_garden/beamstack_transforms.py rename to src/beamstack_provider.py index 5a4ef55..5b9dcc1 100644 --- a/src/transforms_garden/beamstack_transforms.py +++ b/src/beamstack_provider.py @@ -6,31 +6,108 @@ import sys import yaml import apache_beam as beam -from apache_beam.transforms import external from typing import Any, Iterable, Mapping, Optional, Callable from apache_beam.yaml.yaml_provider import ExternalProvider -from apache_beam.utils import subprocess_server import logging import importlib.util +import urllib.request +from urllib.parse import urlparse logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +class BeamstackProviderPathHandler: + def is_local_file(self, path: str) -> bool: + """Check if the path is a local file.""" + return os.path.exists(path) + + def is_github_file(self, path: str) -> bool: + """Check if the path is a GitHub file URL.""" + parsed_url = urlparse(path) + + if parsed_url.netloc == "raw.githubusercontent.com": + return True + if parsed_url.netloc == "github.com": + return True + + return False + + def is_gcs_file(self, path: str) -> bool: + """Check if the path is a Google Cloud Storage URL.""" + return path.startswith('gs://') + + def handle_local_file(self, path: str): + """Handle local file or directory path.""" + logger.info(f"Pulling transforms yaml from: {path}") + if not os.path.exists(path): + raise FileNotFoundError(f"The local path '{path}' does not exist.") + return path + + def handle_github_file(self, file_url: str, target_dir: str): + """Download file from a public GitHub repository to the target path.""" + logger.info(f"Pulling transforms yaml from GitHub url: {file_url}") + + if "github.com" in file_url and "/blob/" in file_url: + file_url = file_url.replace("github.com", "raw.githubusercontent.com").replace("blob/", "") + + os.makedirs(target_dir, exist_ok=True) + + file_name = os.path.basename(file_url) + local_file_path = os.path.join(target_dir, file_name) + + try: + logger.info(f"Downloading {file_name} to {local_file_path}") + urllib.request.urlretrieve(file_url, local_file_path) + except Exception as e: + logger.info(f"Error occured during file download: {e}") + + return local_file_path + + def handle_gcs_file(self, gcs_path: str, target_dir: str): + """Download files from a public GCS bucket to a target path.""" + logger.info(f"Pulling transforms yaml from GCS path: {gcs_path}") + + gcs_path = gcs_path[len("gs://"):] + bucket_name, _, object_name = gcs_path.partition('/') + public_url = f"https://storage.googleapis.com/{bucket_name}/{object_name}" + + os.makedirs(target_dir, exist_ok=True) + local_file_path = os.path.join(target_dir, os.path.basename(object_name)) + + try: + logger.info(f"Downloading {os.path.basename(object_name)} to {target_dir}") + urllib.request.urlretrieve(public_url, local_file_path) + except Exception as e: + logger.info(f"Error downloading file from {public_url}: {e}") + + return local_file_path + @ExternalProvider.register_provider_type('BeamstackTransform') def BeamstackTransform(urns, path): - with open(path, 'r') as f: + target_dir = '/tmp/beamstack_transforms' + + path_handler = BeamstackProviderPathHandler() + + if path_handler.is_local_file(path): + transform_yaml_path = path_handler.handle_local_file(path) + elif path_handler.is_github_file(path): + transform_yaml_path = path_handler.handle_github_file(path, target_dir) + elif path_handler.is_gcs_file(path): + transform_yaml_path = path_handler.handle_gcs_file(path, target_dir) + else: + raise ValueError(f"Unsupported path type: {path}") + + with open(transform_yaml_path, 'r') as f: transform_yaml = yaml.safe_load(f) config = { 'urns': urns, - 'yaml_path': path, - 'dependencies': transform_yaml.get('dependencies', []), - 'runner': transform_yaml.get('runner', []) + 'yaml_path': transform_yaml_path, + 'dependencies': transform_yaml.get('dependencies', []) } return BeamstackTransformProvider(urns, config) - class BeamstackTransformProvider(ExternalProvider): def __init__(self, urns, config): super().__init__(urns, BeamstackExpansionService(config)) @@ -91,11 +168,16 @@ def _load_transform_class(self, transform_name): try: logger.info(f"Loading transform class for: {transform_name}") - spec = importlib.util.spec_from_file_location(f"{transform_map[transform_name]}.py", - os.path.join(self._service._venv_path(), - f"{transform_map[transform_name]}.py")) - + spec = importlib.util.spec_from_file_location( + f"{transform_map[transform_name]}.py", + os.path.join(self._service._venv_path(), f"{transform_map[transform_name]}.py") + ) + if spec is None: + logger.error(f"Specification for module '{transform_map[transform_name]}' could not be found.") + return None + module = importlib.util.module_from_spec(spec) + sys.path.insert(0, os.path.dirname(spec.origin)) spec.loader.exec_module(module) transform_class = getattr(module, transform_name) logger.info(f"Loaded transform class: {transform_class}") @@ -104,12 +186,6 @@ def _load_transform_class(self, transform_name): logger.error(f"Failed to load transform {transform_name}: {e}") raise e - @classmethod - def provider_from_spec(cls, spec): - urns = spec['transforms'] - config = spec['config'] - return cls(urns, config) - class BeamstackExpansionService: VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/beamstack_venvs") @@ -139,7 +215,8 @@ def _load_yaml(self): def _write_source_file(self, src_name, code): """Writes decoded code to file for each source.""" - file_path = os.path.join(self._venv_path(), src_name) + venv = self._venv_path() + file_path = os.path.join(venv, src_name) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'w') as f: f.write(code) @@ -151,45 +228,28 @@ def _venv_path(self): venv = os.path.join(self.VENV_CACHE, venv_hash) if not os.path.exists(venv): + subprocess.run([self.base_python, '-m', 'venv', venv], check=True) + + site_packages_path = os.path.join(venv, 'lib', f'python{sys.version_info.major}.{sys.version_info.minor}', 'site-packages') + if site_packages_path not in sys.path: + sys.path.insert(0, site_packages_path) + + venv_pip = os.path.join(venv, 'bin', 'pip') + + if os.path.exists(venv_pip): installed_packages = subprocess.check_output( - [os.path.join(venv, 'bin', 'pip'), 'list', '--format=freeze'] + [venv_pip, 'list', '--format=freeze'] ).decode('utf-8').splitlines() installed_packages_set = {pkg.split('==')[0] for pkg in installed_packages} - subprocess.run([self.base_python, '-m', 'venv', venv], check=True) - venv_pip = os.path.join(venv, 'bin', 'pip') - for package in self._packages: if package not in installed_packages_set: logger.info(f"Installing package: {package}") - subprocess.run([venv_pip, 'install'] + self._packages, check=True) + subprocess.run([venv_pip, 'install', package], check=True) else: logger.info(f"Package '{package}' is already installed; skipping installation.") + else: + raise FileNotFoundError(f"Could not find pip at expected location: {venv_pip}") - return venv - - # def __enter__(self): - # venv = self._venv_path - # self._service_provider = subprocess_server.SubprocessServer( - # external.ExpansionAndArtifactRetrievalStub, - # [ - # os.path.join(venv, 'bin', 'python3'), - # '-m', - # 'apache_beam.runners.portability.expansion_service_main', - # '--port', - # '{{PORT}}', - # '--fully_qualified_name_glob=*', - # '--pickle_library=cloudpickle', - # ] - # ) - # self._service = self._service_provider.__enter__() - # return self._service - - # def __exit__(self, *args): - # self._service_provider.__exit__(*args) - # self._service = None - - # if os.path.exists(self._venv_path()): - # subprocess.run(['rm', '-rf', self._venv_path()]) - # logger.info("Cleaned up virtual environment after pipeline run.") \ No newline at end of file + return venv \ No newline at end of file diff --git a/src/transforms_garden/main.py b/src/transforms_garden/main.py deleted file mode 100644 index 4429739..0000000 --- a/src/transforms_garden/main.py +++ /dev/null @@ -1,72 +0,0 @@ -import yaml -import logging -from typing import Any, Dict -from yaml_handler import YAMLValidator -from packages_handler import VirtualEnvManager -from pipeline_handler import TransformManager, PipelineBuilder -from utils import LogHandler - - -logger = LogHandler.logger - - -class BeamstackTransforms: - """Main SDK class to orchestrate the process.""" - - def __init__(self, yaml_file: str, source_path: str, sink_path: str, pipeline_options: Dict[str, Any]): - self.source_path = source_path - self.sink_path = sink_path - self.pipeline_options = pipeline_options - self.yaml_file = yaml_file - self.yaml_data = self.load_yaml() - self.venv_dir = 'venv' ### virtual env path - - def load_yaml(self) -> Dict[str, Any]: - with open(self.yaml_file, 'r') as file: - return yaml.safe_load(file) - - def run(self) -> None: - # Validate YAML - YAMLValidator.validate(self.yaml_data) - - # Create virtual environment and install dependencies - venv_manager = VirtualEnvManager(self.venv_dir) - venv_manager.create_venv() - venv_manager.install_packages(self.yaml_data['dependencies']) - - # Decode source code - transform_manager = TransformManager( - transforms=self.yaml_data['transforms'], - source_code=self.yaml_data['source_code'] - ) - transform_manager.decode_source_code() - - # Build and run the pipeline - pipeline_builder = PipelineBuilder( - transforms=self.yaml_data['transforms'], - decoded_code=transform_manager.decoded_code, - venv_dir=self.venv_dir, - source_path=self.source_path, - sink_path=self.sink_path, - pipeline_options=self.pipeline_options - ) - pipeline_builder.build_pipeline() - - -if __name__ == "__main__": - pipeline_options = { - 'runner': 'DirectRunner' - } - - yaml_path = 'pipeline.yaml' - source_path = 'input_data.txt' - sink_path = 'output_data.txt' - - transformsdk = BeamstackTransforms( - yaml_file=yaml_path, - source_path=source_path, - sink_path=sink_path, - pipeline_options=pipeline_options - ) - - transformsdk.run() \ No newline at end of file diff --git a/src/transforms_garden/packages_handler.py b/src/transforms_garden/packages_handler.py deleted file mode 100644 index eebc9f6..0000000 --- a/src/transforms_garden/packages_handler.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -import subprocess -from typing import List -from utils import LogHandler - -logger = LogHandler.logger - -class VirtualEnvManager: - """Manage the virtual environment and package installations.""" - - def __init__(self, venv_dir: str): - self.venv_dir = venv_dir - - def create_venv(self) -> None: - logger.info(f"Creating virtual environment in {self.venv_dir}") - subprocess.run(['python', '-m', 'venv', self.venv_dir], check=True) - - def install_packages(self, packages: List[str]) -> None: - logger.info(f"Installing packages: {', '.join(packages)}") - - # Get the list of currently installed packages - installed_packages = subprocess.check_output( - [os.path.join(self.venv_dir, 'bin', 'pip'), 'list', '--format=freeze'] - ).decode('utf-8').splitlines() - - installed_packages_set = {pkg.split('==')[0] for pkg in installed_packages} - - for package in packages: - if package not in installed_packages_set: - logger.info(f"Installing package: {package}") - subprocess.run([os.path.join(self.venv_dir, 'bin', 'pip'), 'install', package], check=True) - else: - logger.info(f"Package '{package}' is already installed; skipping installation.") \ No newline at end of file diff --git a/src/transforms_garden/pipeline_handler.py b/src/transforms_garden/pipeline_handler.py deleted file mode 100644 index 6b98f6e..0000000 --- a/src/transforms_garden/pipeline_handler.py +++ /dev/null @@ -1,115 +0,0 @@ -import sys -import base64 -from typing import Any, Dict -import importlib.util -import apache_beam as beam -from apache_beam.options.pipeline_options import PipelineOptions -from utils import LogHandler - -logger = LogHandler.logger - -class PipelineDataHandler: - """Handles reading from source and writing to sink in the Beam pipeline.""" - - def __init__(self, source_path: str, sink_path: str): - self.source_path = source_path - self.sink_path = sink_path - - def read_source(self, p: beam.Pipeline) -> beam.PCollection: - """Read data from the source path into a Beam PCollection.""" - logger.info(f"Reading data from source path: {self.source_path}") - return p | "ReadFromSource" >> beam.io.ReadFromText(self.source_path) - - def write_sink(self, pcollection: beam.PCollection) -> None: - """Write the resulting PCollection to the sink path.""" - logger.info(f"Writing data to sink path: {self.sink_path}") - pcollection | "WriteToSink" >> beam.io.WriteToText(self.sink_path) - - -class TransformManager: - """Manage decoding and handling of transform source codes.""" - - def __init__(self, transforms: Dict[str, str], source_code: Dict[str, str]): - self.transforms = transforms - self.source_code = source_code - self.decoded_code = {} - - def decode_source_code(self) -> None: - logger.info("Decoding source code from base64.") - for source, code in self.source_code.items(): - decoded = base64.b64decode(code).decode('utf-8') - self.decoded_code[source] = decoded - logger.debug(f"Decoded source for {source}: {decoded}...") - -class PipelineBuilder: - """Build and run the Apache Beam pipeline.""" - - def __init__(self, - transforms: Dict[str, str], - decoded_code: Dict[str, str], - venv_dir: str, - source_path: str, - sink_path: str, - pipeline_options: Dict[str, Any]): - - self.transforms = transforms - self.decoded_code = decoded_code - self.venv_dir = venv_dir - self.pipeline_data = PipelineDataHandler(source_path, sink_path) - self.pipeline_options = PipelineOptions.from_dictionary(pipeline_options) - - def load_transform_class(self, module_name: str, class_name: str, source_code: str): - """Dynamically load a class from a decoded source code string.""" - # Create a temporary module name - temp_module_name = f"temp_{module_name.replace('.', '_')}" - spec = importlib.util.spec_from_loader(temp_module_name, loader=None) - module = importlib.util.module_from_spec(spec) - - # Execute the source code in the module's dictionary - exec(source_code, module.__dict__) - - # Register the module in sys.modules - sys.modules[temp_module_name] = module - - # Access the class from the module - transform_class = getattr(module, class_name) - - # Check if it's a valid Beam PTransform - if not issubclass(transform_class, beam.PTransform): - raise TypeError(f"{class_name} is not a subclass of beam.PTransform") - - return transform_class - - def build_pipeline(self) -> None: - logger.info("Building the Apache Beam pipeline.") - - # Building the pipeline - with beam.Pipeline(options=self.pipeline_options) as pipeline: - - # Read data from the source - input_data = self.pipeline_data.read_source(pipeline) - - # Apply each transform in transforms list - for transform in self.transforms: - try: - transform_name, transform_info = list(transform.items())[0] - source_file, class_name = transform_info.split(":") - source_code = self.decoded_code.get(source_file) - - if not source_code: - logger.error(f"Source code for {source_file} not found.") - continue - - # Dynamically load the transform class - transform_class = self.load_transform_class(source_file, class_name, source_code) - logger.info(f"Successfully loaded transform: {transform_name} ({class_name})") - - # Apply the transform - logger.info(f"Applying {transform_name} to the pipeline.") - input_data = input_data | f"Apply_{transform_name}" >> transform_class() - - except Exception as e: - logger.error(f"Error applying transform {transform_name}: {e}") - - # Write the transformed data to the sink - self.pipeline_data.write_sink(input_data) \ No newline at end of file diff --git a/src/transforms_garden/utils.py b/src/transforms_garden/utils.py deleted file mode 100644 index e3f5f12..0000000 --- a/src/transforms_garden/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -import logging - -class LogHandler: - def logger(self): - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - return logger \ No newline at end of file diff --git a/src/transforms_garden/yaml_handler.py b/src/transforms_garden/yaml_handler.py deleted file mode 100644 index 9fdd402..0000000 --- a/src/transforms_garden/yaml_handler.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any, Dict -from utils import LogHandler - -logger = LogHandler.logger - -class YAMLValidator: - """Validate the YAML file structure.""" - - @staticmethod - def validate(yaml_data: Dict[str, Any]) -> None: - required_fields = ['metadata', 'source', 'dependencies', 'transforms', 'source_code', 'encoding'] - - for field in required_fields: - if field not in yaml_data: - logger.error(f"Missing required field: {field}") - raise ValueError(f"Missing required field: {field}") - - if yaml_data['encoding'] != 'base64': - logger.error("Invalid encoding type; must be 'base64'.") - raise ValueError("Invalid encoding type; must be 'base64'.") \ No newline at end of file From 450e13f1d63f420933892377293de7c4cb6687c8 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Tue, 10 Dec 2024 12:39:47 -0600 Subject: [PATCH 08/10] added some preprocessing transforms --- .../preprocessing/augment_text.py | 79 +++++++++++++++++++ .../preprocessing/clean_text.py | 69 ++++++++++++++++ .../preprocessing/detect_text_lang.py | 41 ++++++++++ .../preprocessing/normalize_tokens.py | 57 +++++++++++++ .../preprocessing/tokenize_text.py | 41 ++++++++++ 5 files changed, 287 insertions(+) create mode 100644 src/beamstack_transforms/preprocessing/augment_text.py create mode 100644 src/beamstack_transforms/preprocessing/clean_text.py create mode 100644 src/beamstack_transforms/preprocessing/detect_text_lang.py create mode 100644 src/beamstack_transforms/preprocessing/normalize_tokens.py create mode 100644 src/beamstack_transforms/preprocessing/tokenize_text.py diff --git a/src/beamstack_transforms/preprocessing/augment_text.py b/src/beamstack_transforms/preprocessing/augment_text.py new file mode 100644 index 0000000..409ae8b --- /dev/null +++ b/src/beamstack_transforms/preprocessing/augment_text.py @@ -0,0 +1,79 @@ +import apache_beam as beam +import random +from typing import List, Dict, Any +from nltk.corpus import wordnet + +class TextAugmentation(beam.PTransform): + def __init__(self, techniques: List[str], augment_factor: int = 1): + """ + Initializes transform class for augmenting text data using specified techniques. + + :param techniques (List[str]): List of augmentation techniques to apply (e.g., 'synonym_replacement', 'back_translation'). + :param augment_factor (int): Number of augmented examples to generate per input. Default is 1. + """ + super().__init__() + self.techniques = techniques + self.augment_factor = augment_factor + + def expand(self, pcoll): + return pcoll | "Augment Text" >> beam.ParDo( + self._TextAugmentationFn(self.techniques, self.augment_factor) + ) + + class _TextAugmentationFn(beam.DoFn): + def __init__(self, techniques: List[str], augment_factor: int): + """ + A DoFn for applying text augmentation techniques. + + :param techniques (List[str]): List of techniques for augmentation. + :param augment_factor (int): Number of augmented examples to generate per input. + """ + self.techniques = techniques + self.augment_factor = augment_factor + + def process(self, element: Dict[str, Any]): + """ + Augments the input text using specified techniques. + + Args: + :param element (Dict[str, Any]): Input dictionary containing the text to augment. + :param yield (Dict[str, Any]): Augmented examples. + """ + text = element.get("text", "") + for _ in range(self.augment_factor): + augmented_text = self._apply_augmentation(text) + augmented_element = element.copy() + augmented_element["text"] = augmented_text + yield augmented_element + + def _apply_augmentation(self, text: str) -> str: + """ + Applies augmentation techniques to the input text. + + :param text (str): Original text. + + Returns: + str: Augmented text. + """ + if "synonym_replacement" in self.techniques: + text = self._synonym_replacement(text) + # Additional techniques can be added here. + return text + + def _synonym_replacement(self, text: str) -> str: + """ + Replaces random words with synonyms. + + Args: + text (str): Original text. + + Returns: + str: Text with synonyms replaced. + """ + words = text.split() + for i, word in enumerate(words): + synonyms = wordnet.synsets(word) + if synonyms: + synonym = random.choice(synonyms).lemmas()[0].name() + words[i] = synonym + return " ".join(words) \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/clean_text.py b/src/beamstack_transforms/preprocessing/clean_text.py new file mode 100644 index 0000000..4be02d9 --- /dev/null +++ b/src/beamstack_transforms/preprocessing/clean_text.py @@ -0,0 +1,69 @@ +import apache_beam as beam +import re +from typing import List +from nltk.corpus import stopwords +import nltk + +nltk.download('stopwords') + +class CleanText(beam.PTransform): + def __init__(self, stop_words: List[str] = None, additional_patterns: List[str] = None): + """ + Initializes the Transform class for cleaning texts. + + :param stop_words: List of custom stop words to add to the default list. + :param additional_patterns: List of regex patterns to remove from text. + """ + super().__init__() + self.stop_words = stop_words + self.additional_patterns = additional_patterns + + def expand(self, pcoll): + return ( + pcoll + | "Remove Patterns" >> beam.ParDo(self._RemovePatternsFn(self.additional_patterns)) + | "Remove Stop Words" >> beam.ParDo(self._RemoveStopWordsFn(self.stop_words)) + ) + + class _RemovePatternsFn(beam.DoFn): + def __init__(self, additional_patterns: List[str] = None): + """ + Initializes the class to remove regex patterns from text. + + :param additional_patterns: List of regex patterns to remove from text. + """ + self.additional_patterns = additional_patterns if additional_patterns else [] + + def process(self, element: str): + """ + Removes specified regex patterns from the text. + + :param element: Input text. + :yield: Text with patterns removed. + """ + text = re.sub(r'[^a-zA-Z0-9\s]', '', element) # Remove non-alphanumeric characters. + for pattern in self.additional_patterns: + text = re.sub(pattern, '', text) + yield text + + class _RemoveStopWordsFn(beam.DoFn): + def __init__(self, stop_words: List[str] = None): + """ + Initializes the transform class for removing stop words from text. + + :param stop_words: List of custom stop words to add to the default list. + """ + nltk_stop_words = set(stopwords.words('english')) + self.stop_words = nltk_stop_words.union(set(stop_words)) if stop_words else nltk_stop_words + + def process(self, element: str): + """ + Removes stop words from the text. + + :param element: Input text. + :yield: Text without stop words. + """ + text = element.lower() + words = text.split() + text = ' '.join(word for word in words if word not in self.stop_words) + yield text \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/detect_text_lang.py b/src/beamstack_transforms/preprocessing/detect_text_lang.py new file mode 100644 index 0000000..7c89dd4 --- /dev/null +++ b/src/beamstack_transforms/preprocessing/detect_text_lang.py @@ -0,0 +1,41 @@ +import apache_beam as beam +from langdetect import detect +from typing import Any, Optional + +class LanguageDetection(beam.PTransform): + def __init__(self, output_key: Optional[str] = None): + """ + Initializes the transform class for detecting text language. + + :param output_key (Optional[str]): Key to store the detected language in the output element. + """ + super().__init__() + self.output_key = output_key + + def expand(self, pcoll): + return pcoll | "Detect Language" >> beam.ParDo(self._DetectLanguageFn(self.output_key)) + + class _DetectLanguageFn(beam.DoFn): + def __init__(self, output_key: Optional[str]): + """ + Initializes class for detecting the language of input text. + + :param output_key (Optional[str]): Key to store the detected language in the output element. + """ + self.output_key = output_key + + def process(self, element: Any): + """ + Detects the language of the input text. + + :param element: Input text element. Can be plain text or a dictionary containing text. + :param yield: Output text with detected language. + """ + text = element if isinstance(element, str) else element.get("text", "") + detected_language = detect(text) + + if self.output_key: + element[self.output_key] = detected_language + yield element + else: + yield {"text": text, "language": detected_language} \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/normalize_tokens.py b/src/beamstack_transforms/preprocessing/normalize_tokens.py new file mode 100644 index 0000000..a206599 --- /dev/null +++ b/src/beamstack_transforms/preprocessing/normalize_tokens.py @@ -0,0 +1,57 @@ +import apache_beam as beam +from nltk.stem import PorterStemmer +from nltk.stem.wordnet import WordNetLemmatizer +from typing import List +import nltk + +nltk.download('wordnet') +nltk.download('omw-1.4') + + +class NormalizeTokens(beam.PTransform): + def __init__(self, lemmatize: bool = True, stem: bool = False): + """ + Initializes transform class for normalizing tokens. + + :param lemmatize (bool): Whether to apply lemmatization. Default is True. + :param stem (bool): Whether to apply stemming. Default is False. + """ + super().__init__() + self.lemmatize = lemmatize + self.stem = stem + + def expand(self, pcoll): + return pcoll | "Normalize Tokens" >> beam.ParDo( + self._NormalizeTokensFn(lemmatize=self.lemmatize, stem=self.stem) + ) + + class _NormalizeTokensFn(beam.DoFn): + def __init__(self, lemmatize: bool, stem: bool): + """ + Initialize class for normalizing tokens using lemmatization and/or stemming. + + :param lemmatize (bool): Whether to apply lemmatization. + :param stem (bool): Whether to apply stemming. + """ + self.lemmatize = lemmatize + self.stem = stem + self.lemmatizer = WordNetLemmatizer() if lemmatize else None + self.stemmer = PorterStemmer() if stem else None + + def process(self, element: List[str]): + """ + Normalizes tokens in the input element. + + Args: + :param element (List[str]): List of tokens. + :param yield (List[str]): Normalized tokens. + """ + normalized_tokens = [] + for token in element: + word = token + if self.lemmatize: + word = self.lemmatizer.lemmatize(word) + if self.stem: + word = self.stemmer.stem(word) + normalized_tokens.append(word) + yield normalized_tokens \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/tokenize_text.py b/src/beamstack_transforms/preprocessing/tokenize_text.py new file mode 100644 index 0000000..7a836ac --- /dev/null +++ b/src/beamstack_transforms/preprocessing/tokenize_text.py @@ -0,0 +1,41 @@ +import re +import apache_beam as beam +from typing import List, Optional + +class TokenizeText(beam.PTransform): + def __init__(self, lowercase: bool = True, custom_delimiters: Optional[List[str]] = None): + """ + Initializes transform class for tokenizing text. + + :param lowercase (bool): Whether to lowercase the text before tokenization. + :param custom_delimiters (Optional[List[str]]): Additional delimiters for tokenization. + """ + super().__init__() + self.lowercase = lowercase + self.custom_delimiters = custom_delimiters + + def expand(self, pcoll): + return pcoll | "Tokenize Text" >> beam.ParDo(self._TokenizeTextFn(self.lowercase, self.custom_delimiters)) + + class _TokenizeTextFn(beam.DoFn): + def __init__(self, lowercase: bool = True, custom_delimiters: Optional[List[str]] = None): + """ + Initializes transform class for tokenizing text, with optional lowercasing. + + :param lowercase (bool): Whether to convert text to lowercase before tokenization. + :param custom_delimiters (Optional[List[str]]): Additional delimiters for tokenization. + """ + self.lowercase = lowercase + self.custom_delimiters = custom_delimiters or [" ", "\n", "\t", ".", ",", "!", "?"] + + def process(self, element: str): + """ + Tokenizes the input text. + + :param element (str): Input text. + :param yield (List[str]): Tokenized words. + """ + text = element.lower() if self.lowercase else element + delimiter_pattern = "|".join(map(re.escape, self.custom_delimiters)) + tokens = re.split(delimiter_pattern, text) + yield [token for token in tokens if token] \ No newline at end of file From 74b767ac642f64ab330cd6e457d34eac7d467572 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Thu, 9 Jan 2025 10:31:43 -0500 Subject: [PATCH 09/10] added sentence processing transforms --- .../embeddings/huggingface.py | 29 ++-------- .../embeddings/sentence_similarity.py | 49 +++++++++++++++++ .../embeddings/sentence_summarize.py | 53 +++++++++++++++++++ .../preprocessing/tokenize_text.py | 44 +++++++++++---- 4 files changed, 142 insertions(+), 33 deletions(-) create mode 100644 src/beamstack_transforms/embeddings/sentence_similarity.py create mode 100644 src/beamstack_transforms/embeddings/sentence_summarize.py diff --git a/src/beamstack_transforms/embeddings/huggingface.py b/src/beamstack_transforms/embeddings/huggingface.py index b711c58..71c2a2d 100644 --- a/src/beamstack_transforms/embeddings/huggingface.py +++ b/src/beamstack_transforms/embeddings/huggingface.py @@ -1,14 +1,11 @@ import logging - from apache_beam import DoFn, PTransform, ParDo -from beamstack_transforms.utils import import_package, ImportParams, install_package +from sentence_transformers import SentenceTransformer +import numpy as np logger = logging.getLogger(__file__) logging.basicConfig(level=logging.INFO) -REQUIRED_PACKAGES = ["sentence-transformers", "numpy"] - - class CreateEmbeddings(PTransform): def __init__(self, embed_model: str, encode_kwargs: dict = {}, label: str | None = None) -> None: super().__init__(label) @@ -22,34 +19,18 @@ def __init__(self, embed_model, encode_kwargs: dict = {}): self.embed_model = embed_model self.encode_kwargs = encode_kwargs - def start_bundle(self): - try: - install_package(REQUIRED_PACKAGES) - SentenceTransformer, self.np = import_package( - modules=[ - ImportParams( - module="sentence_transformers", - objects=["SentenceTransformer"] - ), - ImportParams( - module="numpy" - ) - ] - ) - except Exception as e: - logger.error(e) - quit() + def setup(self): self.embedder = SentenceTransformer(self.embed_model) def process(self, element): if hasattr(element, '_asdict'): embeddings = {key: self.embedder.encode( - str(value), **self.encode_kwargs).astype(self.np.float32).tolist() + str(value), **self.encode_kwargs).astype(np.float32).tolist() for key, value in element._asdict().items() } else: embeddings = self.embedder.encode( - str(element)).astype(self.np.float32).tolist() + str(element)).astype(np.float32).tolist() yield embeddings return pcol | ParDo(createEmbedding(self.embed_model, self.encode_kwargs)) diff --git a/src/beamstack_transforms/embeddings/sentence_similarity.py b/src/beamstack_transforms/embeddings/sentence_similarity.py new file mode 100644 index 0000000..08aebd0 --- /dev/null +++ b/src/beamstack_transforms/embeddings/sentence_similarity.py @@ -0,0 +1,49 @@ +import apache_beam as beam +from typing import Tuple, List +from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity + +class SentenceSimilarityTransform(beam.PTransform): + def __init__(self, model_name: str = "all-MiniLM-L6-v2"): + """ + Initializes the transform for sentence similarity. + + :param model_name (str): Pre-trained sentence embedding model from Hugging Face. + """ + super().__init__() + self.model_name = model_name + + def expand(self, pcoll): + return ( + pcoll + | "Compute Sentence Similarity" >> beam.ParDo(self._ComputeSimilarityFn(self.model_name)) + ) + + class _ComputeSimilarityFn(beam.DoFn): + def __init__(self, model_name: str): + """ + Initializes the function to compute similarity. + + :param model_name (str): Pre-trained sentence embedding model name. + """ + self.model_name = model_name + self.model = None + + def setup(self): + """Load the sentence embedding model.""" + self.model = SentenceTransformer(self.model_name) + + def process(self, element: Tuple[str, str]): + """ + Computes the similarity between sentences. + + :param element (Tuple[str, str]): A pair of sentences to compare. + :yield (Tuple[str, str, float]): Sentences and their similarity score. + """ + sentences = element + embeddings = self.model.encode(sentences, convert_to_tensor=True) + similarity_matrix = cosine_similarity(embeddings) + + for i in range(len(sentences)): + for j in range(i + 1, len(sentences)): + yield (sentences[i], sentences[j], similarity_matrix[i][j]) \ No newline at end of file diff --git a/src/beamstack_transforms/embeddings/sentence_summarize.py b/src/beamstack_transforms/embeddings/sentence_summarize.py new file mode 100644 index 0000000..b7263dc --- /dev/null +++ b/src/beamstack_transforms/embeddings/sentence_summarize.py @@ -0,0 +1,53 @@ +import apache_beam as beam +from typing import List +from transformers import pipeline + +class SummarizationTransform(beam.PTransform): + def __init__(self, model_name: str, max_length: int = 130, min_length: int = 30): + """ + Initializes the transform for summarization. + + :param model_name (str): The name of the summarization model to use. + :param max_length (int): The maximum length of the generated summary. + :param min_length (int): The minimum length of the generated summary. + """ + super().__init__() + self.model_name = model_name + self.max_length = max_length + self.min_length = min_length + + def expand(self, pcoll): + return pcoll | "Summarize Text" >> beam.ParDo(self._SummarizeTextFn(self.model_name, self.max_length, self.min_length)) + + class _SummarizeTextFn(beam.DoFn): + def __init__(self, model_name: str, max_length: int, min_length: int): + """ + Initializes the function for summarization. + + :param model_name (str): The name of the summarization model. + :param max_length (int): The maximum length of the generated summary. + :param min_length (int): The minimum length of the generated summary. + """ + self.model_name = model_name + self.max_length = max_length + self.min_length = min_length + self.summarizer = None + + def setup(self): + """Load the summarization model.""" + self.summarizer = pipeline("summarization", model=self.model_name) + + def process(self, element: str): + """ + Summarizes a large block of text. + + :param element (str): Input text block. + :yield (str): The generated summary. + """ + summary = self.summarizer( + element, + max_length=self.max_length, + min_length=self.min_length, + do_sample=False + ) + yield summary[0]["summary_text"] \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/tokenize_text.py b/src/beamstack_transforms/preprocessing/tokenize_text.py index 7a836ac..4f8ca07 100644 --- a/src/beamstack_transforms/preprocessing/tokenize_text.py +++ b/src/beamstack_transforms/preprocessing/tokenize_text.py @@ -3,39 +3,65 @@ from typing import List, Optional class TokenizeText(beam.PTransform): - def __init__(self, lowercase: bool = True, custom_delimiters: Optional[List[str]] = None): + def __init__( + self, + lowercase: bool = True, + custom_delimiters: Optional[List[str]] = None, + keep_punctuation: bool = False + ): """ Initializes transform class for tokenizing text. :param lowercase (bool): Whether to lowercase the text before tokenization. :param custom_delimiters (Optional[List[str]]): Additional delimiters for tokenization. + :param keep_punctuation (bool): Whether to keep punctuation as separate tokens. """ super().__init__() self.lowercase = lowercase self.custom_delimiters = custom_delimiters + self.keep_punctuation = keep_punctuation def expand(self, pcoll): - return pcoll | "Tokenize Text" >> beam.ParDo(self._TokenizeTextFn(self.lowercase, self.custom_delimiters)) + return pcoll | "Tokenize Text" >> beam.ParDo( + self._TokenizeTextFn(self.lowercase, self.custom_delimiters, self.keep_punctuation) + ) class _TokenizeTextFn(beam.DoFn): - def __init__(self, lowercase: bool = True, custom_delimiters: Optional[List[str]] = None): + DEFAULT_DELIMITERS = [" ", "\n", "\t", ".", ",", "!", "?", ":", ";", "(", ")", "-", "_"] + + def __init__( + self, + lowercase: bool, + custom_delimiters: Optional[List[str]], + keep_punctuation: bool + ): """ - Initializes transform class for tokenizing text, with optional lowercasing. + Initializes the tokenization function. - :param lowercase (bool): Whether to convert text to lowercase before tokenization. + :param lowercase (bool): Whether to lowercase the text before tokenization. :param custom_delimiters (Optional[List[str]]): Additional delimiters for tokenization. + :param keep_punctuation (bool): Whether to keep punctuation as separate tokens. """ self.lowercase = lowercase - self.custom_delimiters = custom_delimiters or [" ", "\n", "\t", ".", ",", "!", "?"] + self.keep_punctuation = keep_punctuation + self.delimiters = custom_delimiters or self.DEFAULT_DELIMITERS + self.pattern = self._build_regex_pattern() + + def _build_regex_pattern(self) -> re.Pattern: + """ + Builds a compiled regex pattern for tokenization. + """ + if self.keep_punctuation: + return re.compile(r"(\w+|[" + re.escape("".join(self.delimiters)) + r"])") + return re.compile(r"|".join(map(re.escape, self.delimiters))) def process(self, element: str): """ Tokenizes the input text. :param element (str): Input text. - :param yield (List[str]): Tokenized words. + :return: A list of tokenized words. """ text = element.lower() if self.lowercase else element - delimiter_pattern = "|".join(map(re.escape, self.custom_delimiters)) - tokens = re.split(delimiter_pattern, text) + tokens = self.pattern.findall(text) if self.keep_punctuation else re.split(self.pattern, text) yield [token for token in tokens if token] \ No newline at end of file From bddcc756649990d43a5737a0f815e806f64515c7 Mon Sep 17 00:00:00 2001 From: kennedyuche Date: Fri, 17 Jan 2025 11:16:51 -0500 Subject: [PATCH 10/10] added sentence completion transform --- .../embeddings/sentence_completion.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 src/beamstack_transforms/embeddings/sentence_completion.py diff --git a/src/beamstack_transforms/embeddings/sentence_completion.py b/src/beamstack_transforms/embeddings/sentence_completion.py new file mode 100644 index 0000000..3706bcd --- /dev/null +++ b/src/beamstack_transforms/embeddings/sentence_completion.py @@ -0,0 +1,77 @@ +import apache_beam as beam +from typing import Optional +from transformers import pipeline +import openai + +class TextCompletionTransform(beam.PTransform): + def __init__(self, backend: str, model_name: str, max_length: int = 50, openai_api_key: Optional[str] = None): + """ + Initializes the transform for text completion. + + :param backend (str): The backend to use ('huggingface' or 'openai'). + :param model_name (str): The model name to use for text completion. + :param max_length (int): The maximum length of the generated completion. + :param openai_api_key (Optional[str]): The API key for OpenAI (required if backend is 'openai'). + """ + super().__init__() + self.backend = backend.lower() + self.model_name = model_name + self.max_length = max_length + self.openai_api_key = openai_api_key + + if self.backend not in ["huggingface", "openai"]: + raise ValueError("Invalid backend. Choose 'huggingface' or 'openai'.") + + def expand(self, pcoll): + return pcoll | "Generate Text Completions" >> beam.ParDo( + self._GenerateCompletionFn(self.backend, self.model_name, self.max_length, self.openai_api_key) + ) + + class _GenerateCompletionFn(beam.DoFn): + def __init__(self, backend: str, model_name: str, max_length: int, openai_api_key: Optional[str]): + """ + Initializes the function for text completion. + + :param backend (str): The backend to use ('huggingface' or 'openai'). + :param model_name (str): The model name to use. + :param max_length (int): The maximum length of the generated completion. + :param openai_api_key (Optional[str]): The API key for OpenAI (required if backend is 'openai'). + """ + self.backend = backend + self.model_name = model_name + self.max_length = max_length + self.openai_api_key = openai_api_key + self.generator = None + + def setup(self): + """Load the model or initialize API connection based on the backend.""" + if self.backend == "huggingface": + self.generator = pipeline("text-generation", model=self.model_name) + elif self.backend == "openai": + if not self.openai_api_key: + raise ValueError("OpenAI API key must be provided for the OpenAI backend.") + openai.api_key = self.openai_api_key + + def process(self, element: str): + """ + Generates a text completion for the input partial text. + + :param element (str): The partial text to complete. + :yield (str): The completed text. + """ + if self.backend == "huggingface": + completions = self.generator( + element, + max_length=self.max_length, + num_return_sequences=1, + do_sample=True + ) + yield completions[0]["generated_text"] + elif self.backend == "openai": + response = openai.Completion.create( + engine=self.model_name, + prompt=element, + max_tokens=self.max_length, + temperature=0.7 + ) + yield response.choices[0].text.strip() \ No newline at end of file