160 lines
5.7 KiB
Python
160 lines
5.7 KiB
Python
from __future__ import annotations
|
||
import json
|
||
import loguru
|
||
import asyncio
|
||
from weakref import WeakValueDictionary, WeakSet, finalize
|
||
from collections import UserDict
|
||
from uuid import uuid4
|
||
from fastapi import WebSocket, FastAPI, WebSocketDisconnect, WebSocketException
|
||
from .event import call_event
|
||
|
||
class TopicWeakDict(UserDict):
|
||
# Аналог WeakValueDict, но с сильными ссылками для возможности поддержки WeakSet
|
||
def __init__(self, initial_data = None):
|
||
super().__init__(initial_data or {})
|
||
|
||
def __setitem__(self, key, value):
|
||
finalize(value, lambda k=key, s=self: s.pop(k))
|
||
super().__setitem__(key, value)
|
||
|
||
class Server:
|
||
_logger: loguru.Logger = loguru.logger
|
||
_static_tokens: list[str] = []
|
||
_temp_tokens: list[str] = []
|
||
_requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
|
||
_topics = TopicWeakDict()
|
||
_topic_lock = asyncio.Lock()
|
||
|
||
@classmethod
|
||
def init(cls, *, app: FastAPI, static_tokens: list[str] = []):
|
||
cls._static_tokens = static_tokens
|
||
|
||
cls._requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
|
||
cls._register_routes(app)
|
||
|
||
@classmethod
|
||
async def publish(cls, topic: str, message: str) -> asyncio.Future:
|
||
if topic not in cls._topics:
|
||
raise KeyError('Клиентов прослушивающих этот топик не существует')
|
||
|
||
request, data = cls._make_request()
|
||
data['data'] = message
|
||
data = json.dumps(data)
|
||
# Меняем синтаксис под Squirrel
|
||
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
|
||
|
||
for connection in cls._topics[topic]:
|
||
await connection.send_text(data)
|
||
|
||
return request
|
||
|
||
@classmethod
|
||
async def send(cls, connection: WebSocket, message: str, uuid: str):
|
||
data = {
|
||
'uuid': uuid,
|
||
'data': message,
|
||
}
|
||
data = json.dumps(data)
|
||
# Меняем синтаксис под Squirrel
|
||
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
|
||
await connection.send_text(data)
|
||
|
||
@classmethod
|
||
def _make_request(cls):
|
||
request_id = str(uuid4())
|
||
request = asyncio.Future()
|
||
cls._requests[request_id] = request
|
||
|
||
data = {
|
||
'uuid': request_id,
|
||
'data': None,
|
||
}
|
||
|
||
return request, data
|
||
|
||
@classmethod
|
||
def _register_routes(cls, app):
|
||
@app.websocket('/pyg2o')
|
||
async def pyg2o(websocket: WebSocket, token: str, topics: str | None):
|
||
await cls._handle_connection(websocket, token, topics)
|
||
|
||
_ = pyg2o
|
||
|
||
@classmethod
|
||
async def _subscribe(cls, topic_list: list[str], connection: WebSocket):
|
||
async with cls._topic_lock:
|
||
for topic in topic_list:
|
||
if topic not in cls._topics:
|
||
cls._topics[topic] = WeakSet()
|
||
cls._topics[topic].add(connection)
|
||
|
||
@classmethod
|
||
async def _unsubscribe(cls, topic_list: list[str], connection: WebSocket):
|
||
async with cls._topic_lock:
|
||
for topic in topic_list:
|
||
cls._topics[topic].discard(connection)
|
||
|
||
@classmethod
|
||
async def _handle_connection(cls, connection: WebSocket, token: str, topics: str | None):
|
||
|
||
cls._logger.exception(f'PyG2O соединение отклонено: получен токен {token}\nStatic Tokens: {cls._static_tokens}')
|
||
if not await cls._process_query_params(connection, token, topics):
|
||
await connection.close()
|
||
return
|
||
|
||
await connection.accept()
|
||
cls._logger.info('WebSocket клиент подключился')
|
||
|
||
try:
|
||
while True:
|
||
try:
|
||
data = await connection.receive_text()
|
||
message_data = json.loads(data)
|
||
asyncio.create_task(cls._process_message(connection, message_data))
|
||
except json.JSONDecodeError as e:
|
||
cls._logger.exception(f'Ошибка декодирования JSON: {e}')
|
||
except WebSocketDisconnect:
|
||
cls._logger.info('WebSocket клиент отключился')
|
||
except WebSocketException as e:
|
||
cls._logger.exception(f'Ошибка WebSocket подключения: {e}')
|
||
|
||
@classmethod
|
||
async def _process_query_params(cls, connection: WebSocket, token: str, topics: str | None) -> bool:
|
||
|
||
if token not in cls._static_tokens and token not in cls._temp_tokens:
|
||
return False
|
||
|
||
if topics is not None:
|
||
topic_list = [s.strip() for s in topics.split(',')]
|
||
await cls._subscribe(topic_list, connection)
|
||
|
||
return True
|
||
|
||
@classmethod
|
||
async def _process_message(cls, connection: WebSocket, message: dict):
|
||
match message:
|
||
|
||
case {'uuid': id, 'data': data}:
|
||
if id in cls._requests:
|
||
cls._requests[id].set_result(data)
|
||
else:
|
||
asyncio.create_task(call_event('onWebsocketMessage', connection, id, data))
|
||
|
||
case {'event': event, **args}:
|
||
asyncio.create_task(call_event(event, **args))
|
||
|
||
case {'subscribe': topics}:
|
||
await cls._subscribe(topics, connection)
|
||
|
||
case {'unsubscribe': topics}:
|
||
await cls._unsubscribe(topics, connection)
|
||
|
||
case {'create_temp_token': token}:
|
||
cls._temp_tokens.append(token)
|
||
|
||
case {'remove_temp_token': token}:
|
||
cls._temp_tokens.remove(token)
|
||
|
||
case _:
|
||
raise ValueError(f'Неподдерживаемый тип PyG2O сообщения: {message}')
|