Source code for hgraph._wiring._reduce

import inspect
from typing import Callable, cast

from hgraph._types._ts_meta_data import HgTSTypeMetaData
from hgraph._types._scalar_types import SIZE, ZERO, STATE
from hgraph._types._scalar_type_meta_data import HgTupleCollectionScalarType
from hgraph._types._time_series_types import TIME_SERIES_TYPE, TIME_SERIES_TYPE_1, K
from hgraph._types._tsd_meta_data import HgTSDTypeMetaData
from hgraph._types._tsd_type import TSD
from hgraph._types._tsl_meta_data import HgTSLTypeMetaData
from hgraph._types._tsl_type import TSL
from hgraph._types._typing_utils import with_signature
from hgraph._wiring._decorators import compute_node, graph
from hgraph._wiring._wiring_context import WiringContext
from hgraph._wiring._wiring_errors import CustomMessageWiringError
from hgraph._wiring._wiring_node_class._reduce_wiring_node import (
    TsdReduceWiringNodeClass,
    ReduceWiringSignature,
    TsdNonAssociativeReduceWiringNodeClass,
)
from hgraph._wiring._wiring_node_class._wiring_node_class import WiringNodeClass
from hgraph._wiring._wiring_node_signature import WiringNodeSignature
from hgraph._wiring._wiring_port import WiringPort
from hgraph._wiring._wiring_utils import wire_nested_graph

__all__ = ("reduce",)


[docs] def reduce( func: Callable[[TIME_SERIES_TYPE, TIME_SERIES_TYPE_1], TIME_SERIES_TYPE], ts: TSD[K, TIME_SERIES_TYPE_1] | TSL[TIME_SERIES_TYPE_1, SIZE], zero: TIME_SERIES_TYPE = ZERO, is_associative: bool = True, ) -> TIME_SERIES_TYPE: """ Reduce the input time-series collection into a single time-series value. The zero must-be compatible with the TIME_SERIES_TYPE value and be constructable as const(zero, TIME_SERIES_TYPE). If the function is associative, then TIME_SERIES_TYPE must be the same as TIME_SERIES_TYPE_1. When the function is associative, the 'reduce' will perform a tree reduction, otherwise it will perform a linear reduction. The tree reduction is much faster on change. By definition, the reduce function over a TSD must be commutative and associative in the sense that the order of the inputs are not guaranteed. Only a TSL supports non-commutative reduce functions. Example [TSD]: :: tsd: TSD[str, TS[int]] = ... out = reduce(add_, tsd, 0) >> tsd <- {'a': [1], 'b': [4], 'c': [7]} >> out -> 12 Example [TSL]: :: tsl: TSL[TS[int], SIZE] = ... out = reduce(add_, tsl, 0) >> tsl <- ([1], [2], [3], [4], [5]) >> out -> 15 Example [TS[tuple[SCALAR, ...]]: :: ts: TS[tuple[int, ...]] = ... initial_value: TS[str] = ... out = reduce(lambda x, y: format_("{}, {}", x, y), ts, initial_value) NOTE: TSD[int, TIME_SERIES_TYPE_1] with is_associative=False is the only TSD non-associative reduce supported. The expectation is that the integer values represent a uniform list from 0 to size-1. There cannot be holes in the sequence. Removals of keys must be of the form [0:n] where n is the last element in the new set. This allows for processing an input that is different from the output type, as for the tuple example. The ``zero`` element is used as the default result if no value is supplied, it is also used as the input for the chain of keys. The output of the n-1th element is used as the input to the nth element lhs. The values from the dict are used as the rhs input. """ if isinstance(func, WiringNodeClass): signature = func.signature.signature elif isinstance(func, Callable) and func.__name__ == "<lambda>": signature = inspect.signature(func).__str__() else: raise RuntimeError(f"The supplied time-series is not a valid input: '{ts}'") _tp = ts.output_type.dereference() with WiringContext(current_signature=STATE(signature=f"reduce('{signature}', {_tp}, {zero})")): if type(_tp) is HgTSLTypeMetaData: return _reduce_tsl(func, ts, zero, is_associative) elif type(_tp) is HgTSDTypeMetaData: if not is_associative: if _tp.key_tp.py_type is not int: raise CustomMessageWiringError( "Non-associative operators are not supported using TSD inputs that are not integer keyed" ) return _reduce_tsd(func, ts, zero, is_associative) elif type(_tp) is HgTSTypeMetaData and type(_tp.value_scalar_tp) is HgTupleCollectionScalarType: if is_associative: raise CustomMessageWiringError("Associative operators are not supported using TS[tuple[...]] inputs") return _reduce_tuple(func, ts, zero) else: raise RuntimeError(f"Unexpected time-series type: {ts.output_type}")
def _reduce_tsl(func, ts, zero, is_associative): """For the moment, we only support fixed size TSLs. So we can lay out the reduction in the graph statically""" from hgraph import default tp_ = ts.output_type item_tp = tp_.value_tp.py_type if not isinstance(zero, WiringPort): if not is_associative: raise CustomMessageWiringError( "Non-associative operators require a time-series value for zero to be provided" ) if zero is ZERO: import hgraph zero = hgraph._operators._operators.zero(item_tp, func) elif zero is None: from hgraph import nothing zero = nothing(item_tp) else: from hgraph import const zero = const(zero, item_tp) if (sz := tp_.size_tp.py_type.SIZE) == 0: return zero if not is_associative or sz < 4: out = default(ts[0], zero) for i in range(1, sz): out = func(out, default(ts[i], zero)) return out else: outs = [func(default(ts[i], zero), default(ts[i + 1], zero)) for i in range(0, sz - sz % 2, 2)] over_run = None if sz % 2 == 0 else default(ts[-1], zero) while len(outs) > 1: l = len(outs) if l % 2 == 1: if over_run is not None: outs.append(over_run) l += 1 else: over_run = outs.pop() l -= 1 outs = [func(outs[i], outs[i + 1]) for i in range(0, l, 2)] if over_run is not None: out = func(outs[0], over_run) else: out = outs[0] return out def _reduce_tsd(func, ts, zero, is_associative=True) -> TIME_SERIES_TYPE: from hgraph._types._ref_type import REF # We need to ensure that the reduction graph contains no push nodes. (We should be able to support pull nodes) @compute_node def _reduce_tsd_signature( ts: TSD[K, REF[TIME_SERIES_TYPE_1]], zero: REF[TIME_SERIES_TYPE] ) -> REF[TIME_SERIES_TYPE]: ... # Used to create a WiringNodeClass template tp = ts.output_type.dereference() item_tp_md = tp.value_tp item_tp = item_tp_md.py_type if not isinstance(zero, WiringPort): if not is_associative: raise CustomMessageWiringError( "Non-associative operators require a time-series value for zero to be provided" ) if zero is ZERO: import hgraph zero = hgraph._operators._operators.zero(item_tp, func) elif zero is None: from hgraph import nothing zero = nothing(item_tp) else: from hgraph import const zero = const(zero, item_tp) wp = _reduce_tsd_signature(ts, zero) resolved_signature = cast(WiringPort, wp).node_instance.resolved_signature resolved_signature = WiringNodeSignature( node_type=resolved_signature.node_type, name="reduce", args=resolved_signature.args, defaults=resolved_signature.defaults, input_types=resolved_signature.input_types, output_type=resolved_signature.output_type, src_location=resolved_signature.src_location, active_inputs=resolved_signature.active_inputs, valid_inputs=resolved_signature.valid_inputs, all_valid_inputs=resolved_signature.all_valid_inputs, context_inputs=resolved_signature.context_inputs, unresolved_args=resolved_signature.unresolved_args, time_series_args=resolved_signature.time_series_args, injectables=resolved_signature.injectables, # label=resolved_signature.label, has_nested_graphs=True, ) if not isinstance(func, WiringNodeClass): if is_associative: func = graph( with_signature( func, annotations={k: item_tp for k in inspect.signature(func).parameters}, return_annotation=item_tp, ) ) else: zero_tp = zero.output_type.dereference() parameters = list(inspect.signature(func).parameters) if len(parameters) != 2: raise CustomMessageWiringError( f"The function must have exactly two arguments, but has {len(parameters)}" ) annotations = {parameters[0]: zero_tp, parameters[1]: item_tp} func = graph( with_signature( func, annotations=annotations, return_annotation=zero_tp, ) ) if is_associative: input_types = {k: tp.value_tp for k in func.signature.input_types} else: input_types = dict(zip(func.signature.args, (zero.output_type, item_tp_md))) builder, ri = wire_nested_graph(func, input_types, {}, resolved_signature, None, depth=2) reduce_signature = ReduceWiringSignature(**resolved_signature.as_dict(), inner_graph=builder) wiring_node = ( TsdReduceWiringNodeClass(reduce_signature, func) if is_associative else TsdNonAssociativeReduceWiringNodeClass(reduce_signature, func) ) port = wiring_node(ts, zero) from hgraph import WiringGraphContext WiringGraphContext.instance().reassign_items(ri, port.node_instance) return port def _reduce_tuple(func, ts, zero): from hgraph import convert, dedup tsd = dedup(convert[TSD](ts)) return _reduce_tsd(func, tsd, zero, False)