188 lines
7.2 KiB
Python
188 lines
7.2 KiB
Python
from __future__ import annotations
|
||
import json
|
||
from types import SimpleNamespace
|
||
import loguru
|
||
import asyncio
|
||
from weakref import WeakValueDictionary, WeakSet, finalize
|
||
from collections import UserDict
|
||
from uuid import uuid4
|
||
from fastapi import WebSocket, FastAPI, WebSocketDisconnect, WebSocketException
|
||
from .event import call_event
|
||
|
||
class TopicWeakDict(UserDict):
|
||
# Аналог WeakValueDict, но с сильными ссылками для возможности поддержки WeakSet
|
||
def __init__(self, initial_data = None):
|
||
super().__init__(initial_data or {})
|
||
|
||
def __setitem__(self, key, value):
|
||
finalize(value, lambda k=key, s=self: s.pop(k))
|
||
super().__setitem__(key, value)
|
||
|
||
class Server:
|
||
_logger: loguru.Logger = loguru.logger
|
||
_static_tokens: list[str] = []
|
||
_temp_tokens: list[str] = []
|
||
_server_connection: WebSocket | None = None
|
||
_registered_clients: dict[int, list] = {}
|
||
_requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
|
||
_topics = TopicWeakDict()
|
||
_topic_lock = asyncio.Lock()
|
||
|
||
@classmethod
|
||
def init(cls, *, app: FastAPI, static_tokens: list[str] = []):
|
||
cls._static_tokens = static_tokens
|
||
|
||
cls._requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
|
||
cls._register_routes(app)
|
||
|
||
@classmethod
|
||
async def _async_send(cls, connection_list: list[WebSocket], data):
|
||
for connection in connection_list:
|
||
await connection.send_text(data)
|
||
|
||
@classmethod
|
||
def send(cls, connection: WebSocket | str | int, event: str, message: dict, uuid: str | None = None):
|
||
try:
|
||
if isinstance(connection, WebSocket):
|
||
connection_list = [connection]
|
||
elif isinstance(connection, int):
|
||
connection_list = cls._registered_clients[connection]
|
||
else:
|
||
connection_list = cls._topics[connection]
|
||
|
||
request, data = cls._make_request(event, uuid)
|
||
data.update(message)
|
||
except KeyError:
|
||
raise KeyError('Нет зарегистрированного клиента с таким ID')
|
||
except ValueError:
|
||
raise ValueError('message должен быть типа dict')
|
||
|
||
data = json.dumps(data)
|
||
# Меняем синтаксис под Squirrel
|
||
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
|
||
asyncio.create_task(cls._async_send(connection_list, data))
|
||
return request
|
||
|
||
@classmethod
|
||
def sq_execute(cls, code: str) -> asyncio.Future:
|
||
if cls._server_connection is not None:
|
||
return cls.send(cls._server_connection, 'sq_execute', {'code': code})
|
||
else:
|
||
raise ConnectionError('Сервер не подключен к PyG2O')
|
||
|
||
@classmethod
|
||
def _make_request(cls, event: str, uuid: str | None = None):
|
||
if uuid is None:
|
||
request_id = str(uuid4())
|
||
else:
|
||
request_id = uuid
|
||
|
||
request = asyncio.Future()
|
||
cls._requests[request_id] = request
|
||
|
||
data = {
|
||
'event': event,
|
||
'uuid': request_id,
|
||
}
|
||
|
||
return request, data
|
||
|
||
@classmethod
|
||
def _register_routes(cls, app):
|
||
@app.websocket('/pyg2o')
|
||
async def pyg2o(websocket: WebSocket, token: str):
|
||
await cls._handle_connection(websocket, token)
|
||
|
||
_ = pyg2o
|
||
|
||
@classmethod
|
||
async def _subscribe(cls, topic_list: list[str], connection: WebSocket):
|
||
async with cls._topic_lock:
|
||
for topic in topic_list:
|
||
if topic not in cls._topics:
|
||
cls._topics[topic] = WeakSet()
|
||
cls._topics[topic].add(connection)
|
||
|
||
@classmethod
|
||
async def _unsubscribe(cls, topic_list: list[str], connection: WebSocket):
|
||
async with cls._topic_lock:
|
||
for topic in topic_list:
|
||
cls._topics[topic].discard(connection)
|
||
|
||
@classmethod
|
||
async def _handle_connection(cls, connection: WebSocket, token: str):
|
||
|
||
if token not in cls._static_tokens and token not in cls._temp_tokens:
|
||
await connection.close()
|
||
return
|
||
|
||
await connection.accept()
|
||
await cls._subscribe(['all'], connection)
|
||
cls._logger.info('WebSocket клиент подключился')
|
||
|
||
try:
|
||
while True:
|
||
try:
|
||
data = await connection.receive_text()
|
||
message_data = json.loads(data)
|
||
asyncio.create_task(cls._process_message(connection, message_data))
|
||
except json.JSONDecodeError as e:
|
||
cls._logger.exception(f'Ошибка декодирования JSON: {e}')
|
||
except WebSocketDisconnect:
|
||
cls._logger.info('WebSocket клиент отключился')
|
||
if connection == cls._server_connection:
|
||
cls._server_connection = None
|
||
else:
|
||
playerid = next((key for key, values in cls._registered_clients.items() if connection in values), None)
|
||
if playerid is not None: cls._registered_clients[playerid].remove(connection)
|
||
except WebSocketException as e:
|
||
cls._logger.exception(f'Ошибка WebSocket подключения: {e}')
|
||
|
||
@classmethod
|
||
async def _process_message(cls, connection: WebSocket, message: dict):
|
||
match message:
|
||
|
||
case {'event': 'subscribe', 'topics': topics}:
|
||
await cls._subscribe(topics, connection)
|
||
|
||
case {'event': 'unsubscribe', 'topics': topics}:
|
||
await cls._unsubscribe(topics, connection)
|
||
|
||
case {'event': 'create_temp_token', 'token': token}:
|
||
cls._temp_tokens.append(token)
|
||
|
||
case {'event': 'remove_temp_token', 'token': token}:
|
||
cls._temp_tokens.remove(token)
|
||
|
||
case {'event': 'init_temp_tokens', 'tokens': tokens}:
|
||
cls._temp_tokens = cls._temp_tokens + list(tokens.items())
|
||
|
||
case {'event': 'register_client', 'playerid': playerid}:
|
||
try:
|
||
cls._registered_clients[playerid].append(connection)
|
||
except KeyError:
|
||
cls._registered_clients[playerid] = [connection]
|
||
|
||
case {'event': 'register_server'}:
|
||
if cls._server_connection is None:
|
||
cls._server_connection = connection
|
||
|
||
case {'event': 'sq_response', 'uuid': uuid, 'data': data}:
|
||
try:
|
||
cls._requests[uuid].set_result(data)
|
||
except KeyError:
|
||
...
|
||
|
||
case {'event': event, 'uuid': uuid, **kwargs}:
|
||
try:
|
||
cls._requests[uuid].set_result(SimpleNamespace(**kwargs))
|
||
except KeyError:
|
||
kwargs['uuid'] = uuid
|
||
kwargs['connection'] = connection
|
||
playerid = next((key for key, values in cls._registered_clients.items() if connection in values), None)
|
||
if playerid is not None: kwargs['playerid'] = playerid
|
||
asyncio.create_task(call_event(event, **kwargs))
|
||
|
||
case _:
|
||
raise ValueError(f'Неподдерживаемый тип PyG2O сообщения: {message}')
|