from __future__ import annotations import json import loguru import asyncio from weakref import WeakValueDictionary, WeakSet, finalize from collections import UserDict from uuid import uuid4 from fastapi import WebSocket, FastAPI, WebSocketDisconnect, WebSocketException from .event import call_event class TopicWeakDict(UserDict): # Аналог WeakValueDict, но с сильными ссылками для возможности поддержки WeakSet def __init__(self, initial_data = None): super().__init__(initial_data or {}) def __setitem__(self, key, value): finalize(value, lambda k=key, s=self: s.pop(k)) super().__setitem__(key, value) class Server: _logger: loguru.Logger = loguru.logger _static_tokens: list[str] = [] _temp_tokens: list[str] = [] _requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary() _topics = TopicWeakDict() _topic_lock = asyncio.Lock() @classmethod def init(cls, *, app: FastAPI, static_tokens: list[str] = []): cls._static_tokens = static_tokens cls._requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary() cls._register_routes(app) @classmethod def publish(cls, topic: str, message) -> asyncio.Future: if topic not in cls._topics: raise KeyError('Клиентов прослушивающих этот топик не существует') request, data = cls._make_request() data['data'] = message data = json.dumps(data) # Меняем синтаксис под Squirrel data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false') asyncio.create_task(cls._send_to_topic(topic, data)) return request @classmethod async def _send_to_topic(cls, topic, data): for connection in cls._topics[topic]: await connection.send_text(data) @classmethod async def send(cls, connection: WebSocket, message, uuid: str): data = { 'uuid': uuid, 'data': message, } data = json.dumps(data) # Меняем синтаксис под Squirrel data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false') await connection.send_text(data) @classmethod def _make_request(cls): request_id = str(uuid4()) request = asyncio.Future() cls._requests[request_id] = request data = { 'uuid': request_id, 'data': None, } return request, data @classmethod def _register_routes(cls, app): @app.websocket('/pyg2o') async def pyg2o(websocket: WebSocket, token: str, topics: str | None = None): await cls._handle_connection(websocket, token, topics) _ = pyg2o @classmethod async def _subscribe(cls, topic_list: list[str], connection: WebSocket): async with cls._topic_lock: for topic in topic_list: if topic not in cls._topics: cls._topics[topic] = WeakSet() cls._topics[topic].add(connection) @classmethod async def _unsubscribe(cls, topic_list: list[str], connection: WebSocket): async with cls._topic_lock: for topic in topic_list: cls._topics[topic].discard(connection) @classmethod async def _handle_connection(cls, connection: WebSocket, token: str, topics: str | None): if not await cls._process_query_params(connection, token, topics): await connection.close() return await connection.accept() cls._logger.info('WebSocket клиент подключился') try: while True: try: data = await connection.receive_text() message_data = json.loads(data) asyncio.create_task(cls._process_message(connection, message_data)) except json.JSONDecodeError as e: cls._logger.exception(f'Ошибка декодирования JSON: {e}') except WebSocketDisconnect: cls._logger.info('WebSocket клиент отключился') except WebSocketException as e: cls._logger.exception(f'Ошибка WebSocket подключения: {e}') @classmethod async def _process_query_params(cls, connection: WebSocket, token: str, topics: str | None) -> bool: if token not in cls._static_tokens and token not in cls._temp_tokens: return False if topics is not None: topic_list = [s.strip() for s in topics.split(',')] await cls._subscribe(topic_list, connection) return True @classmethod async def _process_message(cls, connection: WebSocket, message: dict): match message: case {'event': event, **kwargs}: try: cls._requests[kwargs['uuid']].set_result(kwargs.get('data')) except KeyError: uuid = kwargs.get('uuid') if uuid is not None: del kwargs['uuid'] asyncio.create_task(call_event(event, connection, uuid, **kwargs)) case {'subscribe': topics}: await cls._subscribe(topics, connection) case {'unsubscribe': topics}: await cls._unsubscribe(topics, connection) case {'create_temp_token': token}: cls._temp_tokens.append(token) case {'remove_temp_token': token}: cls._temp_tokens.remove(token) case {'init_temp_tokens': tokens}: cls._temp_tokens = cls._temp_tokens + list(tokens.items()) case _: raise ValueError(f'Неподдерживаемый тип PyG2O сообщения: {message}')