feat: Support iFLYTEK large model for Chinese-English speech recognition
parent
f9f96fd2cd
commit
4786970689
|
|
@ -0,0 +1,56 @@
|
||||||
|
# coding=utf-8
|
||||||
|
import traceback
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat.xf-yun.com/v1')
|
||||||
|
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
|
|
||||||
|
def is_valid(self,
|
||||||
|
model_type: str,
|
||||||
|
model_name,
|
||||||
|
model_credential: Dict[str, object],
|
||||||
|
model_params, provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||||
|
|
||||||
|
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||||
|
error=str(e)))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,192 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import ssl
|
||||||
|
import traceback
|
||||||
|
from typing import Dict
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from datetime import datetime, timezone, UTC
|
||||||
|
import websockets
|
||||||
|
import os
|
||||||
|
|
||||||
|
from future.backports.urllib.parse import urlparse
|
||||||
|
|
||||||
|
from common.utils.logger import maxkb_logger
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from models_provider.impl.base_stt import BaseSpeechToText
|
||||||
|
|
||||||
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
|
||||||
|
class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
|
spark_app_id: str
|
||||||
|
spark_api_key: str
|
||||||
|
spark_api_secret: str
|
||||||
|
spark_api_url: str
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.spark_api_url = kwargs.get('spark_api_url')
|
||||||
|
self.spark_app_id = kwargs.get('spark_app_id')
|
||||||
|
self.spark_api_key = kwargs.get('spark_api_key')
|
||||||
|
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_cache_model():
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {}
|
||||||
|
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||||
|
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||||
|
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||||
|
optional_params['temperature'] = model_kwargs['temperature']
|
||||||
|
return XFZhEnSparkSpeechToText(
|
||||||
|
spark_app_id=model_credential.get('spark_app_id'),
|
||||||
|
spark_api_key=model_credential.get('spark_api_key'),
|
||||||
|
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||||
|
spark_api_url=model_credential.get('spark_api_url'),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成url
|
||||||
|
def create_url(self):
|
||||||
|
url = self.spark_api_url
|
||||||
|
host = urlparse(url).hostname
|
||||||
|
|
||||||
|
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||||
|
date = datetime.now(UTC).strftime(gmt_format)
|
||||||
|
# 拼接字符串
|
||||||
|
signature_origin = "host: " + host + "\n"
|
||||||
|
signature_origin += "date: " + date + "\n"
|
||||||
|
signature_origin += "GET " + "/v1 HTTP/1.1"
|
||||||
|
# 进行hmac-sha256进行加密
|
||||||
|
signature_sha = hmac.new(
|
||||||
|
self.spark_api_secret.encode('utf-8'),
|
||||||
|
signature_origin.encode('utf-8'),
|
||||||
|
hashlib.sha256
|
||||||
|
).digest()
|
||||||
|
signature = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||||
|
|
||||||
|
authorization_origin = (
|
||||||
|
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", '
|
||||||
|
f'headers="host date request-line", signature="{signature}"'
|
||||||
|
)
|
||||||
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'authorization': authorization,
|
||||||
|
'date': date,
|
||||||
|
'host': host
|
||||||
|
}
|
||||||
|
auth_url = url + '?' + urlencode(params)
|
||||||
|
return auth_url
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
|
||||||
|
self.speech_to_text(f)
|
||||||
|
|
||||||
|
def speech_to_text(self, audio_file_path):
|
||||||
|
async def handle():
|
||||||
|
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
|
||||||
|
# print("连接成功")
|
||||||
|
# 发送音频数据
|
||||||
|
await self.send_audio(ws, audio_file_path)
|
||||||
|
# 接收识别结果
|
||||||
|
return await self.handle_message(ws)
|
||||||
|
try:
|
||||||
|
return asyncio.run(handle())
|
||||||
|
except Exception as err:
|
||||||
|
maxkb_logger.error(f"语音识别错误: {str(err)}: {traceback.format_exc()}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def send_audio(self, ws, audio_file):
|
||||||
|
"""发送音频数据"""
|
||||||
|
chunk_size = 4000
|
||||||
|
seq = 1
|
||||||
|
max_chunks = 10000
|
||||||
|
while True:
|
||||||
|
chunk = audio_file.read(chunk_size)
|
||||||
|
if not chunk or seq > max_chunks:
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk_base64 = base64.b64encode(chunk).decode('utf-8')
|
||||||
|
# 第一帧
|
||||||
|
if seq == 1:
|
||||||
|
frame = {
|
||||||
|
"header": {"app_id": self.spark_app_id, "status": 0},
|
||||||
|
"parameter": {
|
||||||
|
"iat": {
|
||||||
|
"domain": "slm", "language": "zh_cn", "accent": "mandarin",
|
||||||
|
"eos": 10000, "vinfo": 1,
|
||||||
|
"result": {"encoding": "utf8", "compress": "raw", "format": "json"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"audio": {
|
||||||
|
"encoding": "lame", "sample_rate": 16000, "channels": 1,
|
||||||
|
"bit_depth": 16, "seq": seq, "status": 0, "audio": chunk_base64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# 中间帧
|
||||||
|
else:
|
||||||
|
frame = {
|
||||||
|
"header": {"app_id": self.spark_app_id, "status": 1},
|
||||||
|
"payload": {
|
||||||
|
"audio": {
|
||||||
|
"encoding": "lame", "sample_rate": 16000, "channels": 1,
|
||||||
|
"bit_depth": 16, "seq": seq, "status": 1, "audio": chunk_base64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await ws.send(json.dumps(frame))
|
||||||
|
seq += 1
|
||||||
|
|
||||||
|
# 发送结束帧
|
||||||
|
end_frame = {
|
||||||
|
"header": {"app_id": self.spark_app_id, "status": 2},
|
||||||
|
"payload": {
|
||||||
|
"audio": {
|
||||||
|
"encoding": "lame", "sample_rate": 16000, "channels": 1,
|
||||||
|
"bit_depth": 16, "seq": seq, "status": 2, "audio": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await ws.send(json.dumps(end_frame))
|
||||||
|
|
||||||
|
|
||||||
|
# 接受信息处理器
|
||||||
|
async def handle_message(self, ws):
|
||||||
|
result_text = ""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message = await asyncio.wait_for(ws.recv(), timeout=30.0)
|
||||||
|
data = json.loads(message)
|
||||||
|
|
||||||
|
if data['header']['code'] != 0:
|
||||||
|
raise Exception("")
|
||||||
|
|
||||||
|
if 'payload' in data and 'result' in data['payload']:
|
||||||
|
result = data['payload']['result']
|
||||||
|
text = result.get('text', '')
|
||||||
|
if text:
|
||||||
|
text_data = json.loads(base64.b64decode(text).decode('utf-8'))
|
||||||
|
for ws_item in text_data.get('ws', []):
|
||||||
|
for cw in ws_item.get('cw', []):
|
||||||
|
for sw in cw.get('sw', []):
|
||||||
|
result_text += sw['w']
|
||||||
|
|
||||||
|
if data['header'].get('status') == 2:
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
return result_text
|
||||||
|
|
@ -17,6 +17,7 @@ from models_provider.impl.xf_model_provider.credential.image import XunFeiImageM
|
||||||
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||||
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
||||||
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
||||||
|
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential
|
||||||
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
|
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
|
||||||
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
|
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
|
||||||
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||||
|
|
@ -25,10 +26,13 @@ from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||||
from maxkb.conf import PROJECT_DIR
|
from maxkb.conf import PROJECT_DIR
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
|
from models_provider.impl.xf_model_provider.model.zh_en_stt import XFZhEnSparkSpeechToText
|
||||||
|
|
||||||
ssl._create_default_https_context = ssl.create_default_context()
|
ssl._create_default_https_context = ssl.create_default_context()
|
||||||
|
|
||||||
xunfei_model_credential = XunFeiLLMModelCredential()
|
xunfei_model_credential = XunFeiLLMModelCredential()
|
||||||
stt_model_credential = XunFeiSTTModelCredential()
|
stt_model_credential = XunFeiSTTModelCredential()
|
||||||
|
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
|
||||||
image_model_credential = XunFeiImageModelCredential()
|
image_model_credential = XunFeiImageModelCredential()
|
||||||
tts_model_credential = XunFeiTTSModelCredential()
|
tts_model_credential = XunFeiTTSModelCredential()
|
||||||
embedding_model_credential = XFEmbeddingCredential()
|
embedding_model_credential = XFEmbeddingCredential()
|
||||||
|
|
@ -36,7 +40,10 @@ model_info_list = [
|
||||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
|
||||||
|
XFSparkSpeechToText),
|
||||||
|
ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential,
|
||||||
|
XFZhEnSparkSpeechToText),
|
||||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||||
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
||||||
]
|
]
|
||||||
|
|
@ -47,7 +54,8 @@ model_info_manage = (
|
||||||
.append_default_model_info(
|
.append_default_model_info(
|
||||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM))
|
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM))
|
||||||
.append_default_model_info(
|
.append_default_model_info(
|
||||||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
|
||||||
|
XFSparkSpeechToText),
|
||||||
)
|
)
|
||||||
.append_default_model_info(
|
.append_default_model_info(
|
||||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))
|
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue