feat: ✨ http简单用户状态异常接口
This commit is contained in:
parent
0ab542bfb4
commit
4c91bc650a
122
main.py
122
main.py
@ -1,6 +1,120 @@
|
||||
def main():
|
||||
print("Hello from Heart_Mirror_Agent!")
|
||||
import http.server
|
||||
import socketserver
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to sys.path
|
||||
PROJECT_ROOT = Path(__file__).parent
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
from src.Module.llm.llm import StreamingLLM
|
||||
from src.Module.tts.tts import StreamingTTS
|
||||
from src.utils.prompts import get_trigger_prompt
|
||||
from loguru import logger
|
||||
|
||||
# Configuration
|
||||
PORT = 8000
|
||||
|
||||
class AgentRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||
def _set_headers(self, content_type='application/json'):
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', content_type)
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self):
|
||||
if self.path == '/abnormal_trigger':
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = self.rfile.read(content_length)
|
||||
|
||||
try:
|
||||
data = json.loads(post_data.decode('utf-8'))
|
||||
logger.info(f"Received abnormal_trigger: {data}")
|
||||
|
||||
trigger_reason = data.get("trigger_reason")
|
||||
# context_data = data.get("context_data") # Not used in prompt yet, but available
|
||||
|
||||
if not trigger_reason:
|
||||
self.send_error(400, "Missing trigger_reason")
|
||||
return
|
||||
|
||||
# Generate Audio
|
||||
audio_data = self.generate_response_audio(trigger_reason)
|
||||
|
||||
if audio_data:
|
||||
self._set_headers('application/octet-stream')
|
||||
self.wfile.write(audio_data)
|
||||
logger.info("Sent audio response")
|
||||
else:
|
||||
self.send_error(500, "Failed to generate audio")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
self.send_error(400, "Invalid JSON")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {e}")
|
||||
self.send_error(500, str(e))
|
||||
else:
|
||||
self.send_error(404, "Not Found")
|
||||
|
||||
def generate_response_audio(self, trigger_reason):
|
||||
"""
|
||||
Orchestrates LLM and TTS to produce audio bytes.
|
||||
"""
|
||||
system_prompt = get_trigger_prompt(trigger_reason)
|
||||
logger.info(f"Generated prompt for {trigger_reason}")
|
||||
|
||||
# Run async logic synchronously
|
||||
return asyncio.run(self._async_generate(system_prompt))
|
||||
|
||||
async def _async_generate(self, system_prompt):
|
||||
llm = StreamingLLM()
|
||||
# Ensure we use a voice that works. 'Cherry' is used in other handlers.
|
||||
tts = StreamingTTS(voice='Cherry')
|
||||
|
||||
full_audio = bytearray()
|
||||
|
||||
# Generator for LLM text
|
||||
def text_generator():
|
||||
# We use chat with empty message because the system_prompt contains the instruction
|
||||
# and the trigger implies "user just triggered this state".
|
||||
# Or we can put a dummy user message like "." or "START".
|
||||
# Looking at src/handlers/abnormal_trigger.py: it calls llm.chat(message="", system_prompt=system_prompt)
|
||||
for chunk in llm.chat(message="(用户触发了异常状态,请直接根据系统提示开始说话)", system_prompt=system_prompt):
|
||||
if chunk.error:
|
||||
logger.error(f"LLM Error: {chunk.error}")
|
||||
continue
|
||||
if chunk.content:
|
||||
# logger.debug(f"LLM Chunk: {chunk.content}")
|
||||
yield chunk.content
|
||||
|
||||
# Stream TTS from LLM text generator
|
||||
try:
|
||||
for audio_chunk in tts.stream_from_generator(text_generator()):
|
||||
if audio_chunk.error:
|
||||
logger.error(f"TTS Error: {audio_chunk.error}")
|
||||
continue
|
||||
if audio_chunk.data:
|
||||
full_audio.extend(audio_chunk.data)
|
||||
|
||||
return bytes(full_audio)
|
||||
except Exception as e:
|
||||
logger.error(f"Generation failed: {e}")
|
||||
return None
|
||||
|
||||
def run(server_class=http.server.HTTPServer, handler_class=AgentRequestHandler, port=PORT):
|
||||
server_address = ('0.0.0.0', port)
|
||||
httpd = server_class(server_address, handler_class)
|
||||
logger.info(f"Starting Agent HTTP Server on port {port}...")
|
||||
try:
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
httpd.server_close()
|
||||
logger.info("Server stopped.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
58
test_agent_http.py
Normal file
58
test_agent_http.py
Normal file
@ -0,0 +1,58 @@
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import json
|
||||
import time
|
||||
|
||||
def test_abnormal_trigger():
|
||||
url = "http://172.20.10.2:8000/abnormal_trigger"
|
||||
|
||||
payload = {
|
||||
"type": "abnormal_trigger",
|
||||
"trigger_reason": "poor_skin",
|
||||
"enable_streaming": True, # The server handles it by accumulating, but we keep the field
|
||||
"context_data": {
|
||||
"emotion": "sad",
|
||||
"skin_status": {
|
||||
"acne": True,
|
||||
"dark_circles": True
|
||||
},
|
||||
"timestamp": "2024-01-01 12:30:45"
|
||||
}
|
||||
}
|
||||
|
||||
data = json.dumps(payload).encode('utf-8')
|
||||
req = urllib.request.Request(url, data=data, headers={'Content-Type': 'application/json'})
|
||||
|
||||
print(f"Sending request to {url}...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req) as response:
|
||||
if response.status == 200:
|
||||
print("Request successful!")
|
||||
output_file = "response.pcm"
|
||||
|
||||
# Read content
|
||||
audio_data = response.read()
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(audio_data)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"Audio saved to {output_file} ({len(audio_data)} bytes)")
|
||||
print(f"Time taken: {elapsed:.2f} seconds")
|
||||
|
||||
# Optional: detailed check
|
||||
if len(audio_data) < 100:
|
||||
print("Warning: Audio file seems too small.")
|
||||
else:
|
||||
print(f"Error: {response.status}")
|
||||
print(response.read().decode('utf-8'))
|
||||
|
||||
except urllib.error.URLError as e:
|
||||
print(f"Error: {e}")
|
||||
if hasattr(e, 'read'):
|
||||
print(e.read().decode('utf-8'))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_abnormal_trigger()
|
||||
Loading…
x
Reference in New Issue
Block a user