diff --git a/src/Module/asr/asr.py b/src/Module/asr/asr.py new file mode 100644 index 0000000..a1a1d3a --- /dev/null +++ b/src/Module/asr/asr.py @@ -0,0 +1,238 @@ +""" +语音识别模块 - 基于阿里云 DashScope 通义千问 ASR +支持音频文件识别 +""" + +import os +from typing import Optional +from pathlib import Path +from dataclasses import dataclass +from dotenv import load_dotenv + +import dashscope +from dashscope import MultiModalConversation + +from . import asrconfig as config + + +@dataclass +class ASRResult: + """识别结果""" + text: str # 识别文本 + success: bool = True # 是否成功 + error: Optional[str] = None # 错误信息 + request_id: Optional[str] = None # 请求ID + + +class ASR: + """ + 语音识别类 + + 使用方式: + asr = ASR() + result = asr.recognize("audio.wav") + print(result.text) + """ + + # 支持的格式 + SUPPORTED_FORMATS = config.SUPPORTED_FORMATS + + def __init__( + self, + model: str = None, + language: str = None, + enable_itn: bool = None, + context: str = None, + ): + """ + 初始化 ASR + + Args: + model: 模型名称 ('qwen3-asr-flash' 或 'qwen3-asr-flash-filetrans') + language: 语言代码 ('zh', 'en', 等),None 为自动检测 + enable_itn: 是否启用 ITN + context: 上下文增强文本 + """ + self._load_api_key() + + self.model = model or config.MODEL + self.language = language if language is not None else config.LANGUAGE + self.enable_itn = enable_itn if enable_itn is not None else config.ENABLE_ITN + self.context = context or "" + + # 设置 API URL + dashscope.base_http_api_url = config.API_URL + + def _load_api_key(self) -> None: + """从 .env 加载 API Key""" + current_dir = Path(__file__).parent + for _ in range(5): + env_path = current_dir / '.env' + if env_path.exists(): + load_dotenv(env_path) + break + current_dir = current_dir.parent + + api_key = os.environ.get('DASHSCOPE_API_KEY') + if not api_key: + raise ValueError('未找到 DASHSCOPE_API_KEY') + + def recognize( + self, + audio_path: str, + language: str = None, + context: str = None, + ) -> ASRResult: + """ + 识别音频文件 + + Args: + audio_path: 音频文件路径(本地路径或 URL) + language: 临时覆盖语言设置 + context: 临时覆盖上下文 + + Returns: + ASRResult: 识别结果 + """ + try: + # 处理文件路径 + audio_uri = self._prepare_audio_uri(audio_path) + + # 构建消息 + messages = self._build_messages(audio_uri, context) + + # 构建 ASR 选项 + asr_options = {"enable_itn": self.enable_itn} + if language or self.language: + asr_options["language"] = language or self.language + + # 调用 API + response = MultiModalConversation.call( + api_key=os.environ.get('DASHSCOPE_API_KEY'), + model=self.model, + messages=messages, + result_format="message", + asr_options=asr_options + ) + + # 解析结果 + return self._parse_response(response) + + except Exception as e: + return ASRResult( + text="", + success=False, + error=str(e) + ) + + def _prepare_audio_uri(self, audio_path: str) -> str: + """准备音频 URI""" + # 如果已经是 URL + if audio_path.startswith('http://') or audio_path.startswith('https://'): + return audio_path + + # 如果已经是 file:// 格式 + if audio_path.startswith('file://'): + return audio_path + + # 本地文件,转换为 file:// 格式 + path = Path(audio_path) + if not path.exists(): + raise FileNotFoundError(f"音频文件不存在: {audio_path}") + + # 检查文件大小 + file_size = path.stat().st_size + if file_size > config.MAX_FILE_SIZE: + raise ValueError( + f"文件大小 ({file_size} bytes) 超过限制 ({ + config.MAX_FILE_SIZE} bytes)") + + # 检查格式 + suffix = path.suffix.lower().lstrip('.') + if suffix not in self.SUPPORTED_FORMATS: + raise ValueError(f"不支持的音频格式: {suffix}") + + # 转换为绝对路径 + abs_path = path.resolve() + return f"file://{abs_path}" + + def _build_messages(self, audio_uri: str, context: str = None) -> list: + """构建消息""" + ctx = context or self.context + + messages = [ + { + "role": "system", + "content": [{"text": ctx}] + }, + { + "role": "user", + "content": [{"audio": audio_uri}] + } + ] + return messages + + def _parse_response(self, response) -> ASRResult: + """解析 API 响应""" + if response.status_code != 200: + return ASRResult( + text="", + success=False, + error=f"API 错误: {response.code} - {response.message}", + request_id=response.request_id + ) + + # 提取文本 + try: + content = response.output.choices[0].message.content + if isinstance(content, list): + text = content[0].get("text", "") + else: + text = str(content) + + return ASRResult( + text=text, + success=True, + request_id=response.request_id + ) + except Exception as e: + return ASRResult( + text="", + success=False, + error=f"解析响应失败: {e}", + request_id=getattr(response, 'request_id', None) + ) + + +# ============================================================ +# 便捷函数 +# ============================================================ + +def recognize(audio_path: str, **kwargs) -> ASRResult: + """ + 便捷的识别函数 + + Args: + audio_path: 音频文件路径 + **kwargs: 传递给 ASR 的参数 + + Returns: + ASRResult: 识别结果 + """ + asr = ASR(**kwargs) + return asr.recognize(audio_path) + + +def recognize_text(audio_path: str, **kwargs) -> str: + """ + 便捷函数,直接返回识别文本 + + Args: + audio_path: 音频文件路径 + **kwargs: 传递给 ASR 的参数 + + Returns: + str: 识别的文本,失败返回空字符串 + """ + result = recognize(audio_path, **kwargs) + return result.text if result.success else "" diff --git a/src/Module/asr/asrconfig.py b/src/Module/asr/asrconfig.py new file mode 100644 index 0000000..4a056eb --- /dev/null +++ b/src/Module/asr/asrconfig.py @@ -0,0 +1,49 @@ +""" +ASR 配置文件 +定义语音识别的默认参数 +""" + +# ============================================================ +# 模型配置 +# ============================================================ + +# 默认模型 +# qwen3-asr-flash: 短音频 (≤5分钟) +# qwen3-asr-flash-filetrans: 长音频 (≤12小时) +MODEL = 'qwen3-asr-flash' + +# API URL (北京地域) +API_URL = 'https://dashscope.aliyuncs.com/api/v1' + +# 新加坡地域 URL (备用) +# API_URL = 'https://dashscope-intl.aliyuncs.com/api/v1' + + +# ============================================================ +# 识别参数 +# ============================================================ + +# 语言 (可选值: 'zh', 'en', 'ja', 'ko', 'de', 'fr', 'ru', 'es', 'it', 'pt', 'ar', 等) +# None 表示自动检测 +LANGUAGE = None + +# 是否启用 ITN (Inverse Text Normalization) +# 将口语数字转为书面形式,如"一百二十三"→"123" +ENABLE_ITN = False + + +# ============================================================ +# 支持的音频格式 +# ============================================================ + +SUPPORTED_FORMATS = [ + 'aac', 'amr', 'avi', 'aiff', 'flac', 'flv', + 'm4a', 'mkv', 'mp3', 'mpeg', 'ogg', 'opus', + 'wav', 'webm', 'wma', 'wmv' +] + +# 最大文件大小 (字节) - 10MB +MAX_FILE_SIZE = 10 * 1024 * 1024 + +# 最大音频时长 (秒) - 5分钟 +MAX_DURATION = 5 * 60 diff --git a/test/asr/test_asr.py b/test/asr/test_asr.py new file mode 100644 index 0000000..18af493 --- /dev/null +++ b/test/asr/test_asr.py @@ -0,0 +1,83 @@ +""" +ASR 语音识别测试 +使用 TTS 生成的音频文件测试识别 +""" + +from pathlib import Path +from src.Module.asr.asr import ASR, recognize, recognize_text + + +def test_recognize_wav(): + """测试识别 WAV 文件""" + print("=" * 60) + print(" ASR 语音识别测试") + print("=" * 60) + + # 使用 TTS 测试生成的音频文件 + audio_file = Path(__file__).parent.parent / 'tts' / 'output' / 'stream_test.wav' + + if not audio_file.exists(): + print(f"[跳过] 音频文件不存在: {audio_file}") + print("请先运行 TTS 测试生成音频文件") + return False + + print(f"\n音频文件: {audio_file}") + print(f"文件大小: {audio_file.stat().st_size / 1024:.1f} KB") + + # 创建 ASR 实例 + asr = ASR( + model='qwen3-asr-flash', + language='zh', + ) + + print("\n开始识别...") + result = asr.recognize(str(audio_file)) + + print(f"\n识别结果:") + print(f" 成功: {result.success}") + print(f" 文本: {result.text}") + if result.error: + print(f" 错误: {result.error}") + if result.request_id: + print(f" 请求ID: {result.request_id}") + + return result.success + + +def test_convenient_function(): + """测试便捷函数""" + print("\n" + "=" * 60) + print("便捷函数测试") + print("=" * 60) + + audio_file = Path(__file__).parent.parent / 'tts' / \ + 'output' / 'bidirectional_test.wav' + + if not audio_file.exists(): + print(f"[跳过] 音频文件不存在: {audio_file}") + return False + + print(f"\n音频文件: {audio_file.name}") + + # 使用便捷函数 + text = recognize_text(str(audio_file), language='zh') + + print(f"识别文本: {text}") + + return len(text) > 0 + + +if __name__ == '__main__': + results = [] + + success1 = test_recognize_wav() + results.append(("WAV 文件识别", success1)) + + success2 = test_convenient_function() + results.append(("便捷函数", success2)) + + print("\n" + "=" * 60) + print("测试结果:") + for name, success in results: + status = "✓ 通过" if success else "✗ 失败/跳过" + print(f" {name}: {status}")