From 4c91bc650a11e1101848f142b1b8ebfcc7de59e5 Mon Sep 17 00:00:00 2001 From: wds Date: Fri, 2 Jan 2026 05:13:33 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20http=E7=AE=80=E5=8D=95?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E7=8A=B6=E6=80=81=E5=BC=82=E5=B8=B8=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 122 +++++++++++++++++++++++++++++++++++++++++++-- test_agent_http.py | 58 +++++++++++++++++++++ 2 files changed, 176 insertions(+), 4 deletions(-) create mode 100644 test_agent_http.py diff --git a/main.py b/main.py index f30cb69..7c6e3fc 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,120 @@ -def main(): - print("Hello from Heart_Mirror_Agent!") +import http.server +import socketserver +import json +import sys +import os +import asyncio +import threading +from pathlib import Path +# Add project root to sys.path +PROJECT_ROOT = Path(__file__).parent +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) -if __name__ == "__main__": - main() +from src.Module.llm.llm import StreamingLLM +from src.Module.tts.tts import StreamingTTS +from src.utils.prompts import get_trigger_prompt +from loguru import logger + +# Configuration +PORT = 8000 + +class AgentRequestHandler(http.server.BaseHTTPRequestHandler): + def _set_headers(self, content_type='application/json'): + self.send_response(200) + self.send_header('Content-type', content_type) + self.end_headers() + + def do_POST(self): + if self.path == '/abnormal_trigger': + content_length = int(self.headers['Content-Length']) + post_data = self.rfile.read(content_length) + + try: + data = json.loads(post_data.decode('utf-8')) + logger.info(f"Received abnormal_trigger: {data}") + + trigger_reason = data.get("trigger_reason") + # context_data = data.get("context_data") # Not used in prompt yet, but available + + if not trigger_reason: + self.send_error(400, "Missing trigger_reason") + return + + # Generate Audio + audio_data = self.generate_response_audio(trigger_reason) + + if audio_data: + self._set_headers('application/octet-stream') + self.wfile.write(audio_data) + logger.info("Sent audio response") + else: + self.send_error(500, "Failed to generate audio") + + except json.JSONDecodeError: + self.send_error(400, "Invalid JSON") + except Exception as e: + logger.error(f"Error processing request: {e}") + self.send_error(500, str(e)) + else: + self.send_error(404, "Not Found") + + def generate_response_audio(self, trigger_reason): + """ + Orchestrates LLM and TTS to produce audio bytes. + """ + system_prompt = get_trigger_prompt(trigger_reason) + logger.info(f"Generated prompt for {trigger_reason}") + + # Run async logic synchronously + return asyncio.run(self._async_generate(system_prompt)) + + async def _async_generate(self, system_prompt): + llm = StreamingLLM() + # Ensure we use a voice that works. 'Cherry' is used in other handlers. + tts = StreamingTTS(voice='Cherry') + + full_audio = bytearray() + + # Generator for LLM text + def text_generator(): + # We use chat with empty message because the system_prompt contains the instruction + # and the trigger implies "user just triggered this state". + # Or we can put a dummy user message like "." or "START". + # Looking at src/handlers/abnormal_trigger.py: it calls llm.chat(message="", system_prompt=system_prompt) + for chunk in llm.chat(message="(用户触发了异常状态,请直接根据系统提示开始说话)", system_prompt=system_prompt): + if chunk.error: + logger.error(f"LLM Error: {chunk.error}") + continue + if chunk.content: + # logger.debug(f"LLM Chunk: {chunk.content}") + yield chunk.content + + # Stream TTS from LLM text generator + try: + for audio_chunk in tts.stream_from_generator(text_generator()): + if audio_chunk.error: + logger.error(f"TTS Error: {audio_chunk.error}") + continue + if audio_chunk.data: + full_audio.extend(audio_chunk.data) + + return bytes(full_audio) + except Exception as e: + logger.error(f"Generation failed: {e}") + return None + +def run(server_class=http.server.HTTPServer, handler_class=AgentRequestHandler, port=PORT): + server_address = ('0.0.0.0', port) + httpd = server_class(server_address, handler_class) + logger.info(f"Starting Agent HTTP Server on port {port}...") + try: + httpd.serve_forever() + except KeyboardInterrupt: + pass + httpd.server_close() + logger.info("Server stopped.") + +if __name__ == '__main__': + run() \ No newline at end of file diff --git a/test_agent_http.py b/test_agent_http.py new file mode 100644 index 0000000..31b66b2 --- /dev/null +++ b/test_agent_http.py @@ -0,0 +1,58 @@ +import urllib.request +import urllib.error +import json +import time + +def test_abnormal_trigger(): + url = "http://172.20.10.2:8000/abnormal_trigger" + + payload = { + "type": "abnormal_trigger", + "trigger_reason": "poor_skin", + "enable_streaming": True, # The server handles it by accumulating, but we keep the field + "context_data": { + "emotion": "sad", + "skin_status": { + "acne": True, + "dark_circles": True + }, + "timestamp": "2024-01-01 12:30:45" + } + } + + data = json.dumps(payload).encode('utf-8') + req = urllib.request.Request(url, data=data, headers={'Content-Type': 'application/json'}) + + print(f"Sending request to {url}...") + start_time = time.time() + + try: + with urllib.request.urlopen(req) as response: + if response.status == 200: + print("Request successful!") + output_file = "response.pcm" + + # Read content + audio_data = response.read() + + with open(output_file, "wb") as f: + f.write(audio_data) + + elapsed = time.time() - start_time + print(f"Audio saved to {output_file} ({len(audio_data)} bytes)") + print(f"Time taken: {elapsed:.2f} seconds") + + # Optional: detailed check + if len(audio_data) < 100: + print("Warning: Audio file seems too small.") + else: + print(f"Error: {response.status}") + print(response.read().decode('utf-8')) + + except urllib.error.URLError as e: + print(f"Error: {e}") + if hasattr(e, 'read'): + print(e.read().decode('utf-8')) + +if __name__ == "__main__": + test_abnormal_trigger()