Source code for hgraph.adaptors.tornado.websocket_server_adaptor

import asyncio
import logging
from dataclasses import dataclass
from typing import Callable, Generic, Type, TypeVar

import tornado.websocket
from frozendict import frozendict

from hgraph import (
    AUTO_RESOLVE,
    REMOVE,
    STATE,
    TIME_SERIES_TYPE,
    TS,
    TS_SCHEMA,
    TSB,
    TSD,
    CompoundScalar,
    GlobalState,
    HgTSBTypeMetaData,
    HgTypeMetaData,
    TimeSeriesSchema,
    adaptor,
    adaptor_impl,
    combine,
    graph,
    map_,
    merge,
    partition,
    push_queue,
    register_service,
    sink_node,
)
from hgraph.adaptors.tornado._tornado_web import BaseHandler, TornadoWeb

logger = logging.getLogger("websocket_server_adaptor")


[docs] @dataclass(frozen=True) class WebSocketConnectRequest(CompoundScalar): url: str url_parsed_args: tuple[str, ...] = () headers: dict[str, str] = frozendict() cookies: dict[str, dict[str, object]] = frozendict() auth: object = None
STR_OR_BYTES = TypeVar("STR_OR_BYTES", bytes, str)
[docs] @dataclass(frozen=True) class WebSocketServerRequest(TimeSeriesSchema, Generic[STR_OR_BYTES]): connect_request: TS[WebSocketConnectRequest] messages: TS[tuple[STR_OR_BYTES, ...]]
[docs] @dataclass(frozen=True) class WebSocketClientRequest(TimeSeriesSchema, Generic[STR_OR_BYTES]): connect_request: TS[WebSocketConnectRequest] message: TS[STR_OR_BYTES]
[docs] @dataclass(frozen=True) class WebSocketResponse(TimeSeriesSchema, Generic[STR_OR_BYTES]): connect_response: TS[bool] message: TS[STR_OR_BYTES]
class WebSocketAdaptorManager: handlers: dict[str, Callable | str] def __init__(self, binary): self.handlers = {} self.requests = {} self.message_handlers = {} self.binary = binary # Stable ids for requests and per-port handler registration self._next_request_id = 1 self._pyid_to_id: dict[int, int] = {} self._registered_paths_by_port: dict[int, set[str]] = {} # Queues are set from adaptor start; buffer early events until ready self.connect_queue = None self.message_queue = None self._pending_connect: list[tuple[int, WebSocketConnectRequest]] = [] self._pending_messages: list[tuple[int, object]] = [] @classmethod def instance(cls, tp): if not hasattr(cls, "_instance"): cls._instance = {} if tp not in cls._instance: cls._instance[tp] = cls(tp == bytes) return cls._instance[tp] def set_queues(self, connect_queue, message_queue): self.connect_queue = connect_queue self.message_queue = message_queue # Flush any buffered connections/messages if any were queued before queues were ready if getattr(self, "_pending_connect", None): for rid, req in self._pending_connect: logger.debug(f"[WS] Flushing buffered connect rid={rid} path={getattr(req, 'url', '?')}") try: self.connect_queue({rid: req}) except Exception: logger.debug("[WS] Failed to flush buffered connect", exc_info=True) self._pending_connect.clear() if getattr(self, "_pending_messages", None): for rid, msg in self._pending_messages: logger.debug(f"[WS] Flushing buffered message rid={rid} len={len(msg) if hasattr(msg,'__len__') else 'n/a'}") try: self.message_queue({rid: msg}) except Exception: logger.debug("[WS] Failed to flush buffered message", exc_info=True) self._pending_messages.clear() def start(self, port): self.tornado_web = TornadoWeb.instance(port) reg = self._registered_paths_by_port.setdefault(port, set()) for path in self.handlers.keys(): if path not in reg: logger.info("[WS] Register handler path=%s on port=%s", path, port) self.tornado_web.add_handler(path, WebSocketHandler, {"path": path, "binary": self.binary, "mgr": self}) reg.add(path) self.tornado_web.start() def stop(self): self.tornado_web.stop() def shutdown(self, path: str | None = None): # Cancel any outstanding futures and clear mappings for fut in list(self.requests.values()): try: if fut is not None and not fut.done(): fut.cancel() except Exception: logger.debug("[WS] Ignoring exception while cancelling pending future", exc_info=True) self.requests.clear() self.message_handlers.clear() self._pyid_to_id.clear() # Clear any buffered items try: self._pending_connect.clear() self._pending_messages.clear() except Exception: pass # Drop queue references and remove GlobalState keys if provided if path is not None: try: del GlobalState.instance()[f"websocket_server_adaptor://{path}/connect_queue"] except Exception: pass try: del GlobalState.instance()[f"websocket_server_adaptor://{path}/message_queue"] except Exception: pass self.connect_queue = None self.message_queue = None def add_handler(self, path, handler): self.handlers[path] = handler def add_request(self, request_id, request, message_handler): # Map transient Python object id to a stable, unique request id to avoid id() reuse collisions pyid = request_id rid = self._pyid_to_id.get(pyid) if rid is None: rid = self._next_request_id self._next_request_id += 1 self._pyid_to_id[pyid] = rid try: future = asyncio.get_running_loop().create_future() except Exception: future = asyncio.Future() self.requests[rid] = future self.message_handlers[rid] = message_handler if self.connect_queue is None: # Buffer until queues are set by adaptor start logger.debug(f"[WS] Buffer connect rid={rid} path={getattr(request, 'url', '?')} (queues not ready)") self._pending_connect.append((rid, request)) else: self.connect_queue({rid: request}) return future def remove_message_handler(self, request_id): del self.message_handlers[request_id] self.connect_queue({request_id: REMOVE}) self.message_queue({request_id: REMOVE}) def enqueue_message_for(self, rid: int, msg: object): if self.message_queue is None: logger.debug(f"[WS] Buffer message rid={rid} len={len(msg) if hasattr(msg,'__len__') else 'n/a'} (queue not ready)") self._pending_messages.append((rid, msg)) else: self.message_queue({rid: msg}) def complete_request(self, request_id, response): # request_id here is the stable rid used by the engine if r := response.get("connect_response"): self.requests[request_id].set_result( ( r, lambda m, rid=request_id: self.enqueue_message_for(rid, m), lambda rid=request_id: self.remove_message_handler(rid), ) ) logger.info(f"[WS] Completed websocket open request rid={request_id} with response keys={list(response.keys())}") if m := response.get("message"): if h := self.message_handlers.get(request_id): h(m) class WebSocketHandler(tornado.websocket.WebSocketHandler): def initialize(self, path, binary, mgr: WebSocketAdaptorManager): self.path = path self.mgr = mgr self.binary = binary async def prepare(self): await BaseHandler.prepare(self) async def open(self, *args): request_obj = object() request_id = id(request_obj) response, enqueue, close = await self.mgr.add_request( request_id, WebSocketConnectRequest( url=self.path, url_parsed_args=args, headers=self.request.headers, cookies=frozendict({ k: frozendict({"value": v.value, **{p: w for p, w in v.items()}}) for k, v in self.request.cookies.items() }), auth=getattr(self, "current_user", None), ), lambda m: self.write_message(m, binary=self.binary), ) if response: self.enqueue_message = enqueue self.close = close else: self.close() def on_message(self, message): if self.binary: self.enqueue_message(message if type(message) is bytes else message.encode()) else: self.enqueue_message(message if type(message) is str else message.decode()) def on_close(self): self.close()
[docs] def websocket_server_handler(fn: Callable = None, *, url: str): if fn is None: return lambda fn: websocket_server_handler(fn, url=url) from hgraph import WiringNodeClass if not isinstance(fn, WiringNodeClass): fn = graph(fn) assert "request" in fn.signature.time_series_inputs.keys(), "Websocket graph must have an input named 'request'" single_request_type = HgTypeMetaData.parse_type(TSB[WebSocketServerRequest[STR_OR_BYTES]]) multi_request_type = HgTypeMetaData.parse_type(TSD[int, TSB[WebSocketServerRequest[STR_OR_BYTES]]]) if single_request_type.matches(fn.signature.time_series_inputs["request"]): is_single = True resolution = {} single_request_type.build_resolution_dict(resolution, fn.signature.time_series_inputs["request"]) assert ( STR_OR_BYTES in resolution ), f"STR_OR_BYTES is expected in the resolution of the request type, but got {resolution.keys()}" is_binary = resolution[STR_OR_BYTES].matches_type(bytes) elif multi_request_type.matches(fn.signature.time_series_inputs["request"]): is_single = False resolution = {} multi_request_type.build_resolution_dict(resolution, fn.signature.time_series_inputs["request"]) is_binary = resolution[STR_OR_BYTES].matches_type(bytes) else: assert False, ( "WebSocket graph must have a single input named 'request' of type TSB[WebSocketServerRequest] or TSD[int," " TSB[WebSocketServerRequest]]" ) output_type = fn.signature.output_type single_response_type = HgTypeMetaData.parse_type(TSB[WebSocketResponse[STR_OR_BYTES]]) multi_response_type = HgTypeMetaData.parse_type(TSD[int, TSB[WebSocketResponse[STR_OR_BYTES]]]) if single_response_type.matches(output_type): resolution = {} single_response_type.build_resolution_dict(resolution, output_type) assert is_binary == resolution[STR_OR_BYTES].matches_type(bytes), ( "WebSocket graph must have a single output of type TSB[WebSocketResponse] with the same str/binary type as" " the input" ) assert is_single, ( "WebSocket graph must have a single output of type TSB[WebSocketResponse] when the input is" " TSB[WebSocketServerRequest]" ) elif multi_response_type.matches(output_type): resolution = {} multi_response_type.build_resolution_dict(resolution, output_type) assert is_binary == resolution[STR_OR_BYTES].matches_type(bytes), ( "WebSocket graph must have a single output of type TSD[int, TSB[WebSocketResponse]] with the same" " str/binary type as the input" ) assert not is_single, ( "WebSocket graph must have a single output of type TSD[int, TSB[WebSocketResponse]] when the input is" " TSD[int, TSB[WebSocketServerRequest]]" ) else: assert False, ( "WebSocket graph must have a single output of type TSB[WebSocketResponse] or TSD[int," " TSB[WebSocketResponse]]" ) msg_type = bytes if is_binary else str mgr = WebSocketAdaptorManager.instance(msg_type) # this makes the handler to be auto-wired in the http_server_adaptor mgr.add_handler(url, fn) @graph def websocket_server_handler_graph(**inputs: TSB[TS_SCHEMA]) -> TIME_SERIES_TYPE: # if however this is wired into the graph explicitly, it will be used instead of the auto-wiring the handler mgr.add_handler(url, None) # prevent auto-wiring requests = websocket_server_adaptor[STR_OR_BYTES:msg_type].to_graph(path=url, __no_ts_inputs__=True) if fn.signature.time_series_inputs["request"].matches_type(TSB[WebSocketServerRequest[msg_type]]): if inputs.as_dict(): responses = map_(lambda r, i: fn(request=r, **i.as_dict()), requests, inputs) else: responses = map_(lambda r: fn(request=r), requests) else: responses = fn(request=requests, **inputs) if ( isinstance(responses.output_type, HgTSBTypeMetaData) and "response" in responses.output_type.bundle_schema_tp.meta_data_schema ): websocket_server_adaptor[STR_OR_BYTES:msg_type].from_graph(responses.response, path=url) return responses else: websocket_server_adaptor[STR_OR_BYTES:msg_type].from_graph(responses, path=url) return combine() return websocket_server_handler_graph
@adaptor def websocket_server_adaptor( response: TSD[int, TSB[WebSocketResponse[STR_OR_BYTES]]], path: str ) -> TSD[int, TSB[WebSocketServerRequest[STR_OR_BYTES]]]: ...
[docs] @adaptor_impl(interfaces=(websocket_server_adaptor, websocket_server_adaptor)) def websocket_server_adaptor_helper(path: str, port: int): register_service("websocket_server_adaptor", websocket_server_adaptor_impl, port=port)
@adaptor_impl(interfaces=()) def websocket_server_adaptor_impl(path: str, port: int): from hgraph import WiringGraphContext, WiringNodeClass @push_queue(TSD[int, TS[WebSocketConnectRequest]]) def connections_from_web( sender, path: str = "tornado_websocket_server_adaptor", elide: bool = True ) -> TSD[int, TS[WebSocketConnectRequest]]: GlobalState.instance()[f"websocket_server_adaptor://{path}/connect_queue"] = sender return None @push_queue(TSD[int, TS[tuple[STR_OR_BYTES, ...]]]) def messages_from_web( sender, path: str = "tornado_websocket_server_adaptor", batch: bool = True ) -> TSD[int, TS[tuple[bytes, ...]]]: GlobalState.instance()[f"websocket_server_adaptor://{path}/message_queue"] = sender return None @graph def from_web( path: str, _tp: Type[STR_OR_BYTES] = AUTO_RESOLVE ) -> TSD[int, TSB[WebSocketServerRequest[STR_OR_BYTES]]]: path = f"{path}[{_tp.__name__.lower()}]" requests = connections_from_web(path=path) messages = messages_from_web[STR_OR_BYTES:_tp](path=path) return map_( lambda r, m: combine[TSB[WebSocketServerRequest[_tp]]](connect_request=r, messages=m), requests, messages ) @sink_node def to_web( responses: TSD[int, TSB[WebSocketResponse[STR_OR_BYTES]]], port: int, path: str = "tornado_websocket_server_adaptor", _tp: Type[STR_OR_BYTES] = AUTO_RESOLVE, _state: STATE = None, ): for response_id, response in responses.modified_items(): TornadoWeb.get_loop().add_callback(_state.mgr.complete_request, response_id, response.delta_value) @to_web.start def to_web_start(port: int, path: str, _tp: Type[STR_OR_BYTES] = AUTO_RESOLVE, _state: STATE = None): _state.mgr = WebSocketAdaptorManager.instance(_tp) path = f"{path}[{_tp.__name__.lower()}]" _state.mgr.set_queues( connect_queue=GlobalState.instance()[f"websocket_server_adaptor://{path}/connect_queue"], message_queue=GlobalState.instance()[f"websocket_server_adaptor://{path}/message_queue"], ) _state.mgr.start(port) @to_web.stop def to_web_stop(path: str, _tp: Type[STR_OR_BYTES] = AUTO_RESOLVE, _state: STATE = None): # Ensure we clean up pending futures/queues and remove global keys before stopping the server try: typed_path = f"{path}[{_tp.__name__.lower()}]" _state.mgr.shutdown(typed_path) finally: _state.mgr.tornado_web.stop() adaptors_dedup = set() for msg_type in (str, bytes): if WebSocketAdaptorManager.instance(msg_type).handlers: requests = from_web[STR_OR_BYTES:msg_type](path=path) requests_by_url = partition(requests, requests.connect_request.url) responses = {} for url, handler in WebSocketAdaptorManager.instance(msg_type).handlers.items(): if isinstance(handler, WiringNodeClass): logger.info("Adding WS handler: [%s] %s", url, handler.signature.signature) if HgTypeMetaData.parse_type(TSB[WebSocketServerRequest]).matches(handler.signature.time_series_inputs["request"]): responses[url] = map_(handler, request=requests_by_url[url]) elif HgTypeMetaData.parse_type(TSD[int, TSB[WebSocketServerRequest]]).matches(handler.signature.time_series_inputs["request"]): responses[url] = handler(request=requests_by_url[url]) elif handler is None: logger.info("Pre-wired WS handler: [%s]", url) else: raise ValueError(f"Invalid handler type for the websocket_ adaptor: {handler}") adaptors = set() for handler_path, type_map, node, receive in WiringGraphContext.__stack__[0].registered_service_clients( websocket_server_adaptor ): logger.info( f"Adding WS adaptor: {handler_path} for type {type_map[STR_OR_BYTES]} when msg_type is {msg_type}" ) if type_map[STR_OR_BYTES].py_type != msg_type: continue if (handler_path, receive) in adaptors_dedup: raise ValueError( f"Duplicate websocket_ adaptor client for handler_path {handler_path}: only one client is" " allowed" ) adaptors_dedup.add((handler_path, receive)) adaptors.add(handler_path.replace("/from_graph", "").replace("/to_graph", "")) for handler_path in adaptors: url = websocket_server_adaptor.path_from_full_path(handler_path) mgr = WebSocketAdaptorManager.instance(msg_type) mgr.add_handler(url, None) # prevent auto-wiring the handler responses[url] = ( websocket_server_adaptor[STR_OR_BYTES:msg_type].wire_impl_inputs_stub(handler_path).response ) websocket_server_adaptor[STR_OR_BYTES:msg_type].wire_impl_out_stub(handler_path, requests_by_url[url]) to_web(merge(*responses.values(), disjoint=True), port, path=path)