Files
PyG2O/src/pyg2o/server.py

275 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import asyncio
from types import SimpleNamespace
from weakref import WeakValueDictionary, WeakSet, finalize
from collections import UserDict
from uuid import uuid4
from loguru import logger
from fastapi import WebSocket, FastAPI, WebSocketDisconnect, WebSocketException
from .event import call_event
class TopicWeakDict(UserDict):
# Аналог WeakValueDict, но с сильными ссылками для возможности поддержки WeakSet
def __init__(self, initial_data = None):
super().__init__(initial_data or {})
def __setitem__(self, key, value):
finalize(value, lambda k=key, s=self: s.pop(k))
super().__setitem__(key, value)
class Server:
_static_tokens: list[str] = []
_temp_tokens: list[str] = []
_server_connection: WebSocket | None = None
_registered_clients: dict[int, list] = {}
_requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
_topics = TopicWeakDict()
_topic_lock = asyncio.Lock()
@classmethod
def init(cls, *, app: FastAPI, static_tokens: list[str] = []):
cls._static_tokens = static_tokens
cls._requests: WeakValueDictionary[str, asyncio.Future] = WeakValueDictionary()
cls._register_routes(app)
@classmethod
async def _async_send(cls, connection_list: list[WebSocket], data):
for connection in connection_list:
await connection.send_text(data)
@classmethod
def send(cls, connection: WebSocket | str | int, event: str, message: dict, uuid: str | None = None):
try:
if isinstance(connection, WebSocket):
connection_list = [connection]
elif isinstance(connection, int):
connection_list = cls._registered_clients[connection]
else:
connection_list = cls._topics[connection]
request, data = cls._make_request(event, uuid)
data.update(message)
except KeyError:
raise KeyError('Нет зарегистрированного клиента с таким ID')
except ValueError:
raise ValueError('message должен быть типа dict')
message_uuid = data['uuid']
data = json.dumps(data)
# Меняем синтаксис под Squirrel
data = data.replace("'", '\\"').replace('True', 'true').replace('False', 'false')
logs_connection_list = [f"{item.client.host}:{item.client.port}" for item in connection_list]
logger.info(
'Отправлено новое сообщение по каналу WebSocket',
log_type = 'PyG2O',
receivers = logs_connection_list,
message_uuid = message_uuid,
message_data = data,
)
asyncio.create_task(cls._async_send(connection_list, data))
return request
@classmethod
def sq_execute(cls, code: str) -> asyncio.Future:
if cls._server_connection is not None:
return cls.send(cls._server_connection, 'sq_execute', {'code': code})
else:
raise ConnectionError('Сервер не подключен к PyG2O')
@classmethod
def _make_request(cls, event: str, uuid: str | None = None):
if uuid is None:
request_id = str(uuid4())
else:
request_id = uuid
request = asyncio.Future()
cls._requests[request_id] = request
data = {
'event': event,
'uuid': request_id,
}
return request, data
@classmethod
def _register_routes(cls, app):
@app.websocket('/pyg2o')
async def pyg2o(websocket: WebSocket, token: str):
await cls._handle_connection(websocket, token)
_ = pyg2o
@classmethod
async def _subscribe(cls, topic_list: list[str], connection: WebSocket):
async with cls._topic_lock:
for topic in topic_list:
if topic not in cls._topics:
cls._topics[topic] = WeakSet()
cls._topics[topic].add(connection)
@classmethod
async def _unsubscribe(cls, topic_list: list[str], connection: WebSocket):
async with cls._topic_lock:
for topic in topic_list:
cls._topics[topic].discard(connection)
@classmethod
async def _handle_connection(cls, connection: WebSocket, token: str):
if token not in cls._static_tokens and token not in cls._temp_tokens:
await connection.close()
return
await connection.accept()
await cls._subscribe(['all'], connection)
logger.info(
'WebSocket клиент подключился',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
)
try:
while True:
try:
data = await connection.receive_text()
message_data = json.loads(data)
asyncio.create_task(cls._process_message(connection, message_data))
except json.JSONDecodeError as e:
logger.info(
'Ошибка декодирования JSON сообщения',
log_type = 'PyG2O',
description = e,
message_data = data,
connection = f"{connection.client.host}:{connection.client.port}",
)
except WebSocketDisconnect:
if connection == cls._server_connection:
cls._server_connection = None
logger.info(
'WebSocket G2O сервер отключился',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
)
else:
playerid = next((key for key, values in cls._registered_clients.items() if connection in values), None)
if playerid is not None: cls._registered_clients[playerid].remove(connection)
logger.info(
'WebSocket клиент отключился',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
playerid = playerid,
)
except WebSocketException as e:
logger.exception(
'Ошибка при обработке WebSocket сообщения',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
description = e,
)
@classmethod
async def _process_message(cls, connection: WebSocket, message: dict):
match message:
case {'event': 'subscribe', 'topics': topics}:
await cls._subscribe(topics, connection)
case {'event': 'unsubscribe', 'topics': topics}:
await cls._unsubscribe(topics, connection)
case {'event': 'create_temp_token', 'token': token}:
cls._temp_tokens.append(token)
case {'event': 'remove_temp_token', 'token': token}:
cls._temp_tokens.remove(token)
case {'event': 'init_temp_tokens', 'tokens': tokens}:
cls._temp_tokens = cls._temp_tokens + list(tokens.items())
case {'event': 'register_client', 'playerid': playerid}:
try:
cls._registered_clients[playerid].append(connection)
except KeyError:
cls._registered_clients[playerid] = [connection]
logs_connection_list = [f"{item.client.host}:{item.client.port}" for item in cls._registered_clients[playerid]]
logger.info(
'Зарегистрирован новый WebSocket клиент',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
playerid = playerid,
total_playerid_connections = logs_connection_list,
)
case {'event': 'register_server'}:
if cls._server_connection is None:
cls._server_connection = connection
logger.info(
'Зарегистрирован новый WebSocket G2O сервер',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
)
case {'event': 'sq_response', 'uuid': uuid, 'data': data}:
try:
cls._requests[uuid].set_result(data)
logger.info(
'Получен ответ от G2O сервера (sq_response)',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
message_uuid = uuid,
message_data = data,
)
except KeyError:
logger.warning(
'Получен неожиданный ответ от G2O сервера',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
message_uuid = uuid,
message_data = data,
)
case {'event': event, 'uuid': uuid, **kwargs}:
try:
cls._requests[uuid].set_result(SimpleNamespace(**kwargs))
logger.info(
'Получен ответ от клиента',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
message_event = event,
message_uuid = uuid,
message_data = kwargs,
)
except KeyError:
kwargs['uuid'] = uuid
kwargs['connection'] = connection
playerid = next((key for key, values in cls._registered_clients.items() if connection in values), None)
if playerid is not None: kwargs['playerid'] = playerid
asyncio.create_task(call_event(event, **kwargs))
if (event != 'onTick' and event != 'onTime' and event != 'onPlayerChangeChunk'):
logger.info(
'Получено сообщение от сервера' if connection == cls._server_connection else 'Получено сообщение от клиента',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
message_event = event,
message_uuid = uuid,
message_data = kwargs,
)
case _:
logger.error(
'Получено неподдерживаемое сообщение от сервера' if connection == cls._server_connection else 'Получено неподдерживаемое сообщение от клиента',
log_type = 'PyG2O',
connection = f"{connection.client.host}:{connection.client.port}",
message_data = message,
)
raise ValueError(f'Неподдерживаемый тип PyG2O сообщения: {message}')