diff --git a/.env b/.env new file mode 100644 index 0000000..edd2db7 --- /dev/null +++ b/.env @@ -0,0 +1 @@ +DASHSCOPE_API_KEY=sk-db13d70317d84f2ba94a08ef93b2a774 diff --git a/.gitignore b/.gitignore index 56aa68b..17849c3 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ wheels/ .venv .vscode/ tmp/ +ref/ diff --git a/pyproject.toml b/pyproject.toml index 46a1386..7ba1e0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,11 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "agentscope>=1.0.10", + "dashscope>=1.25.5", "loguru>=0.7.3", "pre-commit>=4.3.0", "pydantic>=2.11.7", + "python-dotenv>=1.2.1", "pyyaml>=6.0.2", "ruff>=0.12.11", ] diff --git a/src/Module/tts/tts.py b/src/Module/tts/tts.py new file mode 100644 index 0000000..b4bb74e --- /dev/null +++ b/src/Module/tts/tts.py @@ -0,0 +1,361 @@ +""" +实时语音合成模块 - 基于阿里云 DashScope 通义千问 TTS +核心特性:使用 yield 生成器实现流式音频返回 +""" + +import os +import base64 +import threading +import queue +from typing import Generator, Optional +from pathlib import Path +from dataclasses import dataclass +from dotenv import load_dotenv + +import dashscope +from dashscope.audio.qwen_tts_realtime import ( + QwenTtsRealtime, + QwenTtsRealtimeCallback, + AudioFormat +) + +# 导入配置 +from . import ttsconfig as config + + +@dataclass +class AudioChunk: + """音频数据块""" + data: bytes # 音频数据 (PCM) + is_final: bool = False # 是否为最后一块 + error: Optional[str] = None # 错误信息 + + +@dataclass +class AudioInfo: + """音频参数信息""" + sample_rate: int = config.SAMPLE_RATE + channels: int = config.CHANNELS + sample_width: int = config.SAMPLE_WIDTH + + +class _StreamingCallback(QwenTtsRealtimeCallback): + """内部流式回调类""" + + def __init__(self): + self._audio_queue: queue.Queue[AudioChunk] = queue.Queue() + self._connected = threading.Event() + self._session_updated = threading.Event() + self.session_id: Optional[str] = None + + def on_open(self) -> None: + self._connected.set() + + def on_close(self, close_status_code, close_msg) -> None: + self._audio_queue.put(AudioChunk(data=b'', is_final=True)) + + def on_event(self, response: dict) -> None: + try: + event_type = response.get('type', '') + + if event_type == 'session.created': + self.session_id = response.get('session', {}).get('id') + + elif event_type == 'session.updated': + self._session_updated.set() + + elif event_type == 'response.audio.delta': + audio_b64 = response.get('delta', '') + if audio_b64: + audio_bytes = base64.b64decode(audio_b64) + self._audio_queue.put(AudioChunk(data=audio_bytes)) + + elif event_type == 'session.finished': + self._audio_queue.put(AudioChunk(data=b'', is_final=True)) + + elif event_type == 'error': + error_msg = response.get('error', {}).get('message', '未知错误') + self._audio_queue.put( + AudioChunk( + data=b'', + is_final=True, + error=error_msg)) + + except Exception as e: + self._audio_queue.put(AudioChunk(data=b'', is_final=True, error=str(e))) + + def get_chunk(self, timeout: float = config.CHUNK_TIMEOUT) -> Optional[AudioChunk]: + try: + return self._audio_queue.get(timeout=timeout) + except queue.Empty: + return AudioChunk(data=b'', is_final=True, error='超时') + + def wait_connected(self, timeout: float = config.CONNECT_TIMEOUT) -> bool: + return self._connected.wait(timeout=timeout) + + +class StreamingTTS: + """ + 流式语音合成类 + + 使用方式: + tts = StreamingTTS(voice='Cherry') + for chunk in tts.stream("你好,世界"): + if chunk.data: + audio_player.play(chunk.data) + """ + + # 音频参数信息 (只读) + audio_info = AudioInfo() + + def __init__( + self, + voice: str = None, + language: str = None, + speech_rate: float = None, + volume: int = None, + pitch_rate: float = None, + model: str = None, + ): + """ + 初始化流式TTS + + 所有参数可选,未传入时使用 ttsconfig.py 中的默认值 + + Args: + voice: 音色 ('Cherry', 'Serena', 'Ethan', 'Chelsie') + language: 语言 ('Chinese', 'English', etc.) + speech_rate: 语速 (0.5-2.0) + volume: 音量 (0-100) + pitch_rate: 语调 (0.5-2.0) + model: 模型名称 + """ + self._load_api_key() + + # 使用传入值或配置文件默认值 + self.voice = voice or config.VOICE + self.language = language or config.LANGUAGE + self.speech_rate = self._clamp(speech_rate or config.SPEECH_RATE, 0.5, 2.0) + self.volume = self._clamp(volume or config.VOLUME, 0, 100) + self.pitch_rate = self._clamp(pitch_rate or config.PITCH_RATE, 0.5, 2.0) + self.model = model or config.MODEL + + self._client: Optional[QwenTtsRealtime] = None + self._callback: Optional[_StreamingCallback] = None + + @staticmethod + def _clamp(value, min_val, max_val): + return max(min_val, min(max_val, value)) + + def _load_api_key(self) -> None: + """从 .env 加载 API Key""" + current_dir = Path(__file__).parent + for _ in range(5): + env_path = current_dir / '.env' + if env_path.exists(): + load_dotenv(env_path) + break + current_dir = current_dir.parent + + api_key = os.environ.get('DASHSCOPE_API_KEY') + if api_key: + dashscope.api_key = api_key + else: + raise ValueError('未找到 DASHSCOPE_API_KEY') + + def stream( + self, + text: str, + voice: str = None, + language: str = None, + speech_rate: float = None, + volume: int = None, + pitch_rate: float = None, + ) -> Generator[AudioChunk, None, None]: + """ + 流式合成语音 - 核心接口 + + Args: + text: 要合成的文本 + voice: 临时覆盖音色 + language: 临时覆盖语言 + speech_rate: 临时覆盖语速 + volume: 临时覆盖音量 + pitch_rate: 临时覆盖语调 + + Yields: + AudioChunk: 音频数据块 + + Example: + >>> tts = StreamingTTS() + >>> for chunk in tts.stream("你好", voice='Serena'): + ... process_audio(chunk.data) + """ + # 合并参数:stream 传入的优先,否则用实例属性 + _voice = voice or self.voice + _language = language or self.language + _speech_rate = self._clamp(speech_rate or self.speech_rate, 0.5, 2.0) + _volume = self._clamp(volume or self.volume, 0, 100) + _pitch_rate = self._clamp(pitch_rate or self.pitch_rate, 0.5, 2.0) + + self._callback = _StreamingCallback() + self._client = QwenTtsRealtime( + model=self.model, + callback=self._callback, + url=config.WS_URL + ) + + # 启动发送线程 + send_thread = threading.Thread( + target=self._send_text, + args=(text, _voice, _language, _speech_rate, _volume, _pitch_rate), + daemon=True + ) + + try: + self._client.connect() + send_thread.start() + + while True: + chunk = self._callback.get_chunk() + if chunk is None: + break + yield chunk + if chunk.is_final: + break + + finally: + self._cleanup() + + def _send_text( + self, + text: str, + voice: str, + language: str, + speech_rate: float, + volume: int, + pitch_rate: float + ) -> None: + """在后台发送文本""" + try: + if not self._callback.wait_connected(): + return + + self._client.update_session( + voice=voice, + response_format=AudioFormat.PCM_24000HZ_MONO_16BIT, + mode=config.MODE, + language_type=language, + speech_rate=speech_rate, + volume=volume, + pitch_rate=pitch_rate + ) + + self._client.append_text(text) + self._client.finish() + + except Exception as e: + self._callback._audio_queue.put( + AudioChunk(data=b'', is_final=True, error=str(e)) + ) + + def _cleanup(self) -> None: + if self._client: + try: + self._client.close() + except BaseException: + pass + self._client = None + + def stream_to_file(self, text: str, output_file: str, **kwargs) -> bool: + """ + 流式合成并保存到文件 + + Args: + text: 要合成的文本 + output_file: 输出文件路径 (.pcm 或 .wav) + **kwargs: 传递给 stream() 的参数 + + Returns: + 是否成功 + """ + audio_data = bytearray() + error = None + + for chunk in self.stream(text, **kwargs): + if chunk.error: + error = chunk.error + break + if chunk.data: + audio_data.extend(chunk.data) + + if error: + print(f'[TTS Error] {error}') + return False + + if output_file.endswith('.wav'): + return self._save_as_wav(bytes(audio_data), output_file) + else: + with open(output_file, 'wb') as f: + f.write(audio_data) + return True + + def _save_as_wav(self, pcm_data: bytes, wav_file: str) -> bool: + import wave + try: + with wave.open(wav_file, 'wb') as wav: + wav.setnchannels(config.CHANNELS) + wav.setsampwidth(config.SAMPLE_WIDTH) + wav.setframerate(config.SAMPLE_RATE) + wav.writeframes(pcm_data) + return True + except Exception as e: + print(f'[TTS Error] WAV 保存失败: {e}') + return False + + +# ============================================================ +# 便捷函数 +# ============================================================ + +def synthesize(text: str, **kwargs) -> Generator[AudioChunk, None, None]: + """ + 便捷的流式合成函数 + + Args: + text: 要合成的文本 + **kwargs: 传递给 StreamingTTS/stream() 的参数 + + Yields: + AudioChunk: 音频数据块 + """ + # 分离实例化参数和 stream 参数 + init_keys = {'model'} + stream_keys = {'voice', 'language', 'speech_rate', 'volume', 'pitch_rate'} + + init_kwargs = {k: v for k, v in kwargs.items() if k in init_keys} + stream_kwargs = {k: v for k, v in kwargs.items() if k in stream_keys} + + tts = StreamingTTS(**init_kwargs) + yield from tts.stream(text, **stream_kwargs) + + +def synthesize_to_file(text: str, output_file: str, **kwargs) -> bool: + """ + 合成并保存到文件 + + Args: + text: 要合成的文本 + output_file: 输出文件路径 + **kwargs: 传递给 stream() 的参数 + + Returns: + 是否成功 + """ + tts = StreamingTTS() + return tts.stream_to_file(text, output_file, **kwargs) + + +def get_audio_info() -> AudioInfo: + """获取音频参数信息""" + return AudioInfo() diff --git a/src/Module/tts/ttsconfig.py b/src/Module/tts/ttsconfig.py new file mode 100644 index 0000000..c38f7f7 --- /dev/null +++ b/src/Module/tts/ttsconfig.py @@ -0,0 +1,73 @@ +""" +TTS 配置文件 +定义语音合成的默认参数 +""" + +# ============================================================ +# 音频输出配置 +# ============================================================ + +# 采样率 (Hz) - SDK 当前仅支持 24000 +SAMPLE_RATE = 24000 + +# 声道数 - 单声道 +CHANNELS = 1 + +# 采样位宽 (字节) - 16bit = 2字节 +SAMPLE_WIDTH = 2 + +# 输出格式 +OUTPUT_FORMAT = 'pcm' # 'pcm', 'wav' + + +# ============================================================ +# 语音合成参数 +# ============================================================ + +# 默认模型 +MODEL = 'qwen3-tts-flash-realtime' + +# 默认音色 +# 可选: 'Cherry', 'Serena', 'Ethan', 'Chelsie' +VOICE = 'Cherry' + +# 默认语言 +# 可选: 'Auto', 'Chinese', 'English', 'Japanese', 'Korean', +# 'French', 'German', 'Spanish', 'Italian', 'Portuguese', 'Russian' +LANGUAGE = 'Chinese' + +# 语速 (0.5 - 2.0, 默认 1.0) +SPEECH_RATE = 1.0 + +# 音量 (0 - 100, 默认 50) +VOLUME = 50 + +# 语调 (0.5 - 2.0, 默认 1.0) +PITCH_RATE = 1.0 + + +# ============================================================ +# 服务配置 +# ============================================================ + +# WebSocket URL (北京地域) +WS_URL = 'wss://dashscope.aliyuncs.com/api-ws/v1/realtime' + +# 新加坡地域 URL (备用) +# WS_URL = 'wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime' + +# 连接超时 (秒) +CONNECT_TIMEOUT = 10.0 + +# 数据接收超时 (秒) +CHUNK_TIMEOUT = 30.0 + + +# ============================================================ +# 交互模式 +# ============================================================ + +# 模式: 'server_commit' 或 'commit' +# server_commit: 服务器决定断句 +# commit: 客户端主动触发 +MODE = 'server_commit' diff --git a/test.py b/test.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/asr/mirror_hello.mp3 b/test/asr/mirror_hello.mp3 new file mode 100644 index 0000000..3ec77af Binary files /dev/null and b/test/asr/mirror_hello.mp3 differ diff --git a/test/tts/output/stream_test.wav b/test/tts/output/stream_test.wav new file mode 100644 index 0000000..6995057 Binary files /dev/null and b/test/tts/output/stream_test.wav differ diff --git a/test/tts/test_tts.py b/test/tts/test_tts.py new file mode 100644 index 0000000..7012c68 --- /dev/null +++ b/test/tts/test_tts.py @@ -0,0 +1,96 @@ +""" +TTS 流式合成测试 +""" + +from pathlib import Path +from src.Module.tts.tts import StreamingTTS, synthesize, AudioChunk + + +def test_stream_synthesis(): + """测试流式合成 - 使用 yield 生成器""" + print("=" * 50) + print("流式语音合成测试") + print("=" * 50) + + # 创建 TTS 实例 + tts = StreamingTTS( + voice='Cherry', + language='Chinese', + speech_rate=1.0, + ) + + text = "你好,我是通义千问语音合成系统。这是一段流式合成测试。" + print(f"合成文本: {text}\n") + + # 流式接收音频 + total_bytes = 0 + chunk_count = 0 + audio_data = bytearray() + + print("开始流式接收音频...") + for chunk in tts.stream(text): + if chunk.error: + print(f" 错误: {chunk.error}") + break + + if chunk.data: + chunk_count += 1 + total_bytes += len(chunk.data) + audio_data.extend(chunk.data) + print( + f" [Chunk {chunk_count:02d}] 收到 {len(chunk.data):5d} 字节 | 累计: {total_bytes:6d} 字节") + + if chunk.is_final: + print("\n流式传输完成!") + break + + # 保存为 WAV 文件 + output_dir = Path(__file__).parent / 'output' + output_dir.mkdir(exist_ok=True) + wav_file = output_dir / 'stream_test.wav' + + import wave + with wave.open(str(wav_file), 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(audio_data) + + print(f"\n总计: {chunk_count} 个数据块, {total_bytes} 字节") + print(f"音频已保存: {wav_file}") + + return chunk_count > 0 + + +def test_convenient_function(): + """测试便捷函数""" + print("\n" + "=" * 50) + print("便捷函数测试") + print("=" * 50) + + text = "这是使用便捷函数合成的语音。" + print(f"合成文本: {text}\n") + + total_bytes = 0 + for chunk in synthesize(text, voice='Cherry'): + if chunk.data: + total_bytes += len(chunk.data) + if chunk.is_final: + break + + print(f"合成完成,总数据量: {total_bytes} 字节") + return total_bytes > 0 + + +if __name__ == '__main__': + print("=" * 60) + print(" TTS 流式合成测试") + print("=" * 60) + + success1 = test_stream_synthesis() + success2 = test_convenient_function() + + print("\n" + "=" * 60) + print("测试结果:") + print(f" 流式合成: {'✓ 通过' if success1 else '✗ 失败'}") + print(f" 便捷函数: {'✓ 通过' if success2 else '✗ 失败'}") diff --git a/uv.lock b/uv.lock index a2f14a4..a15ad5d 100644 --- a/uv.lock +++ b/uv.lock @@ -631,6 +631,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "heart-mirror-agent" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "agentscope" }, + { name = "dashscope" }, + { name = "loguru" }, + { name = "pre-commit" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "pyyaml" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "agentscope", specifier = ">=1.0.10" }, + { name = "dashscope", specifier = ">=1.25.5" }, + { name = "loguru", specifier = ">=0.7.3" }, + { name = "pre-commit", specifier = ">=4.3.0" }, + { name = "pydantic", specifier = ">=2.11.7" }, + { name = "python-dotenv", specifier = ">=1.2.1" }, + { name = "pyyaml", specifier = ">=6.0.2" }, + { name = "ruff", specifier = ">=0.12.11" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -1436,29 +1463,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/d2/2ccc2b69a187b80fda3152745670cfba936704f296a9fa54c6c8ac694d12/python_socketio-5.16.0-py3-none-any.whl", hash = "sha256:d95802961e15c7bd54ecf884c6e7644f81be8460f0a02ee66b473df58088ee8a", size = 79607, upload-time = "2025-12-24T23:51:47.2Z" }, ] -[[package]] -name = "Heart_Mirror_Agent" -version = "0.1.0" -source = { virtual = "." } -dependencies = [ - { name = "agentscope" }, - { name = "loguru" }, - { name = "pre-commit" }, - { name = "pydantic" }, - { name = "pyyaml" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "agentscope", specifier = ">=1.0.10" }, - { name = "loguru", specifier = ">=0.7.3" }, - { name = "pre-commit", specifier = ">=4.3.0" }, - { name = "pydantic", specifier = ">=2.11.7" }, - { name = "pyyaml", specifier = ">=6.0.2" }, - { name = "ruff", specifier = ">=0.12.11" }, -] - [[package]] name = "pywin32" version = "311"