IntuitionX_agent/test_ws.py

217 lines
6.8 KiB
Python
Raw Normal View History

"""
WebSocket服务器测试脚本
测试2.1和2.3接口的基本功能
"""
import asyncio
import base64
import json
import sys
from pathlib import Path
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
import websockets
async def test_2_1_abnormal_trigger():
"""测试2.1 异常状态触发对话"""
print("\n" + "="*60)
print("测试 2.1: 异常状态触发对话")
print("="*60)
uri = "ws://localhost:8765"
try:
async with websockets.connect(uri) as websocket:
# 发送异常状态触发请求
request = {
"type": "abnormal_trigger",
"trigger_reason": "poor_skin",
"enable_streaming": True,
"context_data": {
"emotion": "neutral",
"skin_status": {"acne": True}
}
}
print(f"\n[发送] 请求: {json.dumps(request, ensure_ascii=False, indent=2)}")
await websocket.send(json.dumps(request))
# 接收响应
response_count = 0
while True:
try:
message = await asyncio.wait_for(websocket.recv(), timeout=30)
data = json.loads(message)
if data.get("type") == "abnormal_trigger_response":
print(f"\n[响应] 确认响应: {data}")
elif data.get("type") == "audio_stream_download":
response_count += 1
audio_size = len(base64.b64decode(data["data"]))
is_final = data.get("is_final")
print(
f"[响应] 音频块#{response_count}: "
f"大小={audio_size}bytes, is_final={is_final}"
)
if is_final:
print(f"\n[完成] 共收到 {response_count} 个音频块")
break
else:
print(f"[响应] {data}")
except asyncio.TimeoutError:
print("[超时] 等待响应超时30秒")
break
except ConnectionRefusedError:
print("[错误] 无法连接到服务器 (localhost:8765)")
print("请先启动WebSocket服务器: python src/MainServices.py")
return False
print("\n✓ 2.1 测试完成\n")
return True
async def test_2_3_audio_stream():
"""测试2.3 双向音频流对话"""
print("\n" + "="*60)
print("测试 2.3: 双向音频流对话")
print("="*60)
uri = "ws://localhost:8765"
try:
async with websockets.connect(uri) as websocket:
# 1. 初始化音频流
init_request = {
"type": "audio_stream_init",
"session_id": "test_session_001",
"audio_config": {
"sample_rate": 16000,
"bit_depth": 16,
"channels": 1,
"encoding": "pcm"
}
}
print(f"\n[发送] 初始化请求: {json.dumps(init_request, ensure_ascii=False, indent=2)}")
await websocket.send(json.dumps(init_request))
# 接收初始化响应
message = await asyncio.wait_for(websocket.recv(), timeout=5)
init_response = json.loads(message)
print(f"[响应] 初始化响应: {init_response}")
if not init_response.get("success"):
print("[错误] 初始化失败")
return False
# 2. 模拟发送音频块(这里用虚假数据)
print("\n[模拟] 发送虚假音频数据(测试用)")
# 创建简单的PCM音频数据1000个采样点
import struct
audio_data = struct.pack('h' * 1000, *[0] * 1000) # 1000个16bit零
audio_b64 = base64.b64encode(audio_data).decode()
for i in range(3):
upload_request = {
"type": "audio_stream_upload",
"session_id": "test_session_001",
"data": audio_b64,
"sequence": i
}
print(f"[发送] 音频块 #{i+1}")
await websocket.send(json.dumps(upload_request))
await asyncio.sleep(0.5)
# 等待响应(如果有的话)
print("\n[等待] 响应30秒超时...")
response_count = 0
try:
while True:
message = await asyncio.wait_for(websocket.recv(), timeout=30)
data = json.loads(message)
if data.get("type") == "audio_stream_download":
response_count += 1
audio_size = len(base64.b64decode(data["data"]))
print(
f"[响应] 音频块#{response_count}: "
f"大小={audio_size}bytes"
)
if data.get("is_final"):
print(f"\n[完成] 共收到 {response_count} 个音频块")
break
elif data.get("type") == "error":
print(f"[错误] {data['message']}")
break
except asyncio.TimeoutError:
print("[超时] 未收到响应(这在测试数据下是正常的)")
# 3. 发送结束信号
print("\n[发送] 结束信号")
control_request = {
"type": "audio_stream_control",
"session_id": "test_session_001",
"action": "end"
}
await websocket.send(json.dumps(control_request))
except ConnectionRefusedError:
print("[错误] 无法连接到服务器 (localhost:8765)")
print("请先启动WebSocket服务器: python src/MainServices.py")
return False
print("\n✓ 2.3 测试完成\n")
return True
async def main():
"""主测试函数"""
print("\n" + "="*60)
print("WebSocket服务器测试套件")
print("="*60)
success = True
# 测试2.1
if not await test_2_1_abnormal_trigger():
success = False
# 等待一下,避免连接冲突
await asyncio.sleep(2)
# 测试2.3
if not await test_2_3_audio_stream():
success = False
# 总结
print("\n" + "="*60)
if success:
print("✓ 所有测试完成!")
else:
print("✗ 某些测试失败")
print("="*60 + "\n")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n\n[中断] 测试已中断\n")
except Exception as e:
print(f"\n[错误] {e}\n")
import traceback
traceback.print_exc()