Source code for hgraph.adaptors.tornado.websocket_client_adaptor

from collections import defaultdict, deque
from typing import Callable, Type

from tornado import httpclient
from tornado.websocket import websocket_connect

from hgraph import (
    AUTO_RESOLVE,
    service_adaptor,
    service_adaptor_impl,
    TSD,
    push_queue,
    GlobalState,
    sink_node,
    STATE,
    TSB,
)
from hgraph.adaptors.tornado._tornado_web import TornadoWeb
from hgraph.adaptors.tornado.websocket_server_adaptor import STR_OR_BYTES, WebSocketClientRequest, WebSocketResponse


[docs] @service_adaptor def websocket_client_adaptor(
request: TSB[WebSocketClientRequest[STR_OR_BYTES]], path: str = "websocket_client" ) -> TSB[WebSocketResponse[STR_OR_BYTES]]: ...
[docs] @service_adaptor_impl(interfaces=websocket_client_adaptor) def websocket_client_adaptor_impl( request: TSD[int, TSB[WebSocketClientRequest[STR_OR_BYTES]]], path: str = "websocket_client", _tp: Type[STR_OR_BYTES] = AUTO_RESOLVE, ) -> TSD[int, TSB[WebSocketResponse[STR_OR_BYTES]]]: is_binary = _tp is bytes @push_queue(TSD[int, TSB[WebSocketResponse[_tp]]]) def from_web(sender, path: str = "websocket_client") -> TSD[int, TSB[WebSocketResponse[_tp]]]: GlobalState.instance()[f"websocket_client_adaptor://{path}/queue"] = sender return None async def make_websocket_request(state: STATE, id: int, request: WebSocketClientRequest, sender: Callable): try: ws = await websocket_connect( httpclient.HTTPRequest(request.url, headers=request.headers), ping_interval=1, ping_timeout=3 ) except Exception as e: sender({id: {"connect_response": False}}) return state.sockets[id] = ws sender({id: {"connect_response": True}}) try: while True: msg = await ws.read_message() if msg is None: break if type(msg) is str and is_binary: msg = msg.encode("utf-8") elif type(msg) is bytes and not is_binary: msg = msg.decode("utf-8") sender({id: {"message": msg}}) finally: ws.close() del state.sockets[id] sender({id: {"connect_response": False}}) async def send_websocket_message(state: STATE, id: int, message: STR_OR_BYTES): if ws := state.sockets.get(id): await ws.write_message(message, binary=is_binary) else: state.queues[id].append(message) @sink_node def to_web(request: TSD[int, TSB[WebSocketClientRequest[STR_OR_BYTES]]], _state: STATE = None): sender = GlobalState.instance()[f"websocket_client_adaptor://{path}/queue"] for i, r in request.modified_items(): if r.connect_request.modified: TornadoWeb.get_loop().add_callback(make_websocket_request, _state, i, r.connect_request.value, sender) if r.message.modified: TornadoWeb.get_loop().add_callback(send_websocket_message, _state, i, r.message.value) @to_web.start def to_web_start(_state: STATE): TornadoWeb.start_loop() _state.queues = defaultdict(deque) _state.sockets = {} @to_web.stop def to_web_stop(): TornadoWeb.stop_loop() to_web(request) return from_web()