|
| 1 | +import typing as ty |
| 2 | +import inspect |
| 3 | +import re |
| 4 | +from typing import dataclass_transform |
| 5 | +from . import field |
| 6 | +from .task import Task, Outputs |
| 7 | +from pydra.compose.base import ( |
| 8 | + ensure_field_objects, |
| 9 | + build_task_class, |
| 10 | + check_explicit_fields_are_none, |
| 11 | + extract_fields_from_class, |
| 12 | +) |
| 13 | + |
| 14 | + |
| 15 | +@dataclass_transform( |
| 16 | + kw_only_default=True, |
| 17 | + field_specifiers=(field.arg,), |
| 18 | +) |
| 19 | +def define( |
| 20 | + wrapped: type | ty.Callable | None = None, |
| 21 | + /, |
| 22 | + inputs: list[str | field.arg] | dict[str, field.arg | type] | None = None, |
| 23 | + outputs: list[str | field.out] | dict[str, field.out | type] | type | None = None, |
| 24 | + bases: ty.Sequence[type] = (), |
| 25 | + outputs_bases: ty.Sequence[type] = (), |
| 26 | + auto_attribs: bool = True, |
| 27 | + name: str | None = None, |
| 28 | + xor: ty.Sequence[str | None] | ty.Sequence[ty.Sequence[str | None]] = (), |
| 29 | +) -> "Task": |
| 30 | + """ |
| 31 | + Create an interface for a function or a class. |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | + wrapped : type | callable | None |
| 36 | + The function or class to create an interface for. |
| 37 | + inputs : list[str | Arg] | dict[str, Arg | type] | None |
| 38 | + The inputs to the function or class. |
| 39 | + outputs : list[str | base.Out] | dict[str, base.Out | type] | type | None |
| 40 | + The outputs of the function or class. |
| 41 | + auto_attribs : bool |
| 42 | + Whether to use auto_attribs mode when creating the class. |
| 43 | + name: str | None |
| 44 | + The name of the returned class |
| 45 | + xor: Sequence[str | None] | Sequence[Sequence[str | None]], optional |
| 46 | + Names of args that are exclusive mutually exclusive, which must include |
| 47 | + the name of the current field. If this list includes None, then none of the |
| 48 | + fields need to be set. |
| 49 | +
|
| 50 | + Returns |
| 51 | + ------- |
| 52 | + Task |
| 53 | + The task class for the Python function |
| 54 | + """ |
| 55 | + |
| 56 | + def make(wrapped: ty.Callable | type) -> Task: |
| 57 | + if inspect.isclass(wrapped): |
| 58 | + klass = wrapped |
| 59 | + function = klass.function |
| 60 | + class_name = klass.__name__ |
| 61 | + check_explicit_fields_are_none(klass, inputs, outputs) |
| 62 | + parsed_inputs, parsed_outputs = extract_fields_from_class( |
| 63 | + Task, |
| 64 | + Outputs, |
| 65 | + klass, |
| 66 | + field.arg, |
| 67 | + field.out, |
| 68 | + auto_attribs, |
| 69 | + skip_fields=["function"], |
| 70 | + ) |
| 71 | + else: |
| 72 | + if not isinstance(wrapped, str): |
| 73 | + raise ValueError( |
| 74 | + f"wrapped must be a class or a string containing a MATLAB snipped, not {wrapped!r}" |
| 75 | + ) |
| 76 | + klass = None |
| 77 | + input_helps, output_helps = {}, {} |
| 78 | + |
| 79 | + function_name, inferred_inputs, inferred_outputs = ( |
| 80 | + parse_matlab_function( |
| 81 | + wrapped, |
| 82 | + inputs=inputs, |
| 83 | + outputs=outputs, |
| 84 | + ) |
| 85 | + ) |
| 86 | + |
| 87 | + parsed_inputs, parsed_outputs = ensure_field_objects( |
| 88 | + arg_type=field.arg, |
| 89 | + out_type=field.out, |
| 90 | + inputs=inferred_inputs, |
| 91 | + outputs=inferred_outputs, |
| 92 | + input_helps=input_helps, |
| 93 | + output_helps=output_helps, |
| 94 | + ) |
| 95 | + |
| 96 | + if name: |
| 97 | + class_name = name |
| 98 | + else: |
| 99 | + class_name = function_name |
| 100 | + class_name = re.sub(r"[^\w]", "_", class_name) |
| 101 | + if class_name[0].isdigit(): |
| 102 | + class_name = f"_{class_name}" |
| 103 | + |
| 104 | + # Add in fields from base classes |
| 105 | + parsed_inputs.update({n: getattr(Task, n) for n in Task.BASE_ATTRS}) |
| 106 | + parsed_outputs.update({n: getattr(Outputs, n) for n in Outputs.BASE_ATTRS}) |
| 107 | + |
| 108 | + function = wrapped |
| 109 | + |
| 110 | + parsed_inputs["function"] = field.arg( |
| 111 | + name="function", |
| 112 | + type=str, |
| 113 | + default=function, |
| 114 | + help=Task.FUNCTION_HELP, |
| 115 | + ) |
| 116 | + |
| 117 | + defn = build_task_class( |
| 118 | + Task, |
| 119 | + Outputs, |
| 120 | + parsed_inputs, |
| 121 | + parsed_outputs, |
| 122 | + name=class_name, |
| 123 | + klass=klass, |
| 124 | + bases=bases, |
| 125 | + outputs_bases=outputs_bases, |
| 126 | + xor=xor, |
| 127 | + ) |
| 128 | + |
| 129 | + return defn |
| 130 | + |
| 131 | + if wrapped is not None: |
| 132 | + if not isinstance(wrapped, (str, type)): |
| 133 | + raise ValueError(f"wrapped must be a class or a string, not {wrapped!r}") |
| 134 | + return make(wrapped) |
| 135 | + return make |
| 136 | + |
| 137 | + |
| 138 | +def parse_matlab_function( |
| 139 | + function: str, |
| 140 | + inputs: list[str | field.arg] | dict[str, field.arg | type] | None = None, |
| 141 | + outputs: list[str | field.out] | dict[str, field.out | type] | type | None = None, |
| 142 | +) -> tuple[str, dict[str, field.arg], dict[str, field.out]]: |
| 143 | + """ |
| 144 | + Parse a MATLAB function string to extract inputs and outputs. |
| 145 | +
|
| 146 | + Parameters |
| 147 | + ---------- |
| 148 | + function : str |
| 149 | + The MATLAB function string. |
| 150 | + inputs : list or dict, optional |
| 151 | + The inputs to the function. |
| 152 | + outputs : list or dict, optional |
| 153 | + The outputs of the function. |
| 154 | +
|
| 155 | + Returns |
| 156 | + ------- |
| 157 | + tuple |
| 158 | + A tuple containing the function name, inferred inputs, and inferred outputs. |
| 159 | + """ |
| 160 | + raise NotImplementedError |
0 commit comments