import inspect
from itertools import chain
from typing import Any, Callable, cast, TYPE_CHECKING, List
from frozendict import frozendict
from hgraph._types._scalar_types import STATE, Size, SCALAR
from hgraph._types._ts_meta_data import HgTSTypeMetaData
from hgraph._types._ts_type import TS
from hgraph._types._tsb_type import TS_SCHEMA, TS_SCHEMA_1, TSB
from hgraph._types._tss_type import TSS
from hgraph._types._ts_type_var_meta_data import HgTsTypeVarTypeMetaData
from hgraph._types._time_series_meta_data import HgTimeSeriesTypeMetaData
from hgraph._types._tsd_meta_data import HgTSDTypeMetaData
from hgraph._types._tsd_type import K
from hgraph._types._tsl_type import TSL, Size
from hgraph._types._type_meta_data import HgTypeMetaData
from hgraph._types._tsl_meta_data import HgTSLTypeMetaData
from hgraph._wiring._decorators import graph, operator
from hgraph._wiring._markers import _PassthroughMarker, _NoKeyMarker
from hgraph._wiring._wiring_node_class._map_wiring_node import (
TsdMapWiringNodeClass,
TsdMapWiringSignature,
TslMapWiringSignature,
TslMapWiringNodeClass,
)
from hgraph._wiring._wiring_node_class._wiring_node_class import WiringNodeClass
from hgraph._wiring._wiring_node_signature import WiringNodeSignature, WiringNodeType
from hgraph._wiring._wiring_port import WiringPort
from hgraph._wiring._wiring_node_class._wiring_node_class import extract_kwargs
from hgraph._wiring._wiring_context import WiringContext
from hgraph._wiring._wiring_errors import CustomMessageWiringError
from hgraph._wiring._wiring_errors import NoTimeSeriesInputsError
from hgraph._wiring._wiring_utils import stub_wiring_port, as_reference, wire_nested_graph
from hgraph._types._time_series_types import TIME_SERIES_TYPE
if TYPE_CHECKING:
from hgraph._types._scalar_type_meta_data import HgAtomicType
__all__ = ("map_", "KEYS_ARG")
KEYS_ARG = "__keys__"
_KEY_ARG = "__key_arg__"
[docs]
@operator
def map_(
func: Callable[..., Any],
*args: TSB[TS_SCHEMA],
__label__: str = None,
__keys__: TSS[K] = None,
__key_arg__: str = None,
**kwargs: TSB[TS_SCHEMA_1]
) -> TIME_SERIES_TYPE:
"""
Apply a node or lambda element-wise over multiplexed time-series inputs (TSD or TSL).
``map_`` demultiplexes its inputs, calls ``func`` once per key/index, and collects
the results into a multiplexed output. The types of ``func``'s parameters are
inferred at wiring time from the demultiplexed components of the supplied inputs.
:param func: A ``@graph``/``@compute_node`` decorated function or a lambda.
When a **TSD** input is supplied, ``func`` receives ``(key, value, ...)`` where
``key`` is the dictionary key and ``value`` is the per-key time-series.
When a **TSL** input is supplied, ``func`` receives the per-index time-series.
:param args: Time-series inputs to demultiplex and pass to ``func``.
:param __label__: Optional label for debugging/tracing.
:param kwargs: Named time-series inputs to demultiplex and pass to ``func``.
:returns: A ``WiringPort`` representing the multiplexed output.
**TSD extensions:**
* ``__keys__: TSS[SCALAR]`` — explicit key set for demultiplexing.
* ``__key_arg__: str = 'key'`` — name of the parameter that receives the key.
* Wrap an input with ``no_key()`` to exclude it from key-set inference.
* Wrap an input with ``pass_through()`` to pass it directly without demultiplexing.
**Example — map over a TSD:**
::
config: TSD[str, TSB[Config]] = ...
# func receives (key: TS[str], c: TSB[Config]) per entry
map_(lambda key, c: publish_multitable("data", key, random_values(c)), config)
**Example — map a decorated node:**
::
@compute_node
def process(key: TS[str], value: TS[float]) -> TS[float]: ...
map_(process, my_tsd)
"""
...
@graph(overloads=map_,resolvers={K: lambda m: object})
def map_default(
func: Callable[..., Any],
*args: TSB[TS_SCHEMA],
__label__: str = None,
__keys__: TSS[K] = None,
__key_arg__: str = None,
**kwargs: TSB[TS_SCHEMA_1]
) -> TIME_SERIES_TYPE:
args = tuple(i.value if i.is_auto_const else i for i in args.as_dict().values()) if args else ()
kwargs = {k: v.value if v.is_auto_const else v for k, v in kwargs.as_dict().items()} if kwargs else {}
if len(args) + len(kwargs) == 0 and __keys__ is None:
raise NoTimeSeriesInputsError()
from inspect import isfunction
if isinstance(func, WiringNodeClass):
with WiringContext(current_signature=STATE(signature=f"map_('{func.signature.signature}', ...)")):
signature: WiringNodeSignature = func.signature
return _build_and_wire_map(func, signature, *args, **kwargs, __keys__=__keys__, __key_arg__=__key_arg__, __label__=__label__)
elif isfunction(func) and func.__name__ == "<lambda>":
graph = _deduce_signature_from_lambda_and_args(func, *args, __keys__=__keys__, __key_arg__=__key_arg__, **kwargs)
signature: WiringNodeSignature = graph.signature
with WiringContext(current_signature=STATE(signature=f"map_('{signature.signature}', ...)")):
return _build_and_wire_map(graph, signature, *args, __keys__=__keys__, __key_arg__=__key_arg__, **kwargs, __label__=__label__)
else:
raise RuntimeError(f"The supplied function is not a graph or node function or lambda: '{func.__name__}'")
def _deduce_signature_from_lambda_and_args(func, *args, __keys__=None, __key_arg__=None, **kwargs) -> WiringNodeClass:
"""
A lambda was provided for map_ so it will not have a signature to be used. This function will try to work out the
signature from the names of the lambda arguments and the incoming arguments and their types. The logic here
duplicates a little what is found in the _build_map_wiring_node_and_inputs and friends but it is essentially the
inside out of it
"""
from inspect import signature, Parameter
sig = signature(func)
input_has_key_arg = False
input_key_tp = None
input_key_name = __key_arg__ or "key"
# 1. First figure out what is the type of the keys
if __keys__ is not None:
key_set = __keys__
else:
key_set = next(
(
i.key_set
for i in chain(args, kwargs.values())
if isinstance(i.output_type.dereference(), HgTSDTypeMetaData)
),
None,
)
if key_set is None:
raise CustomMessageWiringError(f"No multiplexed inputs found when deducing the signature of the lambda {func} defined at {inspect.getfile(func)}:{inspect.getsourcelines(func)[1]} passed into map_")
key_type = key_set.output_type.dereference().value_scalar_tp
# 2. Put together annotations for the lambda from the parameter types
annotations = {}
values = {}
for i, (n, p) in enumerate(sig.parameters.items()):
if p.kind == Parameter.VAR_KEYWORD:
raise CustomMessageWiringError("lambdas with variable keyword argument list are not supported for map_()")
if p.kind == Parameter.VAR_POSITIONAL:
var_args = args[i:]
var_types = []
for j, arg in enumerate(var_args):
if isinstance(arg, WiringPort):
tp = arg.output_type.dereference()
if (
isinstance(tp, HgTSDTypeMetaData)
and key_type.matches(tp.key_tp)
and _PassthroughMarker not in (args[i].markers or ())
):
var_types.append(tp.value_tp)
else:
var_types.append(tp)
else:
var_types.append(HgTypeMetaData.parse_type(SCALAR))
var_types = [t for t in var_types if isinstance(t, HgTimeSeriesTypeMetaData)]
var_types = [t for t in var_types if all(t.matches(v) for v in var_types)]
if len(var_types) == 0:
raise CustomMessageWiringError(f"Could not deduce type for {n} from [{','.join(map(str, var_types))}]")
annotations[n] = HgTypeMetaData.parse_type(TSL[var_types[0], Size[len(var_args)]])
values[n] = [v for v in var_args]
i = len(args)
continue
if i == 0:
if n == input_key_name: # this is the key input
input_has_key_arg = True
input_key_tp = HgTimeSeriesTypeMetaData.parse_type(TS[key_type.py_type])
annotations[input_key_name] = input_key_tp
values[input_key_name] = key_set
continue
if input_has_key_arg:
i -= 1
if i < len(args): # provided as positional and not key
if isinstance(args[i], WiringPort):
tp = args[i].output_type.dereference()
if (
isinstance(tp, HgTSDTypeMetaData)
and key_type.matches(tp.key_tp)
and _PassthroughMarker not in (args[i].markers or ())
):
annotations[n] = tp.value_tp
else:
annotations[n] = tp
else:
annotations[n] = HgTypeMetaData.parse_type(SCALAR)
values[n] = args[i]
continue
if n in kwargs: # provided as keyword
if isinstance(kwargs[n], WiringPort):
tp = kwargs[n].output_type.dereference()
if (
isinstance(tp, HgTSDTypeMetaData)
and key_type.matches(tp.key_tp)
and _PassthroughMarker not in (kwargs[n].markers or ())
):
annotations[n] = tp.value_tp
else:
annotations[n] = tp
else:
annotations[n] = HgTypeMetaData.parse_type(SCALAR)
values[n] = kwargs[n]
continue
raise CustomMessageWiringError(f"no input for the parameter {n} of the lambda passed into map_")
if (unused := kwargs.keys() - sig.parameters.keys()) != set():
raise CustomMessageWiringError(f"keyword arguments {unused} are not used in the lambda signature")
if i + 1 < len(args):
raise CustomMessageWiringError(f"{len(args) - 2} of positional arguments not used in the lambda signature")
# 3. now we have annotations for the parameters of the lambda the only way to figure out the output type is to
# try to wire it
inputs_ = {}
for k, v in annotations.items():
if v.is_scalar:
inputs_[k] = values[k]
else:
from hgraph import create_input_stub
inputs_[k] = create_input_stub(k, cast(HgTimeSeriesTypeMetaData, v), k == input_key_name)
from hgraph import WiringGraphContext
from hgraph import with_signature, graph
with WiringGraphContext(None, temporary=True) as context:
f = graph(with_signature(func, annotations=annotations, return_annotation=TIME_SERIES_TYPE))
out = f(**inputs_)
if out is not None:
output_type = out.output_type
else:
output_type = None
# 4. Now create a graph with the signature we worked out and return
return graph(with_signature(func, annotations=annotations, return_annotation=output_type))
def _build_map_wiring(
fn: Callable, signature: WiringNodeSignature, *args, __keys__=None, __key_arg__=None, **kwargs
) -> (WiringNodeClass, dict, list, list):
"""
Build the maps wiring signature. This will process the inputs looking to work out which are multiplexed inputs,
which are pass-through, etc. It will perform basic validation that will ensure the signature of the mapped function
and the inputs provided are compatible.
"""
# 1. See if the first argument of the signature is a key argument.
# A key argument has a name of either 'key' (for TSD) or 'ndx' (for TSL)
# The key is a TS[SCALAR] for TSD and TS[int] for TSL.
input_has_key_arg, input_key_name, input_key_tp = _extract_map_fn_key_arg_and_type(signature, __key_arg__)
# 2. Now we can safely extract the kwargs.
kwargs_ = extract_kwargs(
signature, *args, _ensure_match=False, _args_offset=1 if input_has_key_arg else 0, **kwargs
)
if signature.var_arg: # explode var args into individual arguments
for i, arg in enumerate(kwargs_[signature.var_arg]):
kwargs_[f"{signature.var_arg}-{i}"] = arg
kwargs_.pop(signature.var_arg)
# 3. Split out the inputs into multiplexed, no_key, pass_through and direct and key_tp
multiplex_args, no_key_args, pass_through_args, _, map_type, key_tp_ = _split_inputs(signature, kwargs_, __keys__)
# 4. If the key is present, make sure the extracted key type matches what we found in the multiplexed inputs.
if map_type == "TSL":
tp = HgTSTypeMetaData.parse_type(TS[int])
else:
tp = key_tp_
if input_has_key_arg and not input_key_tp.matches(tp):
raise CustomMessageWiringError(f"The ndx argument '{signature.args[0]}: {input_key_tp}' does not match '{tp}'")
input_key_tp = tp
# 5. Extract provided key signature
# We use the output_type of wiring ports, but for scalar values, they must take the form of the underlying
# function signature, so we just use from that signature.
input_types = {
k: v.output_type.dereference() if isinstance(v, WiringPort) else signature.input_types[k]
for k, v in kwargs_.items()
}
# 6. Create the wiring nodes for the map function.
match map_type:
case "TSD":
if __keys__ is not None:
kwargs_[KEYS_ARG] = __keys__
else:
if len(multiplex_args) > 1:
from hgraph import union
__keys__ = union(*tuple(kwargs_[k].key_set for k in multiplex_args if k not in no_key_args))
else:
__keys__ = kwargs_[next(iter(multiplex_args))].key_set
kwargs_[KEYS_ARG] = __keys__
input_types = input_types | {KEYS_ARG: __keys__.output_type.dereference()}
map_wiring_node, ri = _create_tsd_map_wiring_node(
fn,
kwargs_,
input_types,
multiplex_args,
no_key_args,
input_key_tp,
input_key_name if input_has_key_arg else None,
)
case "TSL":
from hgraph._types._scalar_type_meta_data import HgAtomicType
map_wiring_node, ri = _create_tsl_map_signature(
fn,
kwargs_,
input_types,
multiplex_args,
HgAtomicType.parse_type(key_tp_),
input_key_name if input_has_key_arg else None,
)
case _:
raise CustomMessageWiringError(f"Unable to determine map type for given inputs: {kwargs_}")
return map_wiring_node, kwargs_, ri
def _build_and_wire_map(
fn: Callable, signature: WiringNodeSignature, *args, __keys__=None, __key_arg__=None, __label__=None, **kwargs
) -> WiringPort:
map_wiring_node, kwargs_, ri = _build_map_wiring(
fn, signature, *args, __keys__=__keys__, __key_arg__=__key_arg__, **kwargs
)
port = map_wiring_node(**kwargs_, __return_sink_wp__=True, __label__=__label__)
from hgraph import WiringGraphContext
WiringGraphContext.instance().reassign_items(ri, port.node_instance)
return port if port.output_type else None
def _extract_map_fn_key_arg_and_type(
signature: WiringNodeSignature, __key_arg__
) -> tuple[bool, str | None, HgTSTypeMetaData | None]:
"""
Attempt to detect if the mapping fn has a key argument and if so, what is the type of the key is.
"""
input_has_key_arg = False
input_key_tp = None
input_key_name = __key_arg__
if input_key_name == "":
# If the user supplied an emtpy string for _key_arg, interpret as ignore any input named as key / ndx
# and that no key arg is present.
return False, None, None
if input_key_name:
input_has_key_arg = True
if input_key_name != signature.args[0]:
raise CustomMessageWiringError(
f"The key argument '{input_key_name}' is not the first argument of the function:"
f" '{signature.signature}'"
)
input_key_tp = signature.input_types[input_key_name]
elif signature.args[0] in ("key", "ndx"):
from hgraph._types._time_series_meta_data import HgTimeSeriesTypeMetaData
input_key_tp = signature.input_types[signature.args[0]]
match_tp = None
match signature.args[0]:
case "key":
input_key_name = "key"
match_tp = HgTimeSeriesTypeMetaData.parse_type(TS[K])
case "ndx":
input_key_name = "ndx"
match_tp = HgTimeSeriesTypeMetaData.parse_type(TS[int])
if not (input_has_key_arg := (match_tp and match_tp.matches(signature.input_types[signature.args[0]]))):
if match_tp:
raise CustomMessageWiringError(
f"The key argument '{signature.args[0]}: {signature.input_types[signature.args[0]]}' "
f"does not match the expected type: '{match_tp}'"
)
return input_has_key_arg, input_key_name, cast(HgTSTypeMetaData, input_key_tp)
def _split_inputs(
signature: WiringNodeSignature, kwargs_, tsd_keys
) -> tuple[frozenset[str], frozenset[str], frozenset[str], frozenset[str], str, HgTimeSeriesTypeMetaData]:
# multiplex, no_key passthrough, direct, tp, key_tp
"""
Splits out the inputs into three groups:
#. multiplex_args: These are the inputs that need to be de-multiplexed.
#. no_key_args: These are the inputs that are marked as pass through or no key.
#. pass_through_args: These are the inputs that are marked as pass through.
#. direct_args: These are the inputs that match the signature of the underlying signature.
This will also validate that the inputs are correctly typed if requested to do so, for the map_ function
it is useful to by-pass some of the checks are it is really only interested in guessing the correct map type.
Key type is only present if validate_type is True.
"""
if non_ts_inputs := [
arg for arg in kwargs_ if not isinstance(kwargs_[arg], WiringPort) and not arg == signature.var_arg
]:
if not all(k in signature.scalar_inputs for k in non_ts_inputs):
raise CustomMessageWiringError(
f" The following args are not time-series inputs, but should be: {non_ts_inputs}"
)
if signature.var_arg and signature.var_arg in kwargs_:
for i, arg in enumerate(kwargs_[signature.var_arg]):
kwargs_[signature.var_arg + f"-{i}"] = arg
kwargs_.pop(signature.var_arg)
no_key_args = frozenset(arg for arg in kwargs_ if isinstance(kwargs_[arg], WiringPort) and _NoKeyMarker in (kwargs_[arg].markers or ()))
pass_through_args = frozenset(arg for arg in kwargs_ if isinstance(kwargs_[arg], WiringPort) and _PassthroughMarker in (kwargs_[arg].markers or ()))
_validate_pass_through(signature, kwargs_, pass_through_args) # Ensure the pass through args are correctly typed.
input_types = {k: v.output_type.dereference() for k, v in kwargs_.items() if k not in non_ts_inputs}
signature_types = {
k: (
signature.input_types[k]
if k.split("-")[0] != signature.var_arg
else signature.input_types[signature.var_arg].value_tp
)
for k in input_types
}
# Figure out if the map is done over a TSD or TSL by finding the first miltiplexing input
map_type = None
multiplex_type = None
if tsd_keys is not None: # corner case where there are no other inputs but explicitly provided keys
map_type = "TSD"
multiplex_type = HgTSDTypeMetaData
elif input_types:
for k, v in input_types.items():
if k not in (no_key_args | pass_through_args) and type(v_tp := v.dereference()) in (HgTSDTypeMetaData, HgTSLTypeMetaData):
sig_tp = signature_types[k]
if sig_tp.matches(v) and not isinstance(sig_tp, HgTsTypeVarTypeMetaData): # not multiplexing
continue
if isinstance(v_tp, HgTSDTypeMetaData) and sig_tp.matches(v_tp.value_tp):
map_type = "TSD"
multiplex_type = HgTSDTypeMetaData
break
elif isinstance(v_tp, HgTSLTypeMetaData) and sig_tp.matches(v_tp.value_tp):
map_type = "TSL"
multiplex_type = HgTSLTypeMetaData
break
else:
raise CustomMessageWiringError(
f"parameter {k}:{sig_tp} of the mapped graph does not match the input type {v_tp.py_type}"
"for either direct match or multiplexing"
)
if map_type is None:
raise CustomMessageWiringError(
f"failed to determine the type of mapping over {signature} with parameters "
f"of {','.join(f'{k}:{v}' for k, v in input_types.items())}"
)
direct_args = frozenset(
k
for k, v in input_types.items()
if k not in (no_key_args | pass_through_args) and signature_types[k].matches(v)
if
# (type(signature.input_types[k]) is not HgTsTypeVarTypeMetaData and # All time-series value match this!
(type(v) is not multiplex_type)
) # So if it is possibly not direct, don't mark it direct
multiplex_args = frozenset(
k
for k, v in input_types.items()
if k not in pass_through_args and k not in direct_args and type(v) is multiplex_type
)
_validate_multiplex_types(signature_types, kwargs_, multiplex_args, no_key_args)
if not tsd_keys and (len(no_key_args) + len(multiplex_args) == 0):
raise CustomMessageWiringError(f"No multiplexed inputs found for TSD map over {signature} with parameters {input_types}")
if len(multiplex_args) + len(direct_args) + len(pass_through_args) + len(non_ts_inputs) != len(kwargs_):
raise CustomMessageWiringError(f"Unable to determine how to split inputs with args:\n {kwargs_}")
if map_type == "TSL":
key_tp = _extract_tsl_size(kwargs_, multiplex_args, no_key_args)
else:
key_tp = _validate_tsd_keys(kwargs_, multiplex_args, no_key_args, tsd_keys)
return (
multiplex_args,
no_key_args,
pass_through_args,
direct_args,
map_type,
key_tp if map_type == "TSL" else HgTSTypeMetaData(key_tp),
)
def _prepare_stub_inputs(
kwargs_: dict[str, WiringPort | SCALAR],
input_types: dict[str, HgTypeMetaData],
multiplex_args: frozenset[str],
no_key_args: frozenset[str],
input_key_tp: HgTSTypeMetaData,
input_key_name: str | None,
):
call_kwargs = {}
for key, arg in input_types.items():
if key in multiplex_args or key in no_key_args:
arg: HgTSDTypeMetaData | HgTSLTypeMetaData
call_kwargs[key] = stub_wiring_port(arg.value_tp)
elif key == KEYS_ARG:
continue
elif arg.is_scalar:
call_kwargs[key] = kwargs_[key]
else:
call_kwargs[key] = stub_wiring_port(arg)
if input_key_name:
call_kwargs[input_key_name] = stub_wiring_port(input_key_tp)
return call_kwargs
def _create_tsd_map_wiring_node(
fn: WiringNodeClass,
kwargs_: dict[str, WiringPort | SCALAR],
input_types: dict[str, HgTypeMetaData],
multiplex_args: frozenset[str],
no_key_args: frozenset[str],
input_key_tp: HgTSTypeMetaData,
input_key_name: str | None,
) -> [TsdMapWiringNodeClass, tuple]:
# Resolve the mapped function signature
stub_inputs = _prepare_stub_inputs(kwargs_, input_types, multiplex_args, no_key_args, input_key_tp, input_key_name)
resolved_signature = fn.resolve_signature(**stub_inputs)
reference_inputs = frozendict({
k: as_reference(v, k in multiplex_args) if isinstance(v, HgTimeSeriesTypeMetaData) and k != KEYS_ARG else v
for k, v in input_types.items()
})
# NOTE: The wrapper node does not need to set it valid and tick to that of the underlying node, it just
# needs to ensure that it gets notified when the key sets tick. Likewise with validity.
# Build provisional signature first so we can pass it in as context into inner graph wiring
provisional_signature = WiringNodeSignature(
node_type=WiringNodeType.COMPUTE_NODE if resolved_signature.output_type else WiringNodeType.SINK_NODE,
name="map",
# All actual inputs are encoded in the input_types, so we just need to add the keys if present.
args=tuple(input_types.keys()),
defaults=frozendict(), # Defaults would have already been applied.
input_types=reference_inputs,
output_type=(
HgTSDTypeMetaData(input_key_tp.value_scalar_tp, resolved_signature.output_type.dereference().as_reference())
if resolved_signature.output_type
else None
),
src_location=resolved_signature.src_location, # TODO: Figure out something better for this.
active_inputs=None, # We will follow a copy approach to transfer the inputs to inner graphs
valid_inputs=frozenset({
KEYS_ARG,
}), # We have constructed the map so that the key are is always present.
all_valid_inputs=None,
context_inputs=None,
unresolved_args=frozenset(),
time_series_args=frozenset(k for k, v in input_types.items() if not v.is_scalar),
# label=f"map('{resolved_signature.signature}', {', '.join(input_types.keys())})",
has_nested_graphs=True,
)
if resolved_signature.var_arg:
resolved_input_types = dict(resolved_signature.input_types)
for i in range(resolved_signature.input_types[resolved_signature.var_arg].size.SIZE):
resolved_input_types[f"{resolved_signature.var_arg}-{i}"] = resolved_signature.input_types[
resolved_signature.var_arg
].value_tp
resolved_input_types.pop(resolved_signature.var_arg)
else:
resolved_input_types = resolved_signature.input_types
graph, ri = wire_nested_graph(
fn,
resolved_input_types,
{
k: kwargs_[k]
for k, v in resolved_input_types.items()
if not isinstance(v, HgTimeSeriesTypeMetaData) and k != KEYS_ARG
},
provisional_signature,
input_key_name,
depth=2,
)
map_signature = TsdMapWiringSignature(
**provisional_signature.as_dict(),
map_fn_signature=resolved_signature,
key_tp=input_key_tp.value_scalar_tp,
key_arg=input_key_name,
multiplexed_args=multiplex_args,
inner_graph=graph,
)
wiring_node = TsdMapWiringNodeClass(map_signature, fn)
return wiring_node, ri
def _create_tsl_map_signature(
fn: WiringNodeClass,
kwargs_: dict[str, WiringPort | SCALAR],
input_types: dict[str, HgTypeMetaData],
multiplex_args: frozenset[str],
size_tp: "HgAtomicType",
input_key_name: str | None,
):
# Resolve the mapped function signature
stub_inputs = _prepare_stub_inputs(
kwargs_, input_types, multiplex_args, frozenset(), HgTSTypeMetaData.parse_type(TS[int]), input_key_name
)
resolved_signature = fn.resolve_signature(**stub_inputs)
reference_inputs = frozendict({
k: as_reference(v, k in multiplex_args) if isinstance(v, HgTimeSeriesTypeMetaData) else v
for k, v in input_types.items()
})
# Build provisional signature first so we can pass it in as context into inner graph wiring
provisional_signature = WiringNodeSignature(
node_type=WiringNodeType.COMPUTE_NODE if resolved_signature.output_type else WiringNodeType.SINK_NODE,
name="map",
# All actual inputs are encoded in the input_types, so we just need to add the keys if present.
args=tuple(input_types.keys()),
defaults=frozendict(), # Defaults would have already been applied.
input_types=frozendict(reference_inputs),
output_type=(
HgTSLTypeMetaData(resolved_signature.output_type.as_reference(), size_tp)
if resolved_signature.output_type
else None
),
src_location=resolved_signature.src_location, # TODO: Figure out something better for this.
active_inputs=frozenset(),
valid_inputs=frozenset(),
all_valid_inputs=None,
context_inputs=None,
unresolved_args=frozenset(),
time_series_args=frozenset(k for k, v in input_types.items() if not v.is_scalar),
label=f"map('{resolved_signature.signature}', {', '.join(input_types.keys())})",
has_nested_graphs=True,
)
graph, reassignables = wire_nested_graph(
fn,
resolved_signature.input_types,
{
k: kwargs_[k]
for k, v in resolved_signature.input_types.items()
if not isinstance(v, HgTimeSeriesTypeMetaData) and k != KEYS_ARG
},
provisional_signature,
input_key_name,
depth=2,
)
map_signature = TslMapWiringSignature(
**provisional_signature.as_dict(),
map_fn_signature=resolved_signature,
size_tp=size_tp,
key_arg=input_key_name,
multiplexed_args=multiplex_args,
inner_graph=graph,
)
wiring_node = TslMapWiringNodeClass(map_signature, fn)
return wiring_node, reassignables
def _validate_tsd_keys(kwargs_, multiplex_args, no_key_args, tsd_keys):
"""
Ensure all the multiplexed inputs use the same input key.
"""
types = set(kwargs_[arg].output_type.dereference().key_tp for arg in chain(multiplex_args, no_key_args))
if tsd_keys:
types.add(tsd_keys.output_type.dereference().value_scalar_tp)
if len(types) > 1:
raise CustomMessageWiringError(f"The TSD multiplexed inputs have different key types: {types}")
return next(iter(types))
def _validate_pass_through(signature: WiringNodeSignature, kwargs_, pass_through_args):
"""
Validates that the pass-through inputs are valid.
"""
for arg in pass_through_args:
if isinstance(kwargs_[arg], WiringPort) and kwargs_[arg].markers and _PassthroughMarker in kwargs_[arg].markers:
if not (in_type := signature.input_types[arg]).matches(kwargs_[arg].output_type.dereference()):
raise CustomMessageWiringError(
f"The input '{arg}: {kwargs_[arg].output_type}' is marked as pass_through,"
f"but is not compatible with the input type: {in_type}"
)
def _extract_tsl_size(kwargs_: dict[str, WiringPort], multiplex_args, marker_args) -> type[Size]:
"""
With a TSL multiplexed input, we need to determine the size of the output. This is done by looking at all the inputs
that could be multiplexed.
"""
sizes: List[type[Size]] = [
cast(type[Size], cast(HgTSLTypeMetaData, kwargs_[arg].output_type).size_tp.py_type)
for arg in chain(
multiplex_args, (m_arg for m_arg in marker_args if not (kwargs_[m_arg].markers and _PassthroughMarker in kwargs_[m_arg].markers))
)
]
size: type[Size] = Size
for sz in sizes:
if sz.FIXED_SIZE:
if size.FIXED_SIZE:
size = size if size.SIZE < sz.SIZE else sz
else:
size = sz
return size
def _validate_multiplex_types(signature_types, kwargs_, multiplex_args, no_key_args):
"""
Validates that the multiplexed inputs are valid.
"""
for arg in chain(multiplex_args, no_key_args):
if not (in_type := signature_types[arg].dereference()).matches(
(m_type := kwargs_[arg].output_type.dereference()).value_tp
):
raise CustomMessageWiringError(
f"The input '{arg}: {m_type}' is a multiplexed type, but its value type '{m_type.value_tp}' is not"
f" compatible with the input type of the inner graph: {in_type}"
)