feat: ✨ 语音识别
This commit is contained in:
parent
48fe2f37ae
commit
35c9b9eb58
238
src/Module/asr/asr.py
Normal file
238
src/Module/asr/asr.py
Normal file
@ -0,0 +1,238 @@
|
||||
"""
|
||||
语音识别模块 - 基于阿里云 DashScope 通义千问 ASR
|
||||
支持音频文件识别
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import dashscope
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
from . import asrconfig as config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ASRResult:
|
||||
"""识别结果"""
|
||||
text: str # 识别文本
|
||||
success: bool = True # 是否成功
|
||||
error: Optional[str] = None # 错误信息
|
||||
request_id: Optional[str] = None # 请求ID
|
||||
|
||||
|
||||
class ASR:
|
||||
"""
|
||||
语音识别类
|
||||
|
||||
使用方式:
|
||||
asr = ASR()
|
||||
result = asr.recognize("audio.wav")
|
||||
print(result.text)
|
||||
"""
|
||||
|
||||
# 支持的格式
|
||||
SUPPORTED_FORMATS = config.SUPPORTED_FORMATS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = None,
|
||||
language: str = None,
|
||||
enable_itn: bool = None,
|
||||
context: str = None,
|
||||
):
|
||||
"""
|
||||
初始化 ASR
|
||||
|
||||
Args:
|
||||
model: 模型名称 ('qwen3-asr-flash' 或 'qwen3-asr-flash-filetrans')
|
||||
language: 语言代码 ('zh', 'en', 等),None 为自动检测
|
||||
enable_itn: 是否启用 ITN
|
||||
context: 上下文增强文本
|
||||
"""
|
||||
self._load_api_key()
|
||||
|
||||
self.model = model or config.MODEL
|
||||
self.language = language if language is not None else config.LANGUAGE
|
||||
self.enable_itn = enable_itn if enable_itn is not None else config.ENABLE_ITN
|
||||
self.context = context or ""
|
||||
|
||||
# 设置 API URL
|
||||
dashscope.base_http_api_url = config.API_URL
|
||||
|
||||
def _load_api_key(self) -> None:
|
||||
"""从 .env 加载 API Key"""
|
||||
current_dir = Path(__file__).parent
|
||||
for _ in range(5):
|
||||
env_path = current_dir / '.env'
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
break
|
||||
current_dir = current_dir.parent
|
||||
|
||||
api_key = os.environ.get('DASHSCOPE_API_KEY')
|
||||
if not api_key:
|
||||
raise ValueError('未找到 DASHSCOPE_API_KEY')
|
||||
|
||||
def recognize(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: str = None,
|
||||
context: str = None,
|
||||
) -> ASRResult:
|
||||
"""
|
||||
识别音频文件
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径(本地路径或 URL)
|
||||
language: 临时覆盖语言设置
|
||||
context: 临时覆盖上下文
|
||||
|
||||
Returns:
|
||||
ASRResult: 识别结果
|
||||
"""
|
||||
try:
|
||||
# 处理文件路径
|
||||
audio_uri = self._prepare_audio_uri(audio_path)
|
||||
|
||||
# 构建消息
|
||||
messages = self._build_messages(audio_uri, context)
|
||||
|
||||
# 构建 ASR 选项
|
||||
asr_options = {"enable_itn": self.enable_itn}
|
||||
if language or self.language:
|
||||
asr_options["language"] = language or self.language
|
||||
|
||||
# 调用 API
|
||||
response = MultiModalConversation.call(
|
||||
api_key=os.environ.get('DASHSCOPE_API_KEY'),
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
result_format="message",
|
||||
asr_options=asr_options
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
return self._parse_response(response)
|
||||
|
||||
except Exception as e:
|
||||
return ASRResult(
|
||||
text="",
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _prepare_audio_uri(self, audio_path: str) -> str:
|
||||
"""准备音频 URI"""
|
||||
# 如果已经是 URL
|
||||
if audio_path.startswith('http://') or audio_path.startswith('https://'):
|
||||
return audio_path
|
||||
|
||||
# 如果已经是 file:// 格式
|
||||
if audio_path.startswith('file://'):
|
||||
return audio_path
|
||||
|
||||
# 本地文件,转换为 file:// 格式
|
||||
path = Path(audio_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||||
|
||||
# 检查文件大小
|
||||
file_size = path.stat().st_size
|
||||
if file_size > config.MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"文件大小 ({file_size} bytes) 超过限制 ({
|
||||
config.MAX_FILE_SIZE} bytes)")
|
||||
|
||||
# 检查格式
|
||||
suffix = path.suffix.lower().lstrip('.')
|
||||
if suffix not in self.SUPPORTED_FORMATS:
|
||||
raise ValueError(f"不支持的音频格式: {suffix}")
|
||||
|
||||
# 转换为绝对路径
|
||||
abs_path = path.resolve()
|
||||
return f"file://{abs_path}"
|
||||
|
||||
def _build_messages(self, audio_uri: str, context: str = None) -> list:
|
||||
"""构建消息"""
|
||||
ctx = context or self.context
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"text": ctx}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"audio": audio_uri}]
|
||||
}
|
||||
]
|
||||
return messages
|
||||
|
||||
def _parse_response(self, response) -> ASRResult:
|
||||
"""解析 API 响应"""
|
||||
if response.status_code != 200:
|
||||
return ASRResult(
|
||||
text="",
|
||||
success=False,
|
||||
error=f"API 错误: {response.code} - {response.message}",
|
||||
request_id=response.request_id
|
||||
)
|
||||
|
||||
# 提取文本
|
||||
try:
|
||||
content = response.output.choices[0].message.content
|
||||
if isinstance(content, list):
|
||||
text = content[0].get("text", "")
|
||||
else:
|
||||
text = str(content)
|
||||
|
||||
return ASRResult(
|
||||
text=text,
|
||||
success=True,
|
||||
request_id=response.request_id
|
||||
)
|
||||
except Exception as e:
|
||||
return ASRResult(
|
||||
text="",
|
||||
success=False,
|
||||
error=f"解析响应失败: {e}",
|
||||
request_id=getattr(response, 'request_id', None)
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 便捷函数
|
||||
# ============================================================
|
||||
|
||||
def recognize(audio_path: str, **kwargs) -> ASRResult:
|
||||
"""
|
||||
便捷的识别函数
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
**kwargs: 传递给 ASR 的参数
|
||||
|
||||
Returns:
|
||||
ASRResult: 识别结果
|
||||
"""
|
||||
asr = ASR(**kwargs)
|
||||
return asr.recognize(audio_path)
|
||||
|
||||
|
||||
def recognize_text(audio_path: str, **kwargs) -> str:
|
||||
"""
|
||||
便捷函数,直接返回识别文本
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
**kwargs: 传递给 ASR 的参数
|
||||
|
||||
Returns:
|
||||
str: 识别的文本,失败返回空字符串
|
||||
"""
|
||||
result = recognize(audio_path, **kwargs)
|
||||
return result.text if result.success else ""
|
||||
49
src/Module/asr/asrconfig.py
Normal file
49
src/Module/asr/asrconfig.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""
|
||||
ASR 配置文件
|
||||
定义语音识别的默认参数
|
||||
"""
|
||||
|
||||
# ============================================================
|
||||
# 模型配置
|
||||
# ============================================================
|
||||
|
||||
# 默认模型
|
||||
# qwen3-asr-flash: 短音频 (≤5分钟)
|
||||
# qwen3-asr-flash-filetrans: 长音频 (≤12小时)
|
||||
MODEL = 'qwen3-asr-flash'
|
||||
|
||||
# API URL (北京地域)
|
||||
API_URL = 'https://dashscope.aliyuncs.com/api/v1'
|
||||
|
||||
# 新加坡地域 URL (备用)
|
||||
# API_URL = 'https://dashscope-intl.aliyuncs.com/api/v1'
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 识别参数
|
||||
# ============================================================
|
||||
|
||||
# 语言 (可选值: 'zh', 'en', 'ja', 'ko', 'de', 'fr', 'ru', 'es', 'it', 'pt', 'ar', 等)
|
||||
# None 表示自动检测
|
||||
LANGUAGE = None
|
||||
|
||||
# 是否启用 ITN (Inverse Text Normalization)
|
||||
# 将口语数字转为书面形式,如"一百二十三"→"123"
|
||||
ENABLE_ITN = False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 支持的音频格式
|
||||
# ============================================================
|
||||
|
||||
SUPPORTED_FORMATS = [
|
||||
'aac', 'amr', 'avi', 'aiff', 'flac', 'flv',
|
||||
'm4a', 'mkv', 'mp3', 'mpeg', 'ogg', 'opus',
|
||||
'wav', 'webm', 'wma', 'wmv'
|
||||
]
|
||||
|
||||
# 最大文件大小 (字节) - 10MB
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# 最大音频时长 (秒) - 5分钟
|
||||
MAX_DURATION = 5 * 60
|
||||
83
test/asr/test_asr.py
Normal file
83
test/asr/test_asr.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""
|
||||
ASR 语音识别测试
|
||||
使用 TTS 生成的音频文件测试识别
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from src.Module.asr.asr import ASR, recognize, recognize_text
|
||||
|
||||
|
||||
def test_recognize_wav():
|
||||
"""测试识别 WAV 文件"""
|
||||
print("=" * 60)
|
||||
print(" ASR 语音识别测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用 TTS 测试生成的音频文件
|
||||
audio_file = Path(__file__).parent.parent / 'tts' / 'output' / 'stream_test.wav'
|
||||
|
||||
if not audio_file.exists():
|
||||
print(f"[跳过] 音频文件不存在: {audio_file}")
|
||||
print("请先运行 TTS 测试生成音频文件")
|
||||
return False
|
||||
|
||||
print(f"\n音频文件: {audio_file}")
|
||||
print(f"文件大小: {audio_file.stat().st_size / 1024:.1f} KB")
|
||||
|
||||
# 创建 ASR 实例
|
||||
asr = ASR(
|
||||
model='qwen3-asr-flash',
|
||||
language='zh',
|
||||
)
|
||||
|
||||
print("\n开始识别...")
|
||||
result = asr.recognize(str(audio_file))
|
||||
|
||||
print(f"\n识别结果:")
|
||||
print(f" 成功: {result.success}")
|
||||
print(f" 文本: {result.text}")
|
||||
if result.error:
|
||||
print(f" 错误: {result.error}")
|
||||
if result.request_id:
|
||||
print(f" 请求ID: {result.request_id}")
|
||||
|
||||
return result.success
|
||||
|
||||
|
||||
def test_convenient_function():
|
||||
"""测试便捷函数"""
|
||||
print("\n" + "=" * 60)
|
||||
print("便捷函数测试")
|
||||
print("=" * 60)
|
||||
|
||||
audio_file = Path(__file__).parent.parent / 'tts' / \
|
||||
'output' / 'bidirectional_test.wav'
|
||||
|
||||
if not audio_file.exists():
|
||||
print(f"[跳过] 音频文件不存在: {audio_file}")
|
||||
return False
|
||||
|
||||
print(f"\n音频文件: {audio_file.name}")
|
||||
|
||||
# 使用便捷函数
|
||||
text = recognize_text(str(audio_file), language='zh')
|
||||
|
||||
print(f"识别文本: {text}")
|
||||
|
||||
return len(text) > 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
results = []
|
||||
|
||||
success1 = test_recognize_wav()
|
||||
results.append(("WAV 文件识别", success1))
|
||||
|
||||
success2 = test_convenient_function()
|
||||
results.append(("便捷函数", success2))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试结果:")
|
||||
for name, success in results:
|
||||
status = "✓ 通过" if success else "✗ 失败/跳过"
|
||||
print(f" {name}: {status}")
|
||||
Loading…
x
Reference in New Issue
Block a user