feat: ✨ 语音识别
This commit is contained in:
parent
48fe2f37ae
commit
35c9b9eb58
238
src/Module/asr/asr.py
Normal file
238
src/Module/asr/asr.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
"""
|
||||||
|
语音识别模块 - 基于阿里云 DashScope 通义千问 ASR
|
||||||
|
支持音频文件识别
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import dashscope
|
||||||
|
from dashscope import MultiModalConversation
|
||||||
|
|
||||||
|
from . import asrconfig as config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ASRResult:
|
||||||
|
"""识别结果"""
|
||||||
|
text: str # 识别文本
|
||||||
|
success: bool = True # 是否成功
|
||||||
|
error: Optional[str] = None # 错误信息
|
||||||
|
request_id: Optional[str] = None # 请求ID
|
||||||
|
|
||||||
|
|
||||||
|
class ASR:
|
||||||
|
"""
|
||||||
|
语音识别类
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
asr = ASR()
|
||||||
|
result = asr.recognize("audio.wav")
|
||||||
|
print(result.text)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 支持的格式
|
||||||
|
SUPPORTED_FORMATS = config.SUPPORTED_FORMATS
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = None,
|
||||||
|
language: str = None,
|
||||||
|
enable_itn: bool = None,
|
||||||
|
context: str = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化 ASR
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 模型名称 ('qwen3-asr-flash' 或 'qwen3-asr-flash-filetrans')
|
||||||
|
language: 语言代码 ('zh', 'en', 等),None 为自动检测
|
||||||
|
enable_itn: 是否启用 ITN
|
||||||
|
context: 上下文增强文本
|
||||||
|
"""
|
||||||
|
self._load_api_key()
|
||||||
|
|
||||||
|
self.model = model or config.MODEL
|
||||||
|
self.language = language if language is not None else config.LANGUAGE
|
||||||
|
self.enable_itn = enable_itn if enable_itn is not None else config.ENABLE_ITN
|
||||||
|
self.context = context or ""
|
||||||
|
|
||||||
|
# 设置 API URL
|
||||||
|
dashscope.base_http_api_url = config.API_URL
|
||||||
|
|
||||||
|
def _load_api_key(self) -> None:
|
||||||
|
"""从 .env 加载 API Key"""
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
for _ in range(5):
|
||||||
|
env_path = current_dir / '.env'
|
||||||
|
if env_path.exists():
|
||||||
|
load_dotenv(env_path)
|
||||||
|
break
|
||||||
|
current_dir = current_dir.parent
|
||||||
|
|
||||||
|
api_key = os.environ.get('DASHSCOPE_API_KEY')
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError('未找到 DASHSCOPE_API_KEY')
|
||||||
|
|
||||||
|
def recognize(
|
||||||
|
self,
|
||||||
|
audio_path: str,
|
||||||
|
language: str = None,
|
||||||
|
context: str = None,
|
||||||
|
) -> ASRResult:
|
||||||
|
"""
|
||||||
|
识别音频文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: 音频文件路径(本地路径或 URL)
|
||||||
|
language: 临时覆盖语言设置
|
||||||
|
context: 临时覆盖上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ASRResult: 识别结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 处理文件路径
|
||||||
|
audio_uri = self._prepare_audio_uri(audio_path)
|
||||||
|
|
||||||
|
# 构建消息
|
||||||
|
messages = self._build_messages(audio_uri, context)
|
||||||
|
|
||||||
|
# 构建 ASR 选项
|
||||||
|
asr_options = {"enable_itn": self.enable_itn}
|
||||||
|
if language or self.language:
|
||||||
|
asr_options["language"] = language or self.language
|
||||||
|
|
||||||
|
# 调用 API
|
||||||
|
response = MultiModalConversation.call(
|
||||||
|
api_key=os.environ.get('DASHSCOPE_API_KEY'),
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
result_format="message",
|
||||||
|
asr_options=asr_options
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
|
return self._parse_response(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ASRResult(
|
||||||
|
text="",
|
||||||
|
success=False,
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_audio_uri(self, audio_path: str) -> str:
|
||||||
|
"""准备音频 URI"""
|
||||||
|
# 如果已经是 URL
|
||||||
|
if audio_path.startswith('http://') or audio_path.startswith('https://'):
|
||||||
|
return audio_path
|
||||||
|
|
||||||
|
# 如果已经是 file:// 格式
|
||||||
|
if audio_path.startswith('file://'):
|
||||||
|
return audio_path
|
||||||
|
|
||||||
|
# 本地文件,转换为 file:// 格式
|
||||||
|
path = Path(audio_path)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||||||
|
|
||||||
|
# 检查文件大小
|
||||||
|
file_size = path.stat().st_size
|
||||||
|
if file_size > config.MAX_FILE_SIZE:
|
||||||
|
raise ValueError(
|
||||||
|
f"文件大小 ({file_size} bytes) 超过限制 ({
|
||||||
|
config.MAX_FILE_SIZE} bytes)")
|
||||||
|
|
||||||
|
# 检查格式
|
||||||
|
suffix = path.suffix.lower().lstrip('.')
|
||||||
|
if suffix not in self.SUPPORTED_FORMATS:
|
||||||
|
raise ValueError(f"不支持的音频格式: {suffix}")
|
||||||
|
|
||||||
|
# 转换为绝对路径
|
||||||
|
abs_path = path.resolve()
|
||||||
|
return f"file://{abs_path}"
|
||||||
|
|
||||||
|
def _build_messages(self, audio_uri: str, context: str = None) -> list:
|
||||||
|
"""构建消息"""
|
||||||
|
ctx = context or self.context
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"text": ctx}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"audio": audio_uri}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _parse_response(self, response) -> ASRResult:
|
||||||
|
"""解析 API 响应"""
|
||||||
|
if response.status_code != 200:
|
||||||
|
return ASRResult(
|
||||||
|
text="",
|
||||||
|
success=False,
|
||||||
|
error=f"API 错误: {response.code} - {response.message}",
|
||||||
|
request_id=response.request_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取文本
|
||||||
|
try:
|
||||||
|
content = response.output.choices[0].message.content
|
||||||
|
if isinstance(content, list):
|
||||||
|
text = content[0].get("text", "")
|
||||||
|
else:
|
||||||
|
text = str(content)
|
||||||
|
|
||||||
|
return ASRResult(
|
||||||
|
text=text,
|
||||||
|
success=True,
|
||||||
|
request_id=response.request_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ASRResult(
|
||||||
|
text="",
|
||||||
|
success=False,
|
||||||
|
error=f"解析响应失败: {e}",
|
||||||
|
request_id=getattr(response, 'request_id', None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 便捷函数
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def recognize(audio_path: str, **kwargs) -> ASRResult:
|
||||||
|
"""
|
||||||
|
便捷的识别函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: 音频文件路径
|
||||||
|
**kwargs: 传递给 ASR 的参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ASRResult: 识别结果
|
||||||
|
"""
|
||||||
|
asr = ASR(**kwargs)
|
||||||
|
return asr.recognize(audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
def recognize_text(audio_path: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
便捷函数,直接返回识别文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: 音频文件路径
|
||||||
|
**kwargs: 传递给 ASR 的参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 识别的文本,失败返回空字符串
|
||||||
|
"""
|
||||||
|
result = recognize(audio_path, **kwargs)
|
||||||
|
return result.text if result.success else ""
|
||||||
49
src/Module/asr/asrconfig.py
Normal file
49
src/Module/asr/asrconfig.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
ASR 配置文件
|
||||||
|
定义语音识别的默认参数
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 模型配置
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 默认模型
|
||||||
|
# qwen3-asr-flash: 短音频 (≤5分钟)
|
||||||
|
# qwen3-asr-flash-filetrans: 长音频 (≤12小时)
|
||||||
|
MODEL = 'qwen3-asr-flash'
|
||||||
|
|
||||||
|
# API URL (北京地域)
|
||||||
|
API_URL = 'https://dashscope.aliyuncs.com/api/v1'
|
||||||
|
|
||||||
|
# 新加坡地域 URL (备用)
|
||||||
|
# API_URL = 'https://dashscope-intl.aliyuncs.com/api/v1'
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 识别参数
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 语言 (可选值: 'zh', 'en', 'ja', 'ko', 'de', 'fr', 'ru', 'es', 'it', 'pt', 'ar', 等)
|
||||||
|
# None 表示自动检测
|
||||||
|
LANGUAGE = None
|
||||||
|
|
||||||
|
# 是否启用 ITN (Inverse Text Normalization)
|
||||||
|
# 将口语数字转为书面形式,如"一百二十三"→"123"
|
||||||
|
ENABLE_ITN = False
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 支持的音频格式
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
SUPPORTED_FORMATS = [
|
||||||
|
'aac', 'amr', 'avi', 'aiff', 'flac', 'flv',
|
||||||
|
'm4a', 'mkv', 'mp3', 'mpeg', 'ogg', 'opus',
|
||||||
|
'wav', 'webm', 'wma', 'wmv'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 最大文件大小 (字节) - 10MB
|
||||||
|
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
# 最大音频时长 (秒) - 5分钟
|
||||||
|
MAX_DURATION = 5 * 60
|
||||||
83
test/asr/test_asr.py
Normal file
83
test/asr/test_asr.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
ASR 语音识别测试
|
||||||
|
使用 TTS 生成的音频文件测试识别
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from src.Module.asr.asr import ASR, recognize, recognize_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_recognize_wav():
|
||||||
|
"""测试识别 WAV 文件"""
|
||||||
|
print("=" * 60)
|
||||||
|
print(" ASR 语音识别测试")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 使用 TTS 测试生成的音频文件
|
||||||
|
audio_file = Path(__file__).parent.parent / 'tts' / 'output' / 'stream_test.wav'
|
||||||
|
|
||||||
|
if not audio_file.exists():
|
||||||
|
print(f"[跳过] 音频文件不存在: {audio_file}")
|
||||||
|
print("请先运行 TTS 测试生成音频文件")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"\n音频文件: {audio_file}")
|
||||||
|
print(f"文件大小: {audio_file.stat().st_size / 1024:.1f} KB")
|
||||||
|
|
||||||
|
# 创建 ASR 实例
|
||||||
|
asr = ASR(
|
||||||
|
model='qwen3-asr-flash',
|
||||||
|
language='zh',
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n开始识别...")
|
||||||
|
result = asr.recognize(str(audio_file))
|
||||||
|
|
||||||
|
print(f"\n识别结果:")
|
||||||
|
print(f" 成功: {result.success}")
|
||||||
|
print(f" 文本: {result.text}")
|
||||||
|
if result.error:
|
||||||
|
print(f" 错误: {result.error}")
|
||||||
|
if result.request_id:
|
||||||
|
print(f" 请求ID: {result.request_id}")
|
||||||
|
|
||||||
|
return result.success
|
||||||
|
|
||||||
|
|
||||||
|
def test_convenient_function():
|
||||||
|
"""测试便捷函数"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("便捷函数测试")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
audio_file = Path(__file__).parent.parent / 'tts' / \
|
||||||
|
'output' / 'bidirectional_test.wav'
|
||||||
|
|
||||||
|
if not audio_file.exists():
|
||||||
|
print(f"[跳过] 音频文件不存在: {audio_file}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"\n音频文件: {audio_file.name}")
|
||||||
|
|
||||||
|
# 使用便捷函数
|
||||||
|
text = recognize_text(str(audio_file), language='zh')
|
||||||
|
|
||||||
|
print(f"识别文本: {text}")
|
||||||
|
|
||||||
|
return len(text) > 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
results = []
|
||||||
|
|
||||||
|
success1 = test_recognize_wav()
|
||||||
|
results.append(("WAV 文件识别", success1))
|
||||||
|
|
||||||
|
success2 = test_convenient_function()
|
||||||
|
results.append(("便捷函数", success2))
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试结果:")
|
||||||
|
for name, success in results:
|
||||||
|
status = "✓ 通过" if success else "✗ 失败/跳过"
|
||||||
|
print(f" {name}: {status}")
|
||||||
Loading…
x
Reference in New Issue
Block a user