你的代码太可怕了居然把我的源文件删除了!!!这个是严重的问题!!
import asyncio
import aiohttp
import json
import struct
import gzip
import uuid
import logging
import os
import subprocess
from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('run.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 常量定义
DEFAULT_SAMPLE_RATE = 16000
DEFAULT_OUTPUT_ENCODING = 'utf-8'
# 协议相关常量
ProtocolVersion = type('ProtocolVersion', (), {'V1': 0b0001})
MessageType = type('MessageType', (), {
'CLIENT_FULL_REQUEST': 0b0001,
'CLIENT_AUDIO_ONLY_REQUEST': 0b0010,
'SERVER_FULL_RESPONSE': 0b1001,
'SERVER_ERROR_RESPONSE': 0b1111
})
MessageTypeSpecificFlags = type('MessageTypeSpecificFlags', (), {
'NO_SEQUENCE': 0b0000,
'POS_SEQUENCE': 0b0001,
'NEG_SEQUENCE': 0b0010,
'NEG_WITH_SEQUENCE': 0b0011
})
SerializationType = type('SerializationType', (), {
'NO_SERIALIZATION': 0b0000,
'JSON': 0b0001
})
CompressionType = type('CompressionType', (), {'GZIP': 0b0001})
# 配置类
class Config:
def __init__(self):
self.auth = {
",
"
}
@property
def app_key(self) -> str:
return self.auth["app_key"]
@property
def access_key(self) -> str:
return self.auth["access_key"]
config = Config()
# 通用工具类
class CommonUtils:
@staticmethod
def gzip_compress(data: bytes) -> bytes:
return gzip.compress(data)
@staticmethod
def gzip_decompress(data: bytes) -> bytes:
return gzip.decompress(data)
@staticmethod
def convert_audio_to_pcm(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes:
"""直接将音频文件(MP3/WAV等)转换为原始PCM数据"""
try:
cmd = [
"ffmpeg", "-v", "quiet", "-y", "-i", audio_path,
"-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate),
"-f", "s16le", "-"
]
result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg转换失败: {e.stderr.decode()}")
raise RuntimeError(f"音频转换失败: {e.stderr.decode()}")
# 请求头构建类
class AsrRequestHeader:
def __init__(self):
self.message_type = MessageType.CLIENT_FULL_REQUEST
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
self.serialization_type = SerializationType.JSON
self.compression_type = CompressionType.GZIP
self.reserved_data = bytes([0x00])
def to_bytes(self) -> bytes:
header = bytearray()
header.append((ProtocolVersion.V1 << 4) | 1)
header.append((self.message_type << 4) | self.message_type_specific_flags)
header.append((self.serialization_type << 4) | self.compression_type)
header.extend(self.reserved_data)
return bytes(header)
@staticmethod
def default_header() -> 'AsrRequestHeader':
return AsrRequestHeader()
# 请求构建器
class RequestBuilder:
@staticmethod
def new_auth_headers() -> Dict[str, str]:
return {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Request-Id": str(uuid.uuid4()),
"X-Api-,
"X-Api-
}
@staticmethod
def new_full_client_request(seq: int) -> bytes:
header = AsrRequestHeader.default_header()
payload = {
"user": {"uid": "demo_uid"},
"audio": {"format": "pcm", "codec": "raw", "rate": 16000, "bits": 16, "channel": 1},
"request": {
"model_name": "bigmodel",
"enable_itn": True,
"enable_punc": True,
"enable_ddc": True,
"show_utterances": True,
"enable_nonstream": False
}
}
payload_bytes = json.dumps(payload).encode('utf-8')
compressed_payload = CommonUtils.gzip_compress(payload_bytes)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
request.extend(struct.pack('>I', len(compressed_payload)))
request.extend(compressed_payload)
return bytes(request)
@staticmethod
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
header = AsrRequestHeader.default_header()
header.message_type = MessageType.CLIENT_AUDIO_ONLY_REQUEST
if is_last:
header.message_type_specific_flags = MessageTypeSpecificFlags.NEG_WITH_SEQUENCE
seq = -seq
else:
header.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
compressed_segment = CommonUtils.gzip_compress(segment)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
request.extend(struct.pack('>I', len(compressed_segment)))
request.extend(compressed_segment)
return bytes(request)
# 响应处理类
class AsrResponse:
def __init__(self):
self.code = 0
self.event = 0
self.is_last_package = False
self.payload_sequence = 0
self.payload_size = 0
self.payload_msg = None
self.last_text = "" # 记录上一次的完整文本,用于去重
def to_dict(self) -> Dict[str, Any]:
return {
"code": self.code,
"event": self.event,
"is_last_package": self.is_last_package,
"payload_sequence": self.payload_sequence,
"payload_size": self.payload_size,
"payload_msg": self.payload_msg
}
def get_incremental_text(self) -> str:
"""提取增量文本(解决重复问题)"""
current_text = ""
if self.payload_msg and isinstance(self.payload_msg, dict):
if "result" in self.payload_msg:
result = self.payload_msg["result"]
if isinstance(result, dict):
# 优先获取完整文本
if "text" in result:
current_text = result["text"]
# 分句文本拼接
elif "utterances" in result and isinstance(result["utterances"], list):
utterances = []
for utt in result["utterances"]:
if isinstance(utt, dict) and "text" in utt:
utterances.append(utt["text"])
current_text = "".join(utterances)
# 计算增量文本(只返回新增部分)
incremental_text = current_text[len(self.last_text):]
self.last_text = current_text # 更新上一次的文本
return incremental_text
class ResponseParser:
@staticmethod
def parse_response(msg: bytes) -> AsrResponse:
response = AsrResponse()
header_size = msg[0] & 0x0f
message_type = msg[1] >> 4
message_type_specific_flags = msg[1] & 0x0f
serialization_method = msg[2] >> 4
message_compression = msg[2] & 0x0f
payload = msg[header_size*4:]
# 解析标志位
if message_type_specific_flags & 0x01:
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
if message_type_specific_flags & 0x02:
response.is_last_package = True
if message_type_specific_flags & 0x04:
response.event = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
# 解析消息类型
if message_type == MessageType.SERVER_FULL_RESPONSE:
response.payload_size = struct.unpack('>I', payload[:4])[0]
payload = payload[4:]
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
response.code = struct.unpack('>i', payload[:4])[0]
response.payload_size = struct.unpack('>I', payload[4:8])[0]
payload = payload[8:]
if not payload:
return response
# 解压缩
if message_compression == CompressionType.GZIP:
try:
payload = CommonUtils.gzip_decompress(payload)
except Exception as e:
logger.error(f"解压缩失败: {e}")
return response
# 解析JSON
try:
if serialization_method == SerializationType.JSON:
response.payload_msg = json.loads(payload.decode('utf-8'))
except Exception as e:
logger.error(f"解析JSON失败: {e}")
return response
# 核心客户端类
class AsrWsClient:
def __init__(self, url: str, segment_duration: int = 200, output_file: str = None):
self.seq = 1
self.url = url
self.segment_duration = segment_duration
self.conn = None
self.session = None
self.output_file = output_file
self.full_recognized_text = "" # 存储完整的识别文本
self.response_parser = ResponseParser()
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc, tb):
if self.conn and not self.conn.closed:
await self.conn.close()
if self.session and not self.session.closed:
await self.session.close()
async def read_audio_data(self, file_path: str) -> bytes:
"""读取并转换音频为PCM"""
logger.info(f"正在转换音频文件 {file_path} 为PCM格式...")
pcm_data = CommonUtils.convert_audio_to_pcm(file_path, DEFAULT_SAMPLE_RATE)
logger.info(f"PCM转换完成,大小: {len(pcm_data)} 字节")
return pcm_data
def get_segment_size(self) -> int:
"""计算PCM分段大小"""
bytes_per_sample = 2 # 16bit
samples_per_ms = DEFAULT_SAMPLE_RATE / 1000
samples_per_segment = samples_per_ms * self.segment_duration
segment_size = int(samples_per_segment * bytes_per_sample)
# 确保是2的倍数
if segment_size % 2 != 0:
segment_size += 1
logger.info(f"计算分段大小: {segment_size} 字节 (时长: {self.segment_duration}ms)")
return segment_size
async def create_connection(self) -> None:
"""创建WebSocket连接"""
headers = RequestBuilder.new_auth_headers()
self.conn = await self.session.ws_connect(self.url, headers=headers)
logger.info(f"已连接到 {self.url}")
async def send_full_client_request(self) -> None:
"""发送初始请求"""
request = RequestBuilder.new_full_client_request(self.seq)
self.seq += 1
await self.conn.send_bytes(request)
logger.info(f"发送初始请求,序列号: {self.seq-1}")
# 接收初始响应
msg = await self.conn.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = self.response_parser.parse_response(msg.data)
logger.info(f"收到初始响应: {response.to_dict()}")
async def send_audio_segments(self, content: bytes) -> None:
"""发送音频分段"""
segment_size = self.get_segment_size()
audio_segments = []
for i in range(0, len(content), segment_size):
end = min(i + segment_size, len(content))
audio_segments.append(content[i:end])
total_segments = len(audio_segments)
logger.info(f"音频分为 {total_segments} 个分段发送")
for i, segment in enumerate(audio_segments):
is_last = (i == total_segments - 1)
request = RequestBuilder.new_audio_only_request(self.seq, segment, is_last)
await self.conn.send_bytes(request)
logger.debug(f"发送音频分段 {i+1}/{total_segments} (最后一段: {is_last})")
if not is_last:
self.seq += 1
await asyncio.sleep(self.segment_duration / 1000)
async def recv_messages(self) -> None:
"""接收并处理识别结果"""
try:
async for msg in self.conn:
if msg.type == aiohttp.WSMsgType.BINARY:
response = self.response_parser.parse_response(msg.data)
# 提取增量文本
incremental_text = response.get_incremental_text()
if incremental_text:
self.full_recognized_text += incremental_text
logger.debug(f"新增识别文本: {incremental_text[:50]}...")
# 最后一包或出错时保存文件
if response.is_last_package or response.code != 0:
if self.output_file:
self.save_recognized_text()
break
elif msg.type in [aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED]:
logger.error(f"WebSocket连接异常: {msg.type}")
break
except Exception as e:
logger.error(f"接收消息出错: {e}")
raise
def save_recognized_text(self):
"""保存识别结果到文件"""
try:
# 创建输出目录
output_dir = os.path.dirname(self.output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 写入文件
with open(self.output_file, 'w', encoding=DEFAULT_OUTPUT_ENCODING) as f:
f.write(self.full_recognized_text)
logger.info(f"识别结果已保存到 {self.output_file}")
logger.info(f"总识别文本长度: {len(self.full_recognized_text)} 字符")
except Exception as e:
logger.error(f"保存文件失败: {e}")
raise
async def execute(self, file_path: str) -> None:
"""执行完整的ASR识别流程"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"音频文件不存在: {file_path}")
# 重置状态
self.seq = 1
self.full_recognized_text = ""
try:
# 1. 转换音频
content = await self.read_audio_data(file_path)
# 2. 创建连接
await self.create_connection()
# 3. 发送初始请求
await self.send_full_client_request()
# 4. 并发发送音频和接收结果
send_task = asyncio.create_task(self.send_audio_segments(content))
recv_task = asyncio.create_task(self.recv_messages())
await asyncio.gather(send_task, recv_task)
finally:
if self.conn:
await self.conn.close()
async def main():
import argparse
parser = argparse.ArgumentParser(description="ASR WebSocket客户端(支持MP3直接处理)")
parser.add_argument("--file", type=str, required=True, help="音频文件路径(MP3/WAV)")
parser.add_argument("--output", type=str, help="输出文本文件路径")
parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
help="WebSocket URL")
parser.add_argument("--seg-duration", type=int, default=200,
help="每包音频时长(ms),默认200")
args = parser.parse_args()
# 自动生成输出文件名
if not args.output:
base_name = os.path.splitext(os.path.basename(args.file))[0]
args.output = f"{base_name}_asr_result.txt"
# 执行识别
async with AsrWsClient(args.url, args.seg_duration, args.output) as client:
try:
await client.execute(args.file)
except Exception as e:
logger.error(f"ASR处理失败: {e}")
raise
if __name__ == "__main__":
asyncio.run(main())