Files
PyG2O/src/pyg2o/server.py

188 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
from types import SimpleNamespace
import loguru
import asyncio
from weakref import WeakValueDictionary, WeakSet, finalize
from collections import UserDict
from uuid import uuid4
from fastapi import WebSocket, FastAPI, WebSocketDisconnect, WebSocketException
from .event import call_event
class TopicWeakDict(UserDict):
# Аналог WeakValueDict, но с сильными ссылками для возможности поддержки WeakSet
def __init__(self, initial_data = None):
super().__init__(initial_data or {})
def __setitem__(self, key, value):
finalize(value, lambda k=key, s=self: s.pop(k))
super().__setitem__(key, value)
class Server:
_logger: loguru.Logger = loguru.logger
_static_tokens: list[str] = []
_temp_tokens: list[str] = []
_server_connection: WebSocket | None = None
_registered_clients: dict[int, list] = {}
_requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
_topics = TopicWeakDict()
_topic_lock = asyncio.Lock()
@classmethod
def init(cls, *, app: FastAPI, static_tokens: list[str] = []):
cls._static_tokens = static_tokens
cls._requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
cls._register_routes(app)
@classmethod
async def _async_send(cls, connection_list: list[WebSocket], data):
for connection in connection_list:
await connection.send_text(data)
@classmethod
def send(cls, connection: WebSocket | str | int, event: str, message: dict, uuid: str | None = None):
try:
if isinstance(connection, WebSocket):
connection_list = [connection]
elif isinstance(connection, int):
connection_list = cls._registered_clients[connection]
else:
connection_list = cls._topics[connection]
request, data = cls._make_request(event, uuid)
data.update(message)
except KeyError:
raise KeyError('Нет зарегистрированного клиента с таким ID')
except ValueError:
raise ValueError('message должен быть типа dict')
data = json.dumps(data)
# Меняем синтаксис под Squirrel
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
asyncio.create_task(cls._async_send(connection_list, data))
return request
@classmethod
def sq_execute(cls, code: str) -> asyncio.Future:
if cls._server_connection is not None:
return cls.send(cls._server_connection, 'sq_execute', {'code': code})
else:
raise ConnectionError('Сервер не подключен к PyG2O')
@classmethod
def _make_request(cls, event: str, uuid: str | None = None):
if uuid is None:
request_id = str(uuid4())
else:
request_id = uuid
request = asyncio.Future()
cls._requests[request_id] = request
data = {
'event': event,
'uuid': request_id,
}
return request, data
@classmethod
def _register_routes(cls, app):
@app.websocket('/pyg2o')
async def pyg2o(websocket: WebSocket, token: str):
await cls._handle_connection(websocket, token)
_ = pyg2o
@classmethod
async def _subscribe(cls, topic_list: list[str], connection: WebSocket):
async with cls._topic_lock:
for topic in topic_list:
if topic not in cls._topics:
cls._topics[topic] = WeakSet()
cls._topics[topic].add(connection)
@classmethod
async def _unsubscribe(cls, topic_list: list[str], connection: WebSocket):
async with cls._topic_lock:
for topic in topic_list:
cls._topics[topic].discard(connection)
@classmethod
async def _handle_connection(cls, connection: WebSocket, token: str):
if token not in cls._static_tokens and token not in cls._temp_tokens:
await connection.close()
return
await connection.accept()
await cls._subscribe(['all'], connection)
cls._logger.info('WebSocket клиент подключился')
try:
while True:
try:
data = await connection.receive_text()
message_data = json.loads(data)
asyncio.create_task(cls._process_message(connection, message_data))
except json.JSONDecodeError as e:
cls._logger.exception(f'Ошибка декодирования JSON: {e}')
except WebSocketDisconnect:
cls._logger.info('WebSocket клиент отключился')
if connection == cls._server_connection:
cls._server_connection = None
else:
playerid = next((key for key, values in cls._registered_clients.items() if connection in values), None)
if playerid is not None: cls._registered_clients[playerid].remove(connection)
except WebSocketException as e:
cls._logger.exception(f'Ошибка WebSocket подключения: {e}')
@classmethod
async def _process_message(cls, connection: WebSocket, message: dict):
match message:
case {'event': 'subscribe', 'topics': topics}:
await cls._subscribe(topics, connection)
case {'event': 'unsubscribe', 'topics': topics}:
await cls._unsubscribe(topics, connection)
case {'event': 'create_temp_token', 'token': token}:
cls._temp_tokens.append(token)
case {'event': 'remove_temp_token', 'token': token}:
cls._temp_tokens.remove(token)
case {'event': 'init_temp_tokens', 'tokens': tokens}:
cls._temp_tokens = cls._temp_tokens + list(tokens.items())
case {'event': 'register_client', 'playerid': playerid}:
try:
cls._registered_clients[playerid].append(connection)
except KeyError:
cls._registered_clients[playerid] = [connection]
case {'event': 'register_server'}:
if cls._server_connection is None:
cls._server_connection = connection
case {'event': 'sq_response', 'uuid': uuid, 'data': data}:
try:
cls._requests[uuid].set_result(data)
except KeyError:
...
case {'event': event, 'uuid': uuid, **kwargs}:
try:
cls._requests[uuid].set_result(SimpleNamespace(**kwargs))
except KeyError:
kwargs['uuid'] = uuid
kwargs['connection'] = connection
playerid = next((key for key, values in cls._registered_clients.items() if connection in values), None)
if playerid is not None: kwargs['playerid'] = playerid
asyncio.create_task(call_event(event, **kwargs))
case _:
raise ValueError(f'Неподдерживаемый тип PyG2O сообщения: {message}')