120 lines
4.4 KiB
Python
120 lines
4.4 KiB
Python
import http.server
|
|
import socketserver
|
|
import json
|
|
import sys
|
|
import os
|
|
import asyncio
|
|
import threading
|
|
from pathlib import Path
|
|
|
|
# Add project root to sys.path
|
|
PROJECT_ROOT = Path(__file__).parent
|
|
if str(PROJECT_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from src.Module.llm.llm import StreamingLLM
|
|
from src.Module.tts.tts import StreamingTTS
|
|
from src.utils.prompts import get_trigger_prompt
|
|
from loguru import logger
|
|
|
|
# Configuration
|
|
PORT = 8000
|
|
|
|
class AgentRequestHandler(http.server.BaseHTTPRequestHandler):
|
|
def _set_headers(self, content_type='application/json'):
|
|
self.send_response(200)
|
|
self.send_header('Content-type', content_type)
|
|
self.end_headers()
|
|
|
|
def do_POST(self):
|
|
if self.path == '/abnormal_trigger':
|
|
content_length = int(self.headers['Content-Length'])
|
|
post_data = self.rfile.read(content_length)
|
|
|
|
try:
|
|
data = json.loads(post_data.decode('utf-8'))
|
|
logger.info(f"Received abnormal_trigger: {data}")
|
|
|
|
trigger_reason = data.get("trigger_reason")
|
|
# context_data = data.get("context_data") # Not used in prompt yet, but available
|
|
|
|
if not trigger_reason:
|
|
self.send_error(400, "Missing trigger_reason")
|
|
return
|
|
|
|
# Generate Audio
|
|
audio_data = self.generate_response_audio(trigger_reason)
|
|
|
|
if audio_data:
|
|
self._set_headers('application/octet-stream')
|
|
self.wfile.write(audio_data)
|
|
logger.info("Sent audio response")
|
|
else:
|
|
self.send_error(500, "Failed to generate audio")
|
|
|
|
except json.JSONDecodeError:
|
|
self.send_error(400, "Invalid JSON")
|
|
except Exception as e:
|
|
logger.error(f"Error processing request: {e}")
|
|
self.send_error(500, str(e))
|
|
else:
|
|
self.send_error(404, "Not Found")
|
|
|
|
def generate_response_audio(self, trigger_reason):
|
|
"""
|
|
Orchestrates LLM and TTS to produce audio bytes.
|
|
"""
|
|
system_prompt = get_trigger_prompt(trigger_reason)
|
|
logger.info(f"Generated prompt for {trigger_reason}")
|
|
|
|
# Run async logic synchronously
|
|
return asyncio.run(self._async_generate(system_prompt))
|
|
|
|
async def _async_generate(self, system_prompt):
|
|
llm = StreamingLLM()
|
|
# Ensure we use a voice that works. 'Cherry' is used in other handlers.
|
|
tts = StreamingTTS(voice='Cherry')
|
|
|
|
full_audio = bytearray()
|
|
|
|
# Generator for LLM text
|
|
def text_generator():
|
|
# We use chat with empty message because the system_prompt contains the instruction
|
|
# and the trigger implies "user just triggered this state".
|
|
# Or we can put a dummy user message like "." or "START".
|
|
# Looking at src/handlers/abnormal_trigger.py: it calls llm.chat(message="", system_prompt=system_prompt)
|
|
for chunk in llm.chat(message="(用户触发了异常状态,请直接根据系统提示开始说话)", system_prompt=system_prompt):
|
|
if chunk.error:
|
|
logger.error(f"LLM Error: {chunk.error}")
|
|
continue
|
|
if chunk.content:
|
|
# logger.debug(f"LLM Chunk: {chunk.content}")
|
|
yield chunk.content
|
|
|
|
# Stream TTS from LLM text generator
|
|
try:
|
|
for audio_chunk in tts.stream_from_generator(text_generator()):
|
|
if audio_chunk.error:
|
|
logger.error(f"TTS Error: {audio_chunk.error}")
|
|
continue
|
|
if audio_chunk.data:
|
|
full_audio.extend(audio_chunk.data)
|
|
|
|
return bytes(full_audio)
|
|
except Exception as e:
|
|
logger.error(f"Generation failed: {e}")
|
|
return None
|
|
|
|
def run(server_class=http.server.HTTPServer, handler_class=AgentRequestHandler, port=PORT):
|
|
server_address = ('0.0.0.0', port)
|
|
httpd = server_class(server_address, handler_class)
|
|
logger.info(f"Starting Agent HTTP Server on port {port}...")
|
|
try:
|
|
httpd.serve_forever()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
httpd.server_close()
|
|
logger.info("Server stopped.")
|
|
|
|
if __name__ == '__main__':
|
|
run() |