Source code for hgraph._wiring._map

import inspect
from itertools import chain
from typing import Any, Callable, cast, TYPE_CHECKING, List

from frozendict import frozendict

from hgraph._types._scalar_types import STATE, Size, SCALAR
from hgraph._types._ts_meta_data import HgTSTypeMetaData
from hgraph._types._ts_type import TS
from hgraph._types._tsb_type import TS_SCHEMA, TS_SCHEMA_1, TSB
from hgraph._types._tss_type import TSS
from hgraph._types._ts_type_var_meta_data import HgTsTypeVarTypeMetaData
from hgraph._types._time_series_meta_data import HgTimeSeriesTypeMetaData
from hgraph._types._tsd_meta_data import HgTSDTypeMetaData
from hgraph._types._tsd_type import K
from hgraph._types._tsl_type import TSL, Size
from hgraph._types._type_meta_data import HgTypeMetaData
from hgraph._types._tsl_meta_data import HgTSLTypeMetaData
from hgraph._wiring._decorators import graph, operator
from hgraph._wiring._markers import _PassthroughMarker, _NoKeyMarker
from hgraph._wiring._wiring_node_class._map_wiring_node import (
    TsdMapWiringNodeClass,
    TsdMapWiringSignature,
    TslMapWiringSignature,
    TslMapWiringNodeClass,
)
from hgraph._wiring._wiring_node_class._wiring_node_class import WiringNodeClass
from hgraph._wiring._wiring_node_signature import WiringNodeSignature, WiringNodeType
from hgraph._wiring._wiring_port import WiringPort
from hgraph._wiring._wiring_node_class._wiring_node_class import extract_kwargs
from hgraph._wiring._wiring_context import WiringContext
from hgraph._wiring._wiring_errors import CustomMessageWiringError
from hgraph._wiring._wiring_errors import NoTimeSeriesInputsError
from hgraph._wiring._wiring_utils import stub_wiring_port, as_reference, wire_nested_graph
from hgraph._types._time_series_types import TIME_SERIES_TYPE

if TYPE_CHECKING:
    from hgraph._types._scalar_type_meta_data import HgAtomicType

__all__ = ("map_", "KEYS_ARG")

KEYS_ARG = "__keys__"
_KEY_ARG = "__key_arg__"


[docs] @operator def map_( func: Callable[..., Any], *args: TSB[TS_SCHEMA], __label__: str = None, __keys__: TSS[K] = None, __key_arg__: str = None, **kwargs: TSB[TS_SCHEMA_1] ) -> TIME_SERIES_TYPE: """ Apply a node or lambda element-wise over multiplexed time-series inputs (TSD or TSL). ``map_`` demultiplexes its inputs, calls ``func`` once per key/index, and collects the results into a multiplexed output. The types of ``func``'s parameters are inferred at wiring time from the demultiplexed components of the supplied inputs. :param func: A ``@graph``/``@compute_node`` decorated function or a lambda. When a **TSD** input is supplied, ``func`` receives ``(key, value, ...)`` where ``key`` is the dictionary key and ``value`` is the per-key time-series. When a **TSL** input is supplied, ``func`` receives the per-index time-series. :param args: Time-series inputs to demultiplex and pass to ``func``. :param __label__: Optional label for debugging/tracing. :param kwargs: Named time-series inputs to demultiplex and pass to ``func``. :returns: A ``WiringPort`` representing the multiplexed output. **TSD extensions:** * ``__keys__: TSS[SCALAR]`` — explicit key set for demultiplexing. * ``__key_arg__: str = 'key'`` — name of the parameter that receives the key. * Wrap an input with ``no_key()`` to exclude it from key-set inference. * Wrap an input with ``pass_through()`` to pass it directly without demultiplexing. **Example — map over a TSD:** :: config: TSD[str, TSB[Config]] = ... # func receives (key: TS[str], c: TSB[Config]) per entry map_(lambda key, c: publish_multitable("data", key, random_values(c)), config) **Example — map a decorated node:** :: @compute_node def process(key: TS[str], value: TS[float]) -> TS[float]: ... map_(process, my_tsd) """ ...
@graph(overloads=map_,resolvers={K: lambda m: object}) def map_default( func: Callable[..., Any], *args: TSB[TS_SCHEMA], __label__: str = None, __keys__: TSS[K] = None, __key_arg__: str = None, **kwargs: TSB[TS_SCHEMA_1] ) -> TIME_SERIES_TYPE: args = tuple(i.value if i.is_auto_const else i for i in args.as_dict().values()) if args else () kwargs = {k: v.value if v.is_auto_const else v for k, v in kwargs.as_dict().items()} if kwargs else {} if len(args) + len(kwargs) == 0 and __keys__ is None: raise NoTimeSeriesInputsError() from inspect import isfunction if isinstance(func, WiringNodeClass): with WiringContext(current_signature=STATE(signature=f"map_('{func.signature.signature}', ...)")): signature: WiringNodeSignature = func.signature return _build_and_wire_map(func, signature, *args, **kwargs, __keys__=__keys__, __key_arg__=__key_arg__, __label__=__label__) elif isfunction(func) and func.__name__ == "<lambda>": graph = _deduce_signature_from_lambda_and_args(func, *args, __keys__=__keys__, __key_arg__=__key_arg__, **kwargs) signature: WiringNodeSignature = graph.signature with WiringContext(current_signature=STATE(signature=f"map_('{signature.signature}', ...)")): return _build_and_wire_map(graph, signature, *args, __keys__=__keys__, __key_arg__=__key_arg__, **kwargs, __label__=__label__) else: raise RuntimeError(f"The supplied function is not a graph or node function or lambda: '{func.__name__}'") def _deduce_signature_from_lambda_and_args(func, *args, __keys__=None, __key_arg__=None, **kwargs) -> WiringNodeClass: """ A lambda was provided for map_ so it will not have a signature to be used. This function will try to work out the signature from the names of the lambda arguments and the incoming arguments and their types. The logic here duplicates a little what is found in the _build_map_wiring_node_and_inputs and friends but it is essentially the inside out of it """ from inspect import signature, Parameter sig = signature(func) input_has_key_arg = False input_key_tp = None input_key_name = __key_arg__ or "key" # 1. First figure out what is the type of the keys if __keys__ is not None: key_set = __keys__ else: key_set = next( ( i.key_set for i in chain(args, kwargs.values()) if isinstance(i.output_type.dereference(), HgTSDTypeMetaData) ), None, ) if key_set is None: raise CustomMessageWiringError(f"No multiplexed inputs found when deducing the signature of the lambda {func} defined at {inspect.getfile(func)}:{inspect.getsourcelines(func)[1]} passed into map_") key_type = key_set.output_type.dereference().value_scalar_tp # 2. Put together annotations for the lambda from the parameter types annotations = {} values = {} for i, (n, p) in enumerate(sig.parameters.items()): if p.kind == Parameter.VAR_KEYWORD: raise CustomMessageWiringError("lambdas with variable keyword argument list are not supported for map_()") if p.kind == Parameter.VAR_POSITIONAL: var_args = args[i:] var_types = [] for j, arg in enumerate(var_args): if isinstance(arg, WiringPort): tp = arg.output_type.dereference() if ( isinstance(tp, HgTSDTypeMetaData) and key_type.matches(tp.key_tp) and _PassthroughMarker not in (args[i].markers or ()) ): var_types.append(tp.value_tp) else: var_types.append(tp) else: var_types.append(HgTypeMetaData.parse_type(SCALAR)) var_types = [t for t in var_types if isinstance(t, HgTimeSeriesTypeMetaData)] var_types = [t for t in var_types if all(t.matches(v) for v in var_types)] if len(var_types) == 0: raise CustomMessageWiringError(f"Could not deduce type for {n} from [{','.join(map(str, var_types))}]") annotations[n] = HgTypeMetaData.parse_type(TSL[var_types[0], Size[len(var_args)]]) values[n] = [v for v in var_args] i = len(args) continue if i == 0: if n == input_key_name: # this is the key input input_has_key_arg = True input_key_tp = HgTimeSeriesTypeMetaData.parse_type(TS[key_type.py_type]) annotations[input_key_name] = input_key_tp values[input_key_name] = key_set continue if input_has_key_arg: i -= 1 if i < len(args): # provided as positional and not key if isinstance(args[i], WiringPort): tp = args[i].output_type.dereference() if ( isinstance(tp, HgTSDTypeMetaData) and key_type.matches(tp.key_tp) and _PassthroughMarker not in (args[i].markers or ()) ): annotations[n] = tp.value_tp else: annotations[n] = tp else: annotations[n] = HgTypeMetaData.parse_type(SCALAR) values[n] = args[i] continue if n in kwargs: # provided as keyword if isinstance(kwargs[n], WiringPort): tp = kwargs[n].output_type.dereference() if ( isinstance(tp, HgTSDTypeMetaData) and key_type.matches(tp.key_tp) and _PassthroughMarker not in (kwargs[n].markers or ()) ): annotations[n] = tp.value_tp else: annotations[n] = tp else: annotations[n] = HgTypeMetaData.parse_type(SCALAR) values[n] = kwargs[n] continue raise CustomMessageWiringError(f"no input for the parameter {n} of the lambda passed into map_") if (unused := kwargs.keys() - sig.parameters.keys()) != set(): raise CustomMessageWiringError(f"keyword arguments {unused} are not used in the lambda signature") if i + 1 < len(args): raise CustomMessageWiringError(f"{len(args) - 2} of positional arguments not used in the lambda signature") # 3. now we have annotations for the parameters of the lambda the only way to figure out the output type is to # try to wire it inputs_ = {} for k, v in annotations.items(): if v.is_scalar: inputs_[k] = values[k] else: from hgraph import create_input_stub inputs_[k] = create_input_stub(k, cast(HgTimeSeriesTypeMetaData, v), k == input_key_name) from hgraph import WiringGraphContext from hgraph import with_signature, graph with WiringGraphContext(None, temporary=True) as context: f = graph(with_signature(func, annotations=annotations, return_annotation=TIME_SERIES_TYPE)) out = f(**inputs_) if out is not None: output_type = out.output_type else: output_type = None # 4. Now create a graph with the signature we worked out and return return graph(with_signature(func, annotations=annotations, return_annotation=output_type)) def _build_map_wiring( fn: Callable, signature: WiringNodeSignature, *args, __keys__=None, __key_arg__=None, **kwargs ) -> (WiringNodeClass, dict, list, list): """ Build the maps wiring signature. This will process the inputs looking to work out which are multiplexed inputs, which are pass-through, etc. It will perform basic validation that will ensure the signature of the mapped function and the inputs provided are compatible. """ # 1. See if the first argument of the signature is a key argument. # A key argument has a name of either 'key' (for TSD) or 'ndx' (for TSL) # The key is a TS[SCALAR] for TSD and TS[int] for TSL. input_has_key_arg, input_key_name, input_key_tp = _extract_map_fn_key_arg_and_type(signature, __key_arg__) # 2. Now we can safely extract the kwargs. kwargs_ = extract_kwargs( signature, *args, _ensure_match=False, _args_offset=1 if input_has_key_arg else 0, **kwargs ) if signature.var_arg: # explode var args into individual arguments for i, arg in enumerate(kwargs_[signature.var_arg]): kwargs_[f"{signature.var_arg}-{i}"] = arg kwargs_.pop(signature.var_arg) # 3. Split out the inputs into multiplexed, no_key, pass_through and direct and key_tp multiplex_args, no_key_args, pass_through_args, _, map_type, key_tp_ = _split_inputs(signature, kwargs_, __keys__) # 4. If the key is present, make sure the extracted key type matches what we found in the multiplexed inputs. if map_type == "TSL": tp = HgTSTypeMetaData.parse_type(TS[int]) else: tp = key_tp_ if input_has_key_arg and not input_key_tp.matches(tp): raise CustomMessageWiringError(f"The ndx argument '{signature.args[0]}: {input_key_tp}' does not match '{tp}'") input_key_tp = tp # 5. Extract provided key signature # We use the output_type of wiring ports, but for scalar values, they must take the form of the underlying # function signature, so we just use from that signature. input_types = { k: v.output_type.dereference() if isinstance(v, WiringPort) else signature.input_types[k] for k, v in kwargs_.items() } # 6. Create the wiring nodes for the map function. match map_type: case "TSD": if __keys__ is not None: kwargs_[KEYS_ARG] = __keys__ else: if len(multiplex_args) > 1: from hgraph import union __keys__ = union(*tuple(kwargs_[k].key_set for k in multiplex_args if k not in no_key_args)) else: __keys__ = kwargs_[next(iter(multiplex_args))].key_set kwargs_[KEYS_ARG] = __keys__ input_types = input_types | {KEYS_ARG: __keys__.output_type.dereference()} map_wiring_node, ri = _create_tsd_map_wiring_node( fn, kwargs_, input_types, multiplex_args, no_key_args, input_key_tp, input_key_name if input_has_key_arg else None, ) case "TSL": from hgraph._types._scalar_type_meta_data import HgAtomicType map_wiring_node, ri = _create_tsl_map_signature( fn, kwargs_, input_types, multiplex_args, HgAtomicType.parse_type(key_tp_), input_key_name if input_has_key_arg else None, ) case _: raise CustomMessageWiringError(f"Unable to determine map type for given inputs: {kwargs_}") return map_wiring_node, kwargs_, ri def _build_and_wire_map( fn: Callable, signature: WiringNodeSignature, *args, __keys__=None, __key_arg__=None, __label__=None, **kwargs ) -> WiringPort: map_wiring_node, kwargs_, ri = _build_map_wiring( fn, signature, *args, __keys__=__keys__, __key_arg__=__key_arg__, **kwargs ) port = map_wiring_node(**kwargs_, __return_sink_wp__=True, __label__=__label__) from hgraph import WiringGraphContext WiringGraphContext.instance().reassign_items(ri, port.node_instance) return port if port.output_type else None def _extract_map_fn_key_arg_and_type( signature: WiringNodeSignature, __key_arg__ ) -> tuple[bool, str | None, HgTSTypeMetaData | None]: """ Attempt to detect if the mapping fn has a key argument and if so, what is the type of the key is. """ input_has_key_arg = False input_key_tp = None input_key_name = __key_arg__ if input_key_name == "": # If the user supplied an emtpy string for _key_arg, interpret as ignore any input named as key / ndx # and that no key arg is present. return False, None, None if input_key_name: input_has_key_arg = True if input_key_name != signature.args[0]: raise CustomMessageWiringError( f"The key argument '{input_key_name}' is not the first argument of the function:" f" '{signature.signature}'" ) input_key_tp = signature.input_types[input_key_name] elif signature.args[0] in ("key", "ndx"): from hgraph._types._time_series_meta_data import HgTimeSeriesTypeMetaData input_key_tp = signature.input_types[signature.args[0]] match_tp = None match signature.args[0]: case "key": input_key_name = "key" match_tp = HgTimeSeriesTypeMetaData.parse_type(TS[K]) case "ndx": input_key_name = "ndx" match_tp = HgTimeSeriesTypeMetaData.parse_type(TS[int]) if not (input_has_key_arg := (match_tp and match_tp.matches(signature.input_types[signature.args[0]]))): if match_tp: raise CustomMessageWiringError( f"The key argument '{signature.args[0]}: {signature.input_types[signature.args[0]]}' " f"does not match the expected type: '{match_tp}'" ) return input_has_key_arg, input_key_name, cast(HgTSTypeMetaData, input_key_tp) def _split_inputs( signature: WiringNodeSignature, kwargs_, tsd_keys ) -> tuple[frozenset[str], frozenset[str], frozenset[str], frozenset[str], str, HgTimeSeriesTypeMetaData]: # multiplex, no_key passthrough, direct, tp, key_tp """ Splits out the inputs into three groups: #. multiplex_args: These are the inputs that need to be de-multiplexed. #. no_key_args: These are the inputs that are marked as pass through or no key. #. pass_through_args: These are the inputs that are marked as pass through. #. direct_args: These are the inputs that match the signature of the underlying signature. This will also validate that the inputs are correctly typed if requested to do so, for the map_ function it is useful to by-pass some of the checks are it is really only interested in guessing the correct map type. Key type is only present if validate_type is True. """ if non_ts_inputs := [ arg for arg in kwargs_ if not isinstance(kwargs_[arg], WiringPort) and not arg == signature.var_arg ]: if not all(k in signature.scalar_inputs for k in non_ts_inputs): raise CustomMessageWiringError( f" The following args are not time-series inputs, but should be: {non_ts_inputs}" ) if signature.var_arg and signature.var_arg in kwargs_: for i, arg in enumerate(kwargs_[signature.var_arg]): kwargs_[signature.var_arg + f"-{i}"] = arg kwargs_.pop(signature.var_arg) no_key_args = frozenset(arg for arg in kwargs_ if isinstance(kwargs_[arg], WiringPort) and _NoKeyMarker in (kwargs_[arg].markers or ())) pass_through_args = frozenset(arg for arg in kwargs_ if isinstance(kwargs_[arg], WiringPort) and _PassthroughMarker in (kwargs_[arg].markers or ())) _validate_pass_through(signature, kwargs_, pass_through_args) # Ensure the pass through args are correctly typed. input_types = {k: v.output_type.dereference() for k, v in kwargs_.items() if k not in non_ts_inputs} signature_types = { k: ( signature.input_types[k] if k.split("-")[0] != signature.var_arg else signature.input_types[signature.var_arg].value_tp ) for k in input_types } # Figure out if the map is done over a TSD or TSL by finding the first miltiplexing input map_type = None multiplex_type = None if tsd_keys is not None: # corner case where there are no other inputs but explicitly provided keys map_type = "TSD" multiplex_type = HgTSDTypeMetaData elif input_types: for k, v in input_types.items(): if k not in (no_key_args | pass_through_args) and type(v_tp := v.dereference()) in (HgTSDTypeMetaData, HgTSLTypeMetaData): sig_tp = signature_types[k] if sig_tp.matches(v) and not isinstance(sig_tp, HgTsTypeVarTypeMetaData): # not multiplexing continue if isinstance(v_tp, HgTSDTypeMetaData) and sig_tp.matches(v_tp.value_tp): map_type = "TSD" multiplex_type = HgTSDTypeMetaData break elif isinstance(v_tp, HgTSLTypeMetaData) and sig_tp.matches(v_tp.value_tp): map_type = "TSL" multiplex_type = HgTSLTypeMetaData break else: raise CustomMessageWiringError( f"parameter {k}:{sig_tp} of the mapped graph does not match the input type {v_tp.py_type}" "for either direct match or multiplexing" ) if map_type is None: raise CustomMessageWiringError( f"failed to determine the type of mapping over {signature} with parameters " f"of {','.join(f'{k}:{v}' for k, v in input_types.items())}" ) direct_args = frozenset( k for k, v in input_types.items() if k not in (no_key_args | pass_through_args) and signature_types[k].matches(v) if # (type(signature.input_types[k]) is not HgTsTypeVarTypeMetaData and # All time-series value match this! (type(v) is not multiplex_type) ) # So if it is possibly not direct, don't mark it direct multiplex_args = frozenset( k for k, v in input_types.items() if k not in pass_through_args and k not in direct_args and type(v) is multiplex_type ) _validate_multiplex_types(signature_types, kwargs_, multiplex_args, no_key_args) if not tsd_keys and (len(no_key_args) + len(multiplex_args) == 0): raise CustomMessageWiringError(f"No multiplexed inputs found for TSD map over {signature} with parameters {input_types}") if len(multiplex_args) + len(direct_args) + len(pass_through_args) + len(non_ts_inputs) != len(kwargs_): raise CustomMessageWiringError(f"Unable to determine how to split inputs with args:\n {kwargs_}") if map_type == "TSL": key_tp = _extract_tsl_size(kwargs_, multiplex_args, no_key_args) else: key_tp = _validate_tsd_keys(kwargs_, multiplex_args, no_key_args, tsd_keys) return ( multiplex_args, no_key_args, pass_through_args, direct_args, map_type, key_tp if map_type == "TSL" else HgTSTypeMetaData(key_tp), ) def _prepare_stub_inputs( kwargs_: dict[str, WiringPort | SCALAR], input_types: dict[str, HgTypeMetaData], multiplex_args: frozenset[str], no_key_args: frozenset[str], input_key_tp: HgTSTypeMetaData, input_key_name: str | None, ): call_kwargs = {} for key, arg in input_types.items(): if key in multiplex_args or key in no_key_args: arg: HgTSDTypeMetaData | HgTSLTypeMetaData call_kwargs[key] = stub_wiring_port(arg.value_tp) elif key == KEYS_ARG: continue elif arg.is_scalar: call_kwargs[key] = kwargs_[key] else: call_kwargs[key] = stub_wiring_port(arg) if input_key_name: call_kwargs[input_key_name] = stub_wiring_port(input_key_tp) return call_kwargs def _create_tsd_map_wiring_node( fn: WiringNodeClass, kwargs_: dict[str, WiringPort | SCALAR], input_types: dict[str, HgTypeMetaData], multiplex_args: frozenset[str], no_key_args: frozenset[str], input_key_tp: HgTSTypeMetaData, input_key_name: str | None, ) -> [TsdMapWiringNodeClass, tuple]: # Resolve the mapped function signature stub_inputs = _prepare_stub_inputs(kwargs_, input_types, multiplex_args, no_key_args, input_key_tp, input_key_name) resolved_signature = fn.resolve_signature(**stub_inputs) reference_inputs = frozendict({ k: as_reference(v, k in multiplex_args) if isinstance(v, HgTimeSeriesTypeMetaData) and k != KEYS_ARG else v for k, v in input_types.items() }) # NOTE: The wrapper node does not need to set it valid and tick to that of the underlying node, it just # needs to ensure that it gets notified when the key sets tick. Likewise with validity. # Build provisional signature first so we can pass it in as context into inner graph wiring provisional_signature = WiringNodeSignature( node_type=WiringNodeType.COMPUTE_NODE if resolved_signature.output_type else WiringNodeType.SINK_NODE, name="map", # All actual inputs are encoded in the input_types, so we just need to add the keys if present. args=tuple(input_types.keys()), defaults=frozendict(), # Defaults would have already been applied. input_types=reference_inputs, output_type=( HgTSDTypeMetaData(input_key_tp.value_scalar_tp, resolved_signature.output_type.dereference().as_reference()) if resolved_signature.output_type else None ), src_location=resolved_signature.src_location, # TODO: Figure out something better for this. active_inputs=None, # We will follow a copy approach to transfer the inputs to inner graphs valid_inputs=frozenset({ KEYS_ARG, }), # We have constructed the map so that the key are is always present. all_valid_inputs=None, context_inputs=None, unresolved_args=frozenset(), time_series_args=frozenset(k for k, v in input_types.items() if not v.is_scalar), # label=f"map('{resolved_signature.signature}', {', '.join(input_types.keys())})", has_nested_graphs=True, ) if resolved_signature.var_arg: resolved_input_types = dict(resolved_signature.input_types) for i in range(resolved_signature.input_types[resolved_signature.var_arg].size.SIZE): resolved_input_types[f"{resolved_signature.var_arg}-{i}"] = resolved_signature.input_types[ resolved_signature.var_arg ].value_tp resolved_input_types.pop(resolved_signature.var_arg) else: resolved_input_types = resolved_signature.input_types graph, ri = wire_nested_graph( fn, resolved_input_types, { k: kwargs_[k] for k, v in resolved_input_types.items() if not isinstance(v, HgTimeSeriesTypeMetaData) and k != KEYS_ARG }, provisional_signature, input_key_name, depth=2, ) map_signature = TsdMapWiringSignature( **provisional_signature.as_dict(), map_fn_signature=resolved_signature, key_tp=input_key_tp.value_scalar_tp, key_arg=input_key_name, multiplexed_args=multiplex_args, inner_graph=graph, ) wiring_node = TsdMapWiringNodeClass(map_signature, fn) return wiring_node, ri def _create_tsl_map_signature( fn: WiringNodeClass, kwargs_: dict[str, WiringPort | SCALAR], input_types: dict[str, HgTypeMetaData], multiplex_args: frozenset[str], size_tp: "HgAtomicType", input_key_name: str | None, ): # Resolve the mapped function signature stub_inputs = _prepare_stub_inputs( kwargs_, input_types, multiplex_args, frozenset(), HgTSTypeMetaData.parse_type(TS[int]), input_key_name ) resolved_signature = fn.resolve_signature(**stub_inputs) reference_inputs = frozendict({ k: as_reference(v, k in multiplex_args) if isinstance(v, HgTimeSeriesTypeMetaData) else v for k, v in input_types.items() }) # Build provisional signature first so we can pass it in as context into inner graph wiring provisional_signature = WiringNodeSignature( node_type=WiringNodeType.COMPUTE_NODE if resolved_signature.output_type else WiringNodeType.SINK_NODE, name="map", # All actual inputs are encoded in the input_types, so we just need to add the keys if present. args=tuple(input_types.keys()), defaults=frozendict(), # Defaults would have already been applied. input_types=frozendict(reference_inputs), output_type=( HgTSLTypeMetaData(resolved_signature.output_type.as_reference(), size_tp) if resolved_signature.output_type else None ), src_location=resolved_signature.src_location, # TODO: Figure out something better for this. active_inputs=frozenset(), valid_inputs=frozenset(), all_valid_inputs=None, context_inputs=None, unresolved_args=frozenset(), time_series_args=frozenset(k for k, v in input_types.items() if not v.is_scalar), label=f"map('{resolved_signature.signature}', {', '.join(input_types.keys())})", has_nested_graphs=True, ) graph, reassignables = wire_nested_graph( fn, resolved_signature.input_types, { k: kwargs_[k] for k, v in resolved_signature.input_types.items() if not isinstance(v, HgTimeSeriesTypeMetaData) and k != KEYS_ARG }, provisional_signature, input_key_name, depth=2, ) map_signature = TslMapWiringSignature( **provisional_signature.as_dict(), map_fn_signature=resolved_signature, size_tp=size_tp, key_arg=input_key_name, multiplexed_args=multiplex_args, inner_graph=graph, ) wiring_node = TslMapWiringNodeClass(map_signature, fn) return wiring_node, reassignables def _validate_tsd_keys(kwargs_, multiplex_args, no_key_args, tsd_keys): """ Ensure all the multiplexed inputs use the same input key. """ types = set(kwargs_[arg].output_type.dereference().key_tp for arg in chain(multiplex_args, no_key_args)) if tsd_keys: types.add(tsd_keys.output_type.dereference().value_scalar_tp) if len(types) > 1: raise CustomMessageWiringError(f"The TSD multiplexed inputs have different key types: {types}") return next(iter(types)) def _validate_pass_through(signature: WiringNodeSignature, kwargs_, pass_through_args): """ Validates that the pass-through inputs are valid. """ for arg in pass_through_args: if isinstance(kwargs_[arg], WiringPort) and kwargs_[arg].markers and _PassthroughMarker in kwargs_[arg].markers: if not (in_type := signature.input_types[arg]).matches(kwargs_[arg].output_type.dereference()): raise CustomMessageWiringError( f"The input '{arg}: {kwargs_[arg].output_type}' is marked as pass_through," f"but is not compatible with the input type: {in_type}" ) def _extract_tsl_size(kwargs_: dict[str, WiringPort], multiplex_args, marker_args) -> type[Size]: """ With a TSL multiplexed input, we need to determine the size of the output. This is done by looking at all the inputs that could be multiplexed. """ sizes: List[type[Size]] = [ cast(type[Size], cast(HgTSLTypeMetaData, kwargs_[arg].output_type).size_tp.py_type) for arg in chain( multiplex_args, (m_arg for m_arg in marker_args if not (kwargs_[m_arg].markers and _PassthroughMarker in kwargs_[m_arg].markers)) ) ] size: type[Size] = Size for sz in sizes: if sz.FIXED_SIZE: if size.FIXED_SIZE: size = size if size.SIZE < sz.SIZE else sz else: size = sz return size def _validate_multiplex_types(signature_types, kwargs_, multiplex_args, no_key_args): """ Validates that the multiplexed inputs are valid. """ for arg in chain(multiplex_args, no_key_args): if not (in_type := signature_types[arg].dereference()).matches( (m_type := kwargs_[arg].output_type.dereference()).value_tp ): raise CustomMessageWiringError( f"The input '{arg}: {m_type}' is a multiplexed type, but its value type '{m_type.value_tp}' is not" f" compatible with the input type of the inner graph: {in_type}" )