feat: 流式语音合成

This commit is contained in:
DongShengWu 2026-01-01 21:10:04 +08:00
parent cc74e2b880
commit 7f9ae0e036
10 changed files with 561 additions and 23 deletions

1
.env Normal file
View File

@ -0,0 +1 @@
DASHSCOPE_API_KEY=sk-db13d70317d84f2ba94a08ef93b2a774

1
.gitignore vendored
View File

@ -10,3 +10,4 @@ wheels/
.venv .venv
.vscode/ .vscode/
tmp/ tmp/
ref/

View File

@ -6,9 +6,11 @@ readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"agentscope>=1.0.10", "agentscope>=1.0.10",
"dashscope>=1.25.5",
"loguru>=0.7.3", "loguru>=0.7.3",
"pre-commit>=4.3.0", "pre-commit>=4.3.0",
"pydantic>=2.11.7", "pydantic>=2.11.7",
"python-dotenv>=1.2.1",
"pyyaml>=6.0.2", "pyyaml>=6.0.2",
"ruff>=0.12.11", "ruff>=0.12.11",
] ]

361
src/Module/tts/tts.py Normal file
View File

@ -0,0 +1,361 @@
"""
实时语音合成模块 - 基于阿里云 DashScope 通义千问 TTS
核心特性使用 yield 生成器实现流式音频返回
"""
import os
import base64
import threading
import queue
from typing import Generator, Optional
from pathlib import Path
from dataclasses import dataclass
from dotenv import load_dotenv
import dashscope
from dashscope.audio.qwen_tts_realtime import (
QwenTtsRealtime,
QwenTtsRealtimeCallback,
AudioFormat
)
# 导入配置
from . import ttsconfig as config
@dataclass
class AudioChunk:
"""音频数据块"""
data: bytes # 音频数据 (PCM)
is_final: bool = False # 是否为最后一块
error: Optional[str] = None # 错误信息
@dataclass
class AudioInfo:
"""音频参数信息"""
sample_rate: int = config.SAMPLE_RATE
channels: int = config.CHANNELS
sample_width: int = config.SAMPLE_WIDTH
class _StreamingCallback(QwenTtsRealtimeCallback):
"""内部流式回调类"""
def __init__(self):
self._audio_queue: queue.Queue[AudioChunk] = queue.Queue()
self._connected = threading.Event()
self._session_updated = threading.Event()
self.session_id: Optional[str] = None
def on_open(self) -> None:
self._connected.set()
def on_close(self, close_status_code, close_msg) -> None:
self._audio_queue.put(AudioChunk(data=b'', is_final=True))
def on_event(self, response: dict) -> None:
try:
event_type = response.get('type', '')
if event_type == 'session.created':
self.session_id = response.get('session', {}).get('id')
elif event_type == 'session.updated':
self._session_updated.set()
elif event_type == 'response.audio.delta':
audio_b64 = response.get('delta', '')
if audio_b64:
audio_bytes = base64.b64decode(audio_b64)
self._audio_queue.put(AudioChunk(data=audio_bytes))
elif event_type == 'session.finished':
self._audio_queue.put(AudioChunk(data=b'', is_final=True))
elif event_type == 'error':
error_msg = response.get('error', {}).get('message', '未知错误')
self._audio_queue.put(
AudioChunk(
data=b'',
is_final=True,
error=error_msg))
except Exception as e:
self._audio_queue.put(AudioChunk(data=b'', is_final=True, error=str(e)))
def get_chunk(self, timeout: float = config.CHUNK_TIMEOUT) -> Optional[AudioChunk]:
try:
return self._audio_queue.get(timeout=timeout)
except queue.Empty:
return AudioChunk(data=b'', is_final=True, error='超时')
def wait_connected(self, timeout: float = config.CONNECT_TIMEOUT) -> bool:
return self._connected.wait(timeout=timeout)
class StreamingTTS:
"""
流式语音合成类
使用方式:
tts = StreamingTTS(voice='Cherry')
for chunk in tts.stream("你好,世界"):
if chunk.data:
audio_player.play(chunk.data)
"""
# 音频参数信息 (只读)
audio_info = AudioInfo()
def __init__(
self,
voice: str = None,
language: str = None,
speech_rate: float = None,
volume: int = None,
pitch_rate: float = None,
model: str = None,
):
"""
初始化流式TTS
所有参数可选未传入时使用 ttsconfig.py 中的默认值
Args:
voice: 音色 ('Cherry', 'Serena', 'Ethan', 'Chelsie')
language: 语言 ('Chinese', 'English', etc.)
speech_rate: 语速 (0.5-2.0)
volume: 音量 (0-100)
pitch_rate: 语调 (0.5-2.0)
model: 模型名称
"""
self._load_api_key()
# 使用传入值或配置文件默认值
self.voice = voice or config.VOICE
self.language = language or config.LANGUAGE
self.speech_rate = self._clamp(speech_rate or config.SPEECH_RATE, 0.5, 2.0)
self.volume = self._clamp(volume or config.VOLUME, 0, 100)
self.pitch_rate = self._clamp(pitch_rate or config.PITCH_RATE, 0.5, 2.0)
self.model = model or config.MODEL
self._client: Optional[QwenTtsRealtime] = None
self._callback: Optional[_StreamingCallback] = None
@staticmethod
def _clamp(value, min_val, max_val):
return max(min_val, min(max_val, value))
def _load_api_key(self) -> None:
"""从 .env 加载 API Key"""
current_dir = Path(__file__).parent
for _ in range(5):
env_path = current_dir / '.env'
if env_path.exists():
load_dotenv(env_path)
break
current_dir = current_dir.parent
api_key = os.environ.get('DASHSCOPE_API_KEY')
if api_key:
dashscope.api_key = api_key
else:
raise ValueError('未找到 DASHSCOPE_API_KEY')
def stream(
self,
text: str,
voice: str = None,
language: str = None,
speech_rate: float = None,
volume: int = None,
pitch_rate: float = None,
) -> Generator[AudioChunk, None, None]:
"""
流式合成语音 - 核心接口
Args:
text: 要合成的文本
voice: 临时覆盖音色
language: 临时覆盖语言
speech_rate: 临时覆盖语速
volume: 临时覆盖音量
pitch_rate: 临时覆盖语调
Yields:
AudioChunk: 音频数据块
Example:
>>> tts = StreamingTTS()
>>> for chunk in tts.stream("你好", voice='Serena'):
... process_audio(chunk.data)
"""
# 合并参数stream 传入的优先,否则用实例属性
_voice = voice or self.voice
_language = language or self.language
_speech_rate = self._clamp(speech_rate or self.speech_rate, 0.5, 2.0)
_volume = self._clamp(volume or self.volume, 0, 100)
_pitch_rate = self._clamp(pitch_rate or self.pitch_rate, 0.5, 2.0)
self._callback = _StreamingCallback()
self._client = QwenTtsRealtime(
model=self.model,
callback=self._callback,
url=config.WS_URL
)
# 启动发送线程
send_thread = threading.Thread(
target=self._send_text,
args=(text, _voice, _language, _speech_rate, _volume, _pitch_rate),
daemon=True
)
try:
self._client.connect()
send_thread.start()
while True:
chunk = self._callback.get_chunk()
if chunk is None:
break
yield chunk
if chunk.is_final:
break
finally:
self._cleanup()
def _send_text(
self,
text: str,
voice: str,
language: str,
speech_rate: float,
volume: int,
pitch_rate: float
) -> None:
"""在后台发送文本"""
try:
if not self._callback.wait_connected():
return
self._client.update_session(
voice=voice,
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
mode=config.MODE,
language_type=language,
speech_rate=speech_rate,
volume=volume,
pitch_rate=pitch_rate
)
self._client.append_text(text)
self._client.finish()
except Exception as e:
self._callback._audio_queue.put(
AudioChunk(data=b'', is_final=True, error=str(e))
)
def _cleanup(self) -> None:
if self._client:
try:
self._client.close()
except BaseException:
pass
self._client = None
def stream_to_file(self, text: str, output_file: str, **kwargs) -> bool:
"""
流式合成并保存到文件
Args:
text: 要合成的文本
output_file: 输出文件路径 (.pcm .wav)
**kwargs: 传递给 stream() 的参数
Returns:
是否成功
"""
audio_data = bytearray()
error = None
for chunk in self.stream(text, **kwargs):
if chunk.error:
error = chunk.error
break
if chunk.data:
audio_data.extend(chunk.data)
if error:
print(f'[TTS Error] {error}')
return False
if output_file.endswith('.wav'):
return self._save_as_wav(bytes(audio_data), output_file)
else:
with open(output_file, 'wb') as f:
f.write(audio_data)
return True
def _save_as_wav(self, pcm_data: bytes, wav_file: str) -> bool:
import wave
try:
with wave.open(wav_file, 'wb') as wav:
wav.setnchannels(config.CHANNELS)
wav.setsampwidth(config.SAMPLE_WIDTH)
wav.setframerate(config.SAMPLE_RATE)
wav.writeframes(pcm_data)
return True
except Exception as e:
print(f'[TTS Error] WAV 保存失败: {e}')
return False
# ============================================================
# 便捷函数
# ============================================================
def synthesize(text: str, **kwargs) -> Generator[AudioChunk, None, None]:
"""
便捷的流式合成函数
Args:
text: 要合成的文本
**kwargs: 传递给 StreamingTTS/stream() 的参数
Yields:
AudioChunk: 音频数据块
"""
# 分离实例化参数和 stream 参数
init_keys = {'model'}
stream_keys = {'voice', 'language', 'speech_rate', 'volume', 'pitch_rate'}
init_kwargs = {k: v for k, v in kwargs.items() if k in init_keys}
stream_kwargs = {k: v for k, v in kwargs.items() if k in stream_keys}
tts = StreamingTTS(**init_kwargs)
yield from tts.stream(text, **stream_kwargs)
def synthesize_to_file(text: str, output_file: str, **kwargs) -> bool:
"""
合成并保存到文件
Args:
text: 要合成的文本
output_file: 输出文件路径
**kwargs: 传递给 stream() 的参数
Returns:
是否成功
"""
tts = StreamingTTS()
return tts.stream_to_file(text, output_file, **kwargs)
def get_audio_info() -> AudioInfo:
"""获取音频参数信息"""
return AudioInfo()

View File

@ -0,0 +1,73 @@
"""
TTS 配置文件
定义语音合成的默认参数
"""
# ============================================================
# 音频输出配置
# ============================================================
# 采样率 (Hz) - SDK 当前仅支持 24000
SAMPLE_RATE = 24000
# 声道数 - 单声道
CHANNELS = 1
# 采样位宽 (字节) - 16bit = 2字节
SAMPLE_WIDTH = 2
# 输出格式
OUTPUT_FORMAT = 'pcm' # 'pcm', 'wav'
# ============================================================
# 语音合成参数
# ============================================================
# 默认模型
MODEL = 'qwen3-tts-flash-realtime'
# 默认音色
# 可选: 'Cherry', 'Serena', 'Ethan', 'Chelsie'
VOICE = 'Cherry'
# 默认语言
# 可选: 'Auto', 'Chinese', 'English', 'Japanese', 'Korean',
# 'French', 'German', 'Spanish', 'Italian', 'Portuguese', 'Russian'
LANGUAGE = 'Chinese'
# 语速 (0.5 - 2.0, 默认 1.0)
SPEECH_RATE = 1.0
# 音量 (0 - 100, 默认 50)
VOLUME = 50
# 语调 (0.5 - 2.0, 默认 1.0)
PITCH_RATE = 1.0
# ============================================================
# 服务配置
# ============================================================
# WebSocket URL (北京地域)
WS_URL = 'wss://dashscope.aliyuncs.com/api-ws/v1/realtime'
# 新加坡地域 URL (备用)
# WS_URL = 'wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime'
# 连接超时 (秒)
CONNECT_TIMEOUT = 10.0
# 数据接收超时 (秒)
CHUNK_TIMEOUT = 30.0
# ============================================================
# 交互模式
# ============================================================
# 模式: 'server_commit' 或 'commit'
# server_commit: 服务器决定断句
# commit: 客户端主动触发
MODE = 'server_commit'

View File

BIN
test/asr/mirror_hello.mp3 Normal file

Binary file not shown.

Binary file not shown.

96
test/tts/test_tts.py Normal file
View File

@ -0,0 +1,96 @@
"""
TTS 流式合成测试
"""
from pathlib import Path
from src.Module.tts.tts import StreamingTTS, synthesize, AudioChunk
def test_stream_synthesis():
"""测试流式合成 - 使用 yield 生成器"""
print("=" * 50)
print("流式语音合成测试")
print("=" * 50)
# 创建 TTS 实例
tts = StreamingTTS(
voice='Cherry',
language='Chinese',
speech_rate=1.0,
)
text = "你好,我是通义千问语音合成系统。这是一段流式合成测试。"
print(f"合成文本: {text}\n")
# 流式接收音频
total_bytes = 0
chunk_count = 0
audio_data = bytearray()
print("开始流式接收音频...")
for chunk in tts.stream(text):
if chunk.error:
print(f" 错误: {chunk.error}")
break
if chunk.data:
chunk_count += 1
total_bytes += len(chunk.data)
audio_data.extend(chunk.data)
print(
f" [Chunk {chunk_count:02d}] 收到 {len(chunk.data):5d} 字节 | 累计: {total_bytes:6d} 字节")
if chunk.is_final:
print("\n流式传输完成!")
break
# 保存为 WAV 文件
output_dir = Path(__file__).parent / 'output'
output_dir.mkdir(exist_ok=True)
wav_file = output_dir / 'stream_test.wav'
import wave
with wave.open(str(wav_file), 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(24000)
wav.writeframes(audio_data)
print(f"\n总计: {chunk_count} 个数据块, {total_bytes} 字节")
print(f"音频已保存: {wav_file}")
return chunk_count > 0
def test_convenient_function():
"""测试便捷函数"""
print("\n" + "=" * 50)
print("便捷函数测试")
print("=" * 50)
text = "这是使用便捷函数合成的语音。"
print(f"合成文本: {text}\n")
total_bytes = 0
for chunk in synthesize(text, voice='Cherry'):
if chunk.data:
total_bytes += len(chunk.data)
if chunk.is_final:
break
print(f"合成完成,总数据量: {total_bytes} 字节")
return total_bytes > 0
if __name__ == '__main__':
print("=" * 60)
print(" TTS 流式合成测试")
print("=" * 60)
success1 = test_stream_synthesis()
success2 = test_convenient_function()
print("\n" + "=" * 60)
print("测试结果:")
print(f" 流式合成: {'✓ 通过' if success1 else '✗ 失败'}")
print(f" 便捷函数: {'✓ 通过' if success2 else '✗ 失败'}")

50
uv.lock generated
View File

@ -631,6 +631,33 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
] ]
[[package]]
name = "heart-mirror-agent"
version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "agentscope" },
{ name = "dashscope" },
{ name = "loguru" },
{ name = "pre-commit" },
{ name = "pydantic" },
{ name = "python-dotenv" },
{ name = "pyyaml" },
{ name = "ruff" },
]
[package.metadata]
requires-dist = [
{ name = "agentscope", specifier = ">=1.0.10" },
{ name = "dashscope", specifier = ">=1.25.5" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "pre-commit", specifier = ">=4.3.0" },
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "python-dotenv", specifier = ">=1.2.1" },
{ name = "pyyaml", specifier = ">=6.0.2" },
{ name = "ruff", specifier = ">=0.12.11" },
]
[[package]] [[package]]
name = "httpcore" name = "httpcore"
version = "1.0.9" version = "1.0.9"
@ -1436,29 +1463,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/28/d2/2ccc2b69a187b80fda3152745670cfba936704f296a9fa54c6c8ac694d12/python_socketio-5.16.0-py3-none-any.whl", hash = "sha256:d95802961e15c7bd54ecf884c6e7644f81be8460f0a02ee66b473df58088ee8a", size = 79607, upload-time = "2025-12-24T23:51:47.2Z" }, { url = "https://files.pythonhosted.org/packages/28/d2/2ccc2b69a187b80fda3152745670cfba936704f296a9fa54c6c8ac694d12/python_socketio-5.16.0-py3-none-any.whl", hash = "sha256:d95802961e15c7bd54ecf884c6e7644f81be8460f0a02ee66b473df58088ee8a", size = 79607, upload-time = "2025-12-24T23:51:47.2Z" },
] ]
[[package]]
name = "Heart_Mirror_Agent"
version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "agentscope" },
{ name = "loguru" },
{ name = "pre-commit" },
{ name = "pydantic" },
{ name = "pyyaml" },
{ name = "ruff" },
]
[package.metadata]
requires-dist = [
{ name = "agentscope", specifier = ">=1.0.10" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "pre-commit", specifier = ">=4.3.0" },
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "pyyaml", specifier = ">=6.0.2" },
{ name = "ruff", specifier = ">=0.12.11" },
]
[[package]] [[package]]
name = "pywin32" name = "pywin32"
version = "311" version = "311"