165 lines
5.8 KiB
Python
165 lines
5.8 KiB
Python
from __future__ import annotations
|
||
import json
|
||
import loguru
|
||
import asyncio
|
||
from weakref import WeakValueDictionary, WeakSet, finalize
|
||
from collections import UserDict
|
||
from uuid import uuid4
|
||
from fastapi import WebSocket, FastAPI, WebSocketDisconnect, WebSocketException
|
||
from .event import call_event
|
||
|
||
class TopicWeakDict(UserDict):
|
||
# Аналог WeakValueDict, но с сильными ссылками для возможности поддержки WeakSet
|
||
def __init__(self, initial_data = None):
|
||
super().__init__(initial_data or {})
|
||
|
||
def __setitem__(self, key, value):
|
||
finalize(value, lambda k=key, s=self: s.pop(k))
|
||
super().__setitem__(key, value)
|
||
|
||
class Server:
|
||
_logger: loguru.Logger = loguru.logger
|
||
_static_tokens: list[str] = []
|
||
_temp_tokens: list[str] = []
|
||
_requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
|
||
_topics = TopicWeakDict()
|
||
_topic_lock = asyncio.Lock()
|
||
|
||
@classmethod
|
||
def init(cls, *, app: FastAPI, static_tokens: list[str] = []):
|
||
cls._static_tokens = static_tokens
|
||
|
||
cls._requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
|
||
cls._register_routes(app)
|
||
|
||
@classmethod
|
||
def publish(cls, topic: str, message) -> asyncio.Future:
|
||
if topic not in cls._topics:
|
||
raise KeyError('Клиентов прослушивающих этот топик не существует')
|
||
|
||
request, data = cls._make_request()
|
||
data['data'] = message
|
||
data = json.dumps(data)
|
||
# Меняем синтаксис под Squirrel
|
||
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
|
||
|
||
asyncio.create_task(cls._send_to_topic(topic, data))
|
||
return request
|
||
|
||
@classmethod
|
||
async def _send_to_topic(cls, topic, data):
|
||
for connection in cls._topics[topic]:
|
||
await connection.send_text(data)
|
||
|
||
@classmethod
|
||
async def send(cls, connection: WebSocket, message, uuid: str):
|
||
data = {
|
||
'uuid': uuid,
|
||
'data': message,
|
||
}
|
||
data = json.dumps(data)
|
||
# Меняем синтаксис под Squirrel
|
||
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
|
||
await connection.send_text(data)
|
||
|
||
@classmethod
|
||
def _make_request(cls):
|
||
request_id = str(uuid4())
|
||
request = asyncio.Future()
|
||
cls._requests[request_id] = request
|
||
|
||
data = {
|
||
'uuid': request_id,
|
||
'data': None,
|
||
}
|
||
|
||
return request, data
|
||
|
||
@classmethod
|
||
def _register_routes(cls, app):
|
||
@app.websocket('/pyg2o')
|
||
async def pyg2o(websocket: WebSocket, token: str, topics: str | None = None):
|
||
await cls._handle_connection(websocket, token, topics)
|
||
|
||
_ = pyg2o
|
||
|
||
@classmethod
|
||
async def _subscribe(cls, topic_list: list[str], connection: WebSocket):
|
||
async with cls._topic_lock:
|
||
for topic in topic_list:
|
||
if topic not in cls._topics:
|
||
cls._topics[topic] = WeakSet()
|
||
cls._topics[topic].add(connection)
|
||
|
||
@classmethod
|
||
async def _unsubscribe(cls, topic_list: list[str], connection: WebSocket):
|
||
async with cls._topic_lock:
|
||
for topic in topic_list:
|
||
cls._topics[topic].discard(connection)
|
||
|
||
@classmethod
|
||
async def _handle_connection(cls, connection: WebSocket, token: str, topics: str | None):
|
||
|
||
if not await cls._process_query_params(connection, token, topics):
|
||
await connection.close()
|
||
return
|
||
|
||
await connection.accept()
|
||
cls._logger.info('WebSocket клиент подключился')
|
||
|
||
try:
|
||
while True:
|
||
try:
|
||
data = await connection.receive_text()
|
||
message_data = json.loads(data)
|
||
asyncio.create_task(cls._process_message(connection, message_data))
|
||
except json.JSONDecodeError as e:
|
||
cls._logger.exception(f'Ошибка декодирования JSON: {e}')
|
||
except WebSocketDisconnect:
|
||
cls._logger.info('WebSocket клиент отключился')
|
||
except WebSocketException as e:
|
||
cls._logger.exception(f'Ошибка WebSocket подключения: {e}')
|
||
|
||
@classmethod
|
||
async def _process_query_params(cls, connection: WebSocket, token: str, topics: str | None) -> bool:
|
||
|
||
if token not in cls._static_tokens and token not in cls._temp_tokens:
|
||
return False
|
||
|
||
if topics is not None:
|
||
topic_list = [s.strip() for s in topics.split(',')]
|
||
await cls._subscribe(topic_list, connection)
|
||
|
||
return True
|
||
|
||
@classmethod
|
||
async def _process_message(cls, connection: WebSocket, message: dict):
|
||
match message:
|
||
|
||
case {'event': event, **kwargs}:
|
||
try:
|
||
cls._requests[kwargs['uuid']].set_result(kwargs.get('data'))
|
||
except KeyError:
|
||
uuid = kwargs.get('uuid')
|
||
if uuid is not None:
|
||
del kwargs['uuid']
|
||
asyncio.create_task(call_event(event, connection, uuid, **kwargs))
|
||
|
||
case {'subscribe': topics}:
|
||
await cls._subscribe(topics, connection)
|
||
|
||
case {'unsubscribe': topics}:
|
||
await cls._unsubscribe(topics, connection)
|
||
|
||
case {'create_temp_token': token}:
|
||
cls._temp_tokens.append(token)
|
||
|
||
case {'remove_temp_token': token}:
|
||
cls._temp_tokens.remove(token)
|
||
|
||
case {'init_temp_tokens': tokens}:
|
||
cls._temp_tokens = cls._temp_tokens + list(tokens.items())
|
||
|
||
case _:
|
||
raise ValueError(f'Неподдерживаемый тип PyG2O сообщения: {message}')
|