Files
PyG2O/src/pyg2o/server.py

160 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import loguru
import asyncio
from weakref import WeakValueDictionary, WeakSet, finalize
from collections import UserDict
from uuid import uuid4
from fastapi import WebSocket, FastAPI, WebSocketDisconnect, WebSocketException
from .event import call_event
class TopicWeakDict(UserDict):
# Аналог WeakValueDict, но с сильными ссылками для возможности поддержки WeakSet
def __init__(self, initial_data = None):
super().__init__(initial_data or {})
def __setitem__(self, key, value):
finalize(value, lambda k=key, s=self: s.pop(k))
super().__setitem__(key, value)
class Server:
_logger: loguru.Logger = loguru.logger
_static_tokens: list[str] = []
_temp_tokens: list[str] = []
_requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
_topics = TopicWeakDict()
_topic_lock = asyncio.Lock()
@classmethod
def init(cls, *, app: FastAPI, static_tokens: list[str] = []):
cls._static_tokens = static_tokens
cls._requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
cls._register_routes(app)
@classmethod
async def publish(cls, topic: str, message: str) -> asyncio.Future:
if topic not in cls._topics:
raise KeyError('Клиентов прослушивающих этот топик не существует')
request, data = cls._make_request()
data['data'] = message
data = json.dumps(data)
# Меняем синтаксис под Squirrel
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
for connection in cls._topics[topic]:
await connection.send_text(data)
return request
@classmethod
async def send(cls, connection: WebSocket, message: str, uuid: str):
data = {
'uuid': uuid,
'data': message,
}
data = json.dumps(data)
# Меняем синтаксис под Squirrel
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
await connection.send_text(data)
@classmethod
def _make_request(cls):
request_id = str(uuid4())
request = asyncio.Future()
cls._requests[request_id] = request
data = {
'uuid': request_id,
'data': None,
}
return request, data
@classmethod
def _register_routes(cls, app):
@app.websocket('/pyg2o')
async def pyg2o(websocket: WebSocket, token: str, topics: str | None):
await cls._handle_connection(websocket, token, topics)
_ = pyg2o
@classmethod
async def _subscribe(cls, topic_list: list[str], connection: WebSocket):
async with cls._topic_lock:
for topic in topic_list:
if topic not in cls._topics:
cls._topics[topic] = WeakSet()
cls._topics[topic].add(connection)
@classmethod
async def _unsubscribe(cls, topic_list: list[str], connection: WebSocket):
async with cls._topic_lock:
for topic in topic_list:
cls._topics[topic].discard(connection)
@classmethod
async def _handle_connection(cls, connection: WebSocket, token: str, topics: str | None):
cls._logger.exception(f'PyG2O соединение отклонено: получен токен {token}\nStatic Tokens: {cls._static_tokens}')
if not await cls._process_query_params(connection, token, topics):
await connection.close()
return
await connection.accept()
cls._logger.info('WebSocket клиент подключился')
try:
while True:
try:
data = await connection.receive_text()
message_data = json.loads(data)
asyncio.create_task(cls._process_message(connection, message_data))
except json.JSONDecodeError as e:
cls._logger.exception(f'Ошибка декодирования JSON: {e}')
except WebSocketDisconnect:
cls._logger.info('WebSocket клиент отключился')
except WebSocketException as e:
cls._logger.exception(f'Ошибка WebSocket подключения: {e}')
@classmethod
async def _process_query_params(cls, connection: WebSocket, token: str, topics: str | None) -> bool:
if token not in cls._static_tokens and token not in cls._temp_tokens:
return False
if topics is not None:
topic_list = [s.strip() for s in topics.split(',')]
await cls._subscribe(topic_list, connection)
return True
@classmethod
async def _process_message(cls, connection: WebSocket, message: dict):
match message:
case {'uuid': id, 'data': data}:
if id in cls._requests:
cls._requests[id].set_result(data)
else:
asyncio.create_task(call_event('onWebsocketMessage', connection, id, data))
case {'event': event, **args}:
asyncio.create_task(call_event(event, **args))
case {'subscribe': topics}:
await cls._subscribe(topics, connection)
case {'unsubscribe': topics}:
await cls._unsubscribe(topics, connection)
case {'create_temp_token': token}:
cls._temp_tokens.append(token)
case {'remove_temp_token': token}:
cls._temp_tokens.remove(token)
case _:
raise ValueError(f'Неподдерживаемый тип PyG2O сообщения: {message}')