diff --git a/src/beamstack_provider.py b/src/beamstack_provider.py new file mode 100644 index 0000000..5b9dcc1 --- /dev/null +++ b/src/beamstack_provider.py @@ -0,0 +1,255 @@ +import os +import json +import hashlib +import base64 +import subprocess +import sys +import yaml +import apache_beam as beam +from typing import Any, Iterable, Mapping, Optional, Callable +from apache_beam.yaml.yaml_provider import ExternalProvider +import logging +import importlib.util +import urllib.request +from urllib.parse import urlparse + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class BeamstackProviderPathHandler: + def is_local_file(self, path: str) -> bool: + """Check if the path is a local file.""" + return os.path.exists(path) + + def is_github_file(self, path: str) -> bool: + """Check if the path is a GitHub file URL.""" + parsed_url = urlparse(path) + + if parsed_url.netloc == "raw.githubusercontent.com": + return True + if parsed_url.netloc == "github.com": + return True + + return False + + def is_gcs_file(self, path: str) -> bool: + """Check if the path is a Google Cloud Storage URL.""" + return path.startswith('gs://') + + def handle_local_file(self, path: str): + """Handle local file or directory path.""" + logger.info(f"Pulling transforms yaml from: {path}") + if not os.path.exists(path): + raise FileNotFoundError(f"The local path '{path}' does not exist.") + return path + + def handle_github_file(self, file_url: str, target_dir: str): + """Download file from a public GitHub repository to the target path.""" + logger.info(f"Pulling transforms yaml from GitHub url: {file_url}") + + if "github.com" in file_url and "/blob/" in file_url: + file_url = file_url.replace("github.com", "raw.githubusercontent.com").replace("blob/", "") + + os.makedirs(target_dir, exist_ok=True) + + file_name = os.path.basename(file_url) + local_file_path = os.path.join(target_dir, file_name) + + try: + logger.info(f"Downloading {file_name} to {local_file_path}") + urllib.request.urlretrieve(file_url, local_file_path) + except Exception as e: + logger.info(f"Error occured during file download: {e}") + + return local_file_path + + def handle_gcs_file(self, gcs_path: str, target_dir: str): + """Download files from a public GCS bucket to a target path.""" + logger.info(f"Pulling transforms yaml from GCS path: {gcs_path}") + + gcs_path = gcs_path[len("gs://"):] + bucket_name, _, object_name = gcs_path.partition('/') + public_url = f"https://storage.googleapis.com/{bucket_name}/{object_name}" + + os.makedirs(target_dir, exist_ok=True) + local_file_path = os.path.join(target_dir, os.path.basename(object_name)) + + try: + logger.info(f"Downloading {os.path.basename(object_name)} to {target_dir}") + urllib.request.urlretrieve(public_url, local_file_path) + except Exception as e: + logger.info(f"Error downloading file from {public_url}: {e}") + + return local_file_path + +@ExternalProvider.register_provider_type('BeamstackTransform') +def BeamstackTransform(urns, path): + target_dir = '/tmp/beamstack_transforms' + + path_handler = BeamstackProviderPathHandler() + + if path_handler.is_local_file(path): + transform_yaml_path = path_handler.handle_local_file(path) + elif path_handler.is_github_file(path): + transform_yaml_path = path_handler.handle_github_file(path, target_dir) + elif path_handler.is_gcs_file(path): + transform_yaml_path = path_handler.handle_gcs_file(path, target_dir) + else: + raise ValueError(f"Unsupported path type: {path}") + + with open(transform_yaml_path, 'r') as f: + transform_yaml = yaml.safe_load(f) + + config = { + 'urns': urns, + 'yaml_path': transform_yaml_path, + 'dependencies': transform_yaml.get('dependencies', []) + } + + return BeamstackTransformProvider(urns, config) + +class BeamstackTransformProvider(ExternalProvider): + def __init__(self, urns, config): + super().__init__(urns, BeamstackExpansionService(config)) + self.config = config + self.transforms = config.get('urns', {}) + + logger.info(f"Transforms: {self.transforms}") + + def available(self) -> bool: + return True + + def cache_artifacts(self) -> Optional[Iterable[str]]: + return [self._service._venv()] + + def create_transform(self, + typ: str, + args: Mapping[str, Any], + yaml_create_transform: Callable[[Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]) -> Optional[beam.PTransform]: + """Create a PTransform based on decoded source code and configurations.""" + if callable(self._service): + self._service = self._service() + + logger.info(f"Creating transform of type: {typ} with args: {args}") + + transform_class = self._load_transform_class(typ) + + if callable(transform_class): + config_args = args.get('config', {}) + try: + return transform_class(**config_args) + except TypeError as e: + logger.error(f"Error initializing transform '{typ}': {e}") + raise + else: + logger.error(f"{typ} is not a callable transform class.") + + + def _module_class_map(self) -> dict: + """Transform module and class dictionary map""" + self.yaml_path = self.config.get('yaml_path') + + with open(self.yaml_path, 'r') as file: + data = yaml.safe_load(file) + self.transforms = data['transforms'] + + transform_map = {} + for item in self.transforms: + for _, value in item.items(): + module_name, transform_class = value.split(':') + transform_map[transform_class] = module_name + + return transform_map + + def _load_transform_class(self, transform_name): + """Dynamically loads and returns a transform class by name.""" + transform_map = self._module_class_map() + + try: + logger.info(f"Loading transform class for: {transform_name}") + + spec = importlib.util.spec_from_file_location( + f"{transform_map[transform_name]}.py", + os.path.join(self._service._venv_path(), f"{transform_map[transform_name]}.py") + ) + if spec is None: + logger.error(f"Specification for module '{transform_map[transform_name]}' could not be found.") + return None + + module = importlib.util.module_from_spec(spec) + sys.path.insert(0, os.path.dirname(spec.origin)) + spec.loader.exec_module(module) + transform_class = getattr(module, transform_name) + logger.info(f"Loaded transform class: {transform_class}") + return transform_class + except Exception as e: + logger.error(f"Failed to load transform {transform_name}: {e}") + raise e + + +class BeamstackExpansionService: + VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/beamstack_venvs") + + def __init__(self, config): + self.config = config + self.runner = config.get('runner') + self.yaml_path = config.get('yaml_path') + self.base_python = sys.executable + self._packages = config.get('dependencies', []) + self._service = None + + self._load_yaml() + + def _load_yaml(self): + """Loads and decodes the transforms.yaml file.""" + with open(self.yaml_path, 'r') as file: + data = yaml.safe_load(file) + self._packages = data.get('dependencies', []) + self.source_code = data['source_code'] + self.encoding = data['encoding'] + + for module_name, encoded_code in self.source_code.items(): + decoded_code = base64.b64decode(encoded_code).decode('utf-8') + self._write_source_file(f"{module_name}.py", decoded_code) + self._source_module = f"{module_name}.py" + + def _write_source_file(self, src_name, code): + """Writes decoded code to file for each source.""" + venv = self._venv_path() + file_path = os.path.join(venv, src_name) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w') as f: + f.write(code) + + def _venv_path(self): + """Returns the path for the virtual environment directory based on the packages and runner.""" + key = json.dumps({'binary': self.base_python, 'packages': sorted(self._packages), 'runner': self.runner}) + venv_hash = hashlib.sha256(key.encode('utf-8')).hexdigest() + venv = os.path.join(self.VENV_CACHE, venv_hash) + + if not os.path.exists(venv): + subprocess.run([self.base_python, '-m', 'venv', venv], check=True) + + site_packages_path = os.path.join(venv, 'lib', f'python{sys.version_info.major}.{sys.version_info.minor}', 'site-packages') + if site_packages_path not in sys.path: + sys.path.insert(0, site_packages_path) + + venv_pip = os.path.join(venv, 'bin', 'pip') + + if os.path.exists(venv_pip): + installed_packages = subprocess.check_output( + [venv_pip, 'list', '--format=freeze'] + ).decode('utf-8').splitlines() + + installed_packages_set = {pkg.split('==')[0] for pkg in installed_packages} + + for package in self._packages: + if package not in installed_packages_set: + logger.info(f"Installing package: {package}") + subprocess.run([venv_pip, 'install', package], check=True) + else: + logger.info(f"Package '{package}' is already installed; skipping installation.") + else: + raise FileNotFoundError(f"Could not find pip at expected location: {venv_pip}") + + return venv \ No newline at end of file diff --git a/src/beamstack_transforms/embeddings/huggingface.py b/src/beamstack_transforms/embeddings/huggingface.py index b711c58..71c2a2d 100644 --- a/src/beamstack_transforms/embeddings/huggingface.py +++ b/src/beamstack_transforms/embeddings/huggingface.py @@ -1,14 +1,11 @@ import logging - from apache_beam import DoFn, PTransform, ParDo -from beamstack_transforms.utils import import_package, ImportParams, install_package +from sentence_transformers import SentenceTransformer +import numpy as np logger = logging.getLogger(__file__) logging.basicConfig(level=logging.INFO) -REQUIRED_PACKAGES = ["sentence-transformers", "numpy"] - - class CreateEmbeddings(PTransform): def __init__(self, embed_model: str, encode_kwargs: dict = {}, label: str | None = None) -> None: super().__init__(label) @@ -22,34 +19,18 @@ def __init__(self, embed_model, encode_kwargs: dict = {}): self.embed_model = embed_model self.encode_kwargs = encode_kwargs - def start_bundle(self): - try: - install_package(REQUIRED_PACKAGES) - SentenceTransformer, self.np = import_package( - modules=[ - ImportParams( - module="sentence_transformers", - objects=["SentenceTransformer"] - ), - ImportParams( - module="numpy" - ) - ] - ) - except Exception as e: - logger.error(e) - quit() + def setup(self): self.embedder = SentenceTransformer(self.embed_model) def process(self, element): if hasattr(element, '_asdict'): embeddings = {key: self.embedder.encode( - str(value), **self.encode_kwargs).astype(self.np.float32).tolist() + str(value), **self.encode_kwargs).astype(np.float32).tolist() for key, value in element._asdict().items() } else: embeddings = self.embedder.encode( - str(element)).astype(self.np.float32).tolist() + str(element)).astype(np.float32).tolist() yield embeddings return pcol | ParDo(createEmbedding(self.embed_model, self.encode_kwargs)) diff --git a/src/beamstack_transforms/embeddings/sentence_completion.py b/src/beamstack_transforms/embeddings/sentence_completion.py new file mode 100644 index 0000000..3706bcd --- /dev/null +++ b/src/beamstack_transforms/embeddings/sentence_completion.py @@ -0,0 +1,77 @@ +import apache_beam as beam +from typing import Optional +from transformers import pipeline +import openai + +class TextCompletionTransform(beam.PTransform): + def __init__(self, backend: str, model_name: str, max_length: int = 50, openai_api_key: Optional[str] = None): + """ + Initializes the transform for text completion. + + :param backend (str): The backend to use ('huggingface' or 'openai'). + :param model_name (str): The model name to use for text completion. + :param max_length (int): The maximum length of the generated completion. + :param openai_api_key (Optional[str]): The API key for OpenAI (required if backend is 'openai'). + """ + super().__init__() + self.backend = backend.lower() + self.model_name = model_name + self.max_length = max_length + self.openai_api_key = openai_api_key + + if self.backend not in ["huggingface", "openai"]: + raise ValueError("Invalid backend. Choose 'huggingface' or 'openai'.") + + def expand(self, pcoll): + return pcoll | "Generate Text Completions" >> beam.ParDo( + self._GenerateCompletionFn(self.backend, self.model_name, self.max_length, self.openai_api_key) + ) + + class _GenerateCompletionFn(beam.DoFn): + def __init__(self, backend: str, model_name: str, max_length: int, openai_api_key: Optional[str]): + """ + Initializes the function for text completion. + + :param backend (str): The backend to use ('huggingface' or 'openai'). + :param model_name (str): The model name to use. + :param max_length (int): The maximum length of the generated completion. + :param openai_api_key (Optional[str]): The API key for OpenAI (required if backend is 'openai'). + """ + self.backend = backend + self.model_name = model_name + self.max_length = max_length + self.openai_api_key = openai_api_key + self.generator = None + + def setup(self): + """Load the model or initialize API connection based on the backend.""" + if self.backend == "huggingface": + self.generator = pipeline("text-generation", model=self.model_name) + elif self.backend == "openai": + if not self.openai_api_key: + raise ValueError("OpenAI API key must be provided for the OpenAI backend.") + openai.api_key = self.openai_api_key + + def process(self, element: str): + """ + Generates a text completion for the input partial text. + + :param element (str): The partial text to complete. + :yield (str): The completed text. + """ + if self.backend == "huggingface": + completions = self.generator( + element, + max_length=self.max_length, + num_return_sequences=1, + do_sample=True + ) + yield completions[0]["generated_text"] + elif self.backend == "openai": + response = openai.Completion.create( + engine=self.model_name, + prompt=element, + max_tokens=self.max_length, + temperature=0.7 + ) + yield response.choices[0].text.strip() \ No newline at end of file diff --git a/src/beamstack_transforms/embeddings/sentence_similarity.py b/src/beamstack_transforms/embeddings/sentence_similarity.py new file mode 100644 index 0000000..08aebd0 --- /dev/null +++ b/src/beamstack_transforms/embeddings/sentence_similarity.py @@ -0,0 +1,49 @@ +import apache_beam as beam +from typing import Tuple, List +from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity + +class SentenceSimilarityTransform(beam.PTransform): + def __init__(self, model_name: str = "all-MiniLM-L6-v2"): + """ + Initializes the transform for sentence similarity. + + :param model_name (str): Pre-trained sentence embedding model from Hugging Face. + """ + super().__init__() + self.model_name = model_name + + def expand(self, pcoll): + return ( + pcoll + | "Compute Sentence Similarity" >> beam.ParDo(self._ComputeSimilarityFn(self.model_name)) + ) + + class _ComputeSimilarityFn(beam.DoFn): + def __init__(self, model_name: str): + """ + Initializes the function to compute similarity. + + :param model_name (str): Pre-trained sentence embedding model name. + """ + self.model_name = model_name + self.model = None + + def setup(self): + """Load the sentence embedding model.""" + self.model = SentenceTransformer(self.model_name) + + def process(self, element: Tuple[str, str]): + """ + Computes the similarity between sentences. + + :param element (Tuple[str, str]): A pair of sentences to compare. + :yield (Tuple[str, str, float]): Sentences and their similarity score. + """ + sentences = element + embeddings = self.model.encode(sentences, convert_to_tensor=True) + similarity_matrix = cosine_similarity(embeddings) + + for i in range(len(sentences)): + for j in range(i + 1, len(sentences)): + yield (sentences[i], sentences[j], similarity_matrix[i][j]) \ No newline at end of file diff --git a/src/beamstack_transforms/embeddings/sentence_summarize.py b/src/beamstack_transforms/embeddings/sentence_summarize.py new file mode 100644 index 0000000..b7263dc --- /dev/null +++ b/src/beamstack_transforms/embeddings/sentence_summarize.py @@ -0,0 +1,53 @@ +import apache_beam as beam +from typing import List +from transformers import pipeline + +class SummarizationTransform(beam.PTransform): + def __init__(self, model_name: str, max_length: int = 130, min_length: int = 30): + """ + Initializes the transform for summarization. + + :param model_name (str): The name of the summarization model to use. + :param max_length (int): The maximum length of the generated summary. + :param min_length (int): The minimum length of the generated summary. + """ + super().__init__() + self.model_name = model_name + self.max_length = max_length + self.min_length = min_length + + def expand(self, pcoll): + return pcoll | "Summarize Text" >> beam.ParDo(self._SummarizeTextFn(self.model_name, self.max_length, self.min_length)) + + class _SummarizeTextFn(beam.DoFn): + def __init__(self, model_name: str, max_length: int, min_length: int): + """ + Initializes the function for summarization. + + :param model_name (str): The name of the summarization model. + :param max_length (int): The maximum length of the generated summary. + :param min_length (int): The minimum length of the generated summary. + """ + self.model_name = model_name + self.max_length = max_length + self.min_length = min_length + self.summarizer = None + + def setup(self): + """Load the summarization model.""" + self.summarizer = pipeline("summarization", model=self.model_name) + + def process(self, element: str): + """ + Summarizes a large block of text. + + :param element (str): Input text block. + :yield (str): The generated summary. + """ + summary = self.summarizer( + element, + max_length=self.max_length, + min_length=self.min_length, + do_sample=False + ) + yield summary[0]["summary_text"] \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/augment_text.py b/src/beamstack_transforms/preprocessing/augment_text.py new file mode 100644 index 0000000..409ae8b --- /dev/null +++ b/src/beamstack_transforms/preprocessing/augment_text.py @@ -0,0 +1,79 @@ +import apache_beam as beam +import random +from typing import List, Dict, Any +from nltk.corpus import wordnet + +class TextAugmentation(beam.PTransform): + def __init__(self, techniques: List[str], augment_factor: int = 1): + """ + Initializes transform class for augmenting text data using specified techniques. + + :param techniques (List[str]): List of augmentation techniques to apply (e.g., 'synonym_replacement', 'back_translation'). + :param augment_factor (int): Number of augmented examples to generate per input. Default is 1. + """ + super().__init__() + self.techniques = techniques + self.augment_factor = augment_factor + + def expand(self, pcoll): + return pcoll | "Augment Text" >> beam.ParDo( + self._TextAugmentationFn(self.techniques, self.augment_factor) + ) + + class _TextAugmentationFn(beam.DoFn): + def __init__(self, techniques: List[str], augment_factor: int): + """ + A DoFn for applying text augmentation techniques. + + :param techniques (List[str]): List of techniques for augmentation. + :param augment_factor (int): Number of augmented examples to generate per input. + """ + self.techniques = techniques + self.augment_factor = augment_factor + + def process(self, element: Dict[str, Any]): + """ + Augments the input text using specified techniques. + + Args: + :param element (Dict[str, Any]): Input dictionary containing the text to augment. + :param yield (Dict[str, Any]): Augmented examples. + """ + text = element.get("text", "") + for _ in range(self.augment_factor): + augmented_text = self._apply_augmentation(text) + augmented_element = element.copy() + augmented_element["text"] = augmented_text + yield augmented_element + + def _apply_augmentation(self, text: str) -> str: + """ + Applies augmentation techniques to the input text. + + :param text (str): Original text. + + Returns: + str: Augmented text. + """ + if "synonym_replacement" in self.techniques: + text = self._synonym_replacement(text) + # Additional techniques can be added here. + return text + + def _synonym_replacement(self, text: str) -> str: + """ + Replaces random words with synonyms. + + Args: + text (str): Original text. + + Returns: + str: Text with synonyms replaced. + """ + words = text.split() + for i, word in enumerate(words): + synonyms = wordnet.synsets(word) + if synonyms: + synonym = random.choice(synonyms).lemmas()[0].name() + words[i] = synonym + return " ".join(words) \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/clean_text.py b/src/beamstack_transforms/preprocessing/clean_text.py new file mode 100644 index 0000000..4be02d9 --- /dev/null +++ b/src/beamstack_transforms/preprocessing/clean_text.py @@ -0,0 +1,69 @@ +import apache_beam as beam +import re +from typing import List +from nltk.corpus import stopwords +import nltk + +nltk.download('stopwords') + +class CleanText(beam.PTransform): + def __init__(self, stop_words: List[str] = None, additional_patterns: List[str] = None): + """ + Initializes the Transform class for cleaning texts. + + :param stop_words: List of custom stop words to add to the default list. + :param additional_patterns: List of regex patterns to remove from text. + """ + super().__init__() + self.stop_words = stop_words + self.additional_patterns = additional_patterns + + def expand(self, pcoll): + return ( + pcoll + | "Remove Patterns" >> beam.ParDo(self._RemovePatternsFn(self.additional_patterns)) + | "Remove Stop Words" >> beam.ParDo(self._RemoveStopWordsFn(self.stop_words)) + ) + + class _RemovePatternsFn(beam.DoFn): + def __init__(self, additional_patterns: List[str] = None): + """ + Initializes the class to remove regex patterns from text. + + :param additional_patterns: List of regex patterns to remove from text. + """ + self.additional_patterns = additional_patterns if additional_patterns else [] + + def process(self, element: str): + """ + Removes specified regex patterns from the text. + + :param element: Input text. + :yield: Text with patterns removed. + """ + text = re.sub(r'[^a-zA-Z0-9\s]', '', element) # Remove non-alphanumeric characters. + for pattern in self.additional_patterns: + text = re.sub(pattern, '', text) + yield text + + class _RemoveStopWordsFn(beam.DoFn): + def __init__(self, stop_words: List[str] = None): + """ + Initializes the transform class for removing stop words from text. + + :param stop_words: List of custom stop words to add to the default list. + """ + nltk_stop_words = set(stopwords.words('english')) + self.stop_words = nltk_stop_words.union(set(stop_words)) if stop_words else nltk_stop_words + + def process(self, element: str): + """ + Removes stop words from the text. + + :param element: Input text. + :yield: Text without stop words. + """ + text = element.lower() + words = text.split() + text = ' '.join(word for word in words if word not in self.stop_words) + yield text \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/detect_text_lang.py b/src/beamstack_transforms/preprocessing/detect_text_lang.py new file mode 100644 index 0000000..7c89dd4 --- /dev/null +++ b/src/beamstack_transforms/preprocessing/detect_text_lang.py @@ -0,0 +1,41 @@ +import apache_beam as beam +from langdetect import detect +from typing import Any, Optional + +class LanguageDetection(beam.PTransform): + def __init__(self, output_key: Optional[str] = None): + """ + Initializes the transform class for detecting text language. + + :param output_key (Optional[str]): Key to store the detected language in the output element. + """ + super().__init__() + self.output_key = output_key + + def expand(self, pcoll): + return pcoll | "Detect Language" >> beam.ParDo(self._DetectLanguageFn(self.output_key)) + + class _DetectLanguageFn(beam.DoFn): + def __init__(self, output_key: Optional[str]): + """ + Initializes class for detecting the language of input text. + + :param output_key (Optional[str]): Key to store the detected language in the output element. + """ + self.output_key = output_key + + def process(self, element: Any): + """ + Detects the language of the input text. + + :param element: Input text element. Can be plain text or a dictionary containing text. + :param yield: Output text with detected language. + """ + text = element if isinstance(element, str) else element.get("text", "") + detected_language = detect(text) + + if self.output_key: + element[self.output_key] = detected_language + yield element + else: + yield {"text": text, "language": detected_language} \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/normalize_tokens.py b/src/beamstack_transforms/preprocessing/normalize_tokens.py new file mode 100644 index 0000000..a206599 --- /dev/null +++ b/src/beamstack_transforms/preprocessing/normalize_tokens.py @@ -0,0 +1,57 @@ +import apache_beam as beam +from nltk.stem import PorterStemmer +from nltk.stem.wordnet import WordNetLemmatizer +from typing import List +import nltk + +nltk.download('wordnet') +nltk.download('omw-1.4') + + +class NormalizeTokens(beam.PTransform): + def __init__(self, lemmatize: bool = True, stem: bool = False): + """ + Initializes transform class for normalizing tokens. + + :param lemmatize (bool): Whether to apply lemmatization. Default is True. + :param stem (bool): Whether to apply stemming. Default is False. + """ + super().__init__() + self.lemmatize = lemmatize + self.stem = stem + + def expand(self, pcoll): + return pcoll | "Normalize Tokens" >> beam.ParDo( + self._NormalizeTokensFn(lemmatize=self.lemmatize, stem=self.stem) + ) + + class _NormalizeTokensFn(beam.DoFn): + def __init__(self, lemmatize: bool, stem: bool): + """ + Initialize class for normalizing tokens using lemmatization and/or stemming. + + :param lemmatize (bool): Whether to apply lemmatization. + :param stem (bool): Whether to apply stemming. + """ + self.lemmatize = lemmatize + self.stem = stem + self.lemmatizer = WordNetLemmatizer() if lemmatize else None + self.stemmer = PorterStemmer() if stem else None + + def process(self, element: List[str]): + """ + Normalizes tokens in the input element. + + Args: + :param element (List[str]): List of tokens. + :param yield (List[str]): Normalized tokens. + """ + normalized_tokens = [] + for token in element: + word = token + if self.lemmatize: + word = self.lemmatizer.lemmatize(word) + if self.stem: + word = self.stemmer.stem(word) + normalized_tokens.append(word) + yield normalized_tokens \ No newline at end of file diff --git a/src/beamstack_transforms/preprocessing/tokenize_text.py b/src/beamstack_transforms/preprocessing/tokenize_text.py new file mode 100644 index 0000000..4f8ca07 --- /dev/null +++ b/src/beamstack_transforms/preprocessing/tokenize_text.py @@ -0,0 +1,67 @@ +import re +import apache_beam as beam +from typing import List, Optional + +class TokenizeText(beam.PTransform): + def __init__( + self, + lowercase: bool = True, + custom_delimiters: Optional[List[str]] = None, + keep_punctuation: bool = False + ): + """ + Initializes transform class for tokenizing text. + + :param lowercase (bool): Whether to lowercase the text before tokenization. + :param custom_delimiters (Optional[List[str]]): Additional delimiters for tokenization. + :param keep_punctuation (bool): Whether to keep punctuation as separate tokens. + """ + super().__init__() + self.lowercase = lowercase + self.custom_delimiters = custom_delimiters + self.keep_punctuation = keep_punctuation + + def expand(self, pcoll): + return pcoll | "Tokenize Text" >> beam.ParDo( + self._TokenizeTextFn(self.lowercase, self.custom_delimiters, self.keep_punctuation) + ) + + class _TokenizeTextFn(beam.DoFn): + DEFAULT_DELIMITERS = [" ", "\n", "\t", ".", ",", "!", "?", ":", ";", "(", ")", "-", "_"] + + def __init__( + self, + lowercase: bool, + custom_delimiters: Optional[List[str]], + keep_punctuation: bool + ): + """ + Initializes the tokenization function. + + :param lowercase (bool): Whether to lowercase the text before tokenization. + :param custom_delimiters (Optional[List[str]]): Additional delimiters for tokenization. + :param keep_punctuation (bool): Whether to keep punctuation as separate tokens. + """ + self.lowercase = lowercase + self.keep_punctuation = keep_punctuation + self.delimiters = custom_delimiters or self.DEFAULT_DELIMITERS + self.pattern = self._build_regex_pattern() + + def _build_regex_pattern(self) -> re.Pattern: + """ + Builds a compiled regex pattern for tokenization. + """ + if self.keep_punctuation: + return re.compile(r"(\w+|[" + re.escape("".join(self.delimiters)) + r"])") + return re.compile(r"|".join(map(re.escape, self.delimiters))) + + def process(self, element: str): + """ + Tokenizes the input text. + + :param element (str): Input text. + :return: A list of tokenized words. + """ + text = element.lower() if self.lowercase else element + tokens = self.pattern.findall(text) if self.keep_punctuation else re.split(self.pattern, text) + yield [token for token in tokens if token] \ No newline at end of file