""" WebSocket服务器测试脚本 测试2.1和2.3接口的基本功能 """ import asyncio import base64 import json import sys from pathlib import Path # 添加src目录到路径 sys.path.insert(0, str(Path(__file__).parent)) import websockets async def test_2_1_abnormal_trigger(): """测试2.1 异常状态触发对话""" print("\n" + "="*60) print("测试 2.1: 异常状态触发对话") print("="*60) uri = "ws://localhost:8765" try: async with websockets.connect(uri) as websocket: # 发送异常状态触发请求 request = { "type": "abnormal_trigger", "trigger_reason": "poor_skin", "enable_streaming": True, "context_data": { "emotion": "neutral", "skin_status": {"acne": True} } } print(f"\n[发送] 请求: {json.dumps(request, ensure_ascii=False, indent=2)}") await websocket.send(json.dumps(request)) # 接收响应 response_count = 0 while True: try: message = await asyncio.wait_for(websocket.recv(), timeout=30) data = json.loads(message) if data.get("type") == "abnormal_trigger_response": print(f"\n[响应] 确认响应: {data}") elif data.get("type") == "audio_stream_download": response_count += 1 audio_size = len(base64.b64decode(data["data"])) is_final = data.get("is_final") print( f"[响应] 音频块#{response_count}: " f"大小={audio_size}bytes, is_final={is_final}" ) if is_final: print(f"\n[完成] 共收到 {response_count} 个音频块") break else: print(f"[响应] {data}") except asyncio.TimeoutError: print("[超时] 等待响应超时(30秒)") break except ConnectionRefusedError: print("[错误] 无法连接到服务器 (localhost:8765)") print("请先启动WebSocket服务器: python src/MainServices.py") return False print("\n✓ 2.1 测试完成\n") return True async def test_2_3_audio_stream(): """测试2.3 双向音频流对话""" print("\n" + "="*60) print("测试 2.3: 双向音频流对话") print("="*60) uri = "ws://localhost:8765" try: async with websockets.connect(uri) as websocket: # 1. 初始化音频流 init_request = { "type": "audio_stream_init", "session_id": "test_session_001", "audio_config": { "sample_rate": 16000, "bit_depth": 16, "channels": 1, "encoding": "pcm" } } print(f"\n[发送] 初始化请求: {json.dumps(init_request, ensure_ascii=False, indent=2)}") await websocket.send(json.dumps(init_request)) # 接收初始化响应 message = await asyncio.wait_for(websocket.recv(), timeout=5) init_response = json.loads(message) print(f"[响应] 初始化响应: {init_response}") if not init_response.get("success"): print("[错误] 初始化失败") return False # 2. 模拟发送音频块(这里用虚假数据) print("\n[模拟] 发送虚假音频数据(测试用)") # 创建简单的PCM音频数据(1000个采样点) import struct audio_data = struct.pack('h' * 1000, *[0] * 1000) # 1000个16bit零 audio_b64 = base64.b64encode(audio_data).decode() for i in range(3): upload_request = { "type": "audio_stream_upload", "session_id": "test_session_001", "data": audio_b64, "sequence": i } print(f"[发送] 音频块 #{i+1}") await websocket.send(json.dumps(upload_request)) await asyncio.sleep(0.5) # 等待响应(如果有的话) print("\n[等待] 响应(30秒超时)...") response_count = 0 try: while True: message = await asyncio.wait_for(websocket.recv(), timeout=30) data = json.loads(message) if data.get("type") == "audio_stream_download": response_count += 1 audio_size = len(base64.b64decode(data["data"])) print( f"[响应] 音频块#{response_count}: " f"大小={audio_size}bytes" ) if data.get("is_final"): print(f"\n[完成] 共收到 {response_count} 个音频块") break elif data.get("type") == "error": print(f"[错误] {data['message']}") break except asyncio.TimeoutError: print("[超时] 未收到响应(这在测试数据下是正常的)") # 3. 发送结束信号 print("\n[发送] 结束信号") control_request = { "type": "audio_stream_control", "session_id": "test_session_001", "action": "end" } await websocket.send(json.dumps(control_request)) except ConnectionRefusedError: print("[错误] 无法连接到服务器 (localhost:8765)") print("请先启动WebSocket服务器: python src/MainServices.py") return False print("\n✓ 2.3 测试完成\n") return True async def main(): """主测试函数""" print("\n" + "="*60) print("WebSocket服务器测试套件") print("="*60) success = True # 测试2.1 if not await test_2_1_abnormal_trigger(): success = False # 等待一下,避免连接冲突 await asyncio.sleep(2) # 测试2.3 if not await test_2_3_audio_stream(): success = False # 总结 print("\n" + "="*60) if success: print("✓ 所有测试完成!") else: print("✗ 某些测试失败") print("="*60 + "\n") if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\n\n[中断] 测试已中断\n") except Exception as e: print(f"\n[错误] {e}\n") import traceback traceback.print_exc()