Compare commits
2 Commits
9c98e670e1
...
982425e21e
| Author | SHA1 | Date |
|---|---|---|
|
|
982425e21e | |
|
|
06c4f5d74e |
|
|
@ -183,6 +183,11 @@
|
|||
<artifactId>tencentcloud-speech-sdk-java</artifactId>
|
||||
<version>1.0.67</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencentcloudapi</groupId>
|
||||
<artifactId>tencentcloud-sdk-java-asr</artifactId>
|
||||
<version>3.1.1470</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
|
|
|||
|
|
@ -956,16 +956,16 @@ public class AiModelServiceImpl implements AiModelService {
|
|||
}
|
||||
Map<String, Object> mediaConfig = dto.getMediaConfig() == null ? Collections.emptyMap() : dto.getMediaConfig();
|
||||
if (readConfigString(mediaConfig.get(MEDIA_TENCENT_APP_ID)) == null) {
|
||||
throw new RuntimeException("腾讯实时 ASR 模型必须配置 mediaConfig.tencentAppId");
|
||||
throw new RuntimeException("腾讯 ASR 模型必须配置 mediaConfig.tencentAppId");
|
||||
}
|
||||
if (readConfigString(mediaConfig.get(MEDIA_TENCENT_SECRET_ID)) == null) {
|
||||
throw new RuntimeException("腾讯实时 ASR 模型必须配置 mediaConfig.tencentSecretId");
|
||||
throw new RuntimeException("腾讯 ASR 模型必须配置 mediaConfig.tencentSecretId");
|
||||
}
|
||||
if (readConfigString(mediaConfig.get(MEDIA_TENCENT_SECRET_KEY)) == null) {
|
||||
throw new RuntimeException("腾讯实时 ASR 模型必须配置 mediaConfig.tencentSecretKey");
|
||||
throw new RuntimeException("腾讯 ASR 模型必须配置 mediaConfig.tencentSecretKey");
|
||||
}
|
||||
if (dto.getModelCode() == null || dto.getModelCode().isBlank()) {
|
||||
throw new RuntimeException("腾讯实时 ASR 模型必须配置 modelCode");
|
||||
throw new RuntimeException("腾讯 ASR 模型必须配置 modelCode");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,16 @@ import com.imeeting.support.redis.MeetingLockCache;
|
|||
import com.unisbase.entity.SysUser;
|
||||
import com.unisbase.mapper.SysUserMapper;
|
||||
import com.unisbase.service.SysParamService;
|
||||
import com.tencentcloudapi.asr.v20190614.AsrClient;
|
||||
import com.tencentcloudapi.asr.v20190614.models.CreateRecTaskRequest;
|
||||
import com.tencentcloudapi.asr.v20190614.models.CreateRecTaskResponse;
|
||||
import com.tencentcloudapi.asr.v20190614.models.DescribeTaskStatusRequest;
|
||||
import com.tencentcloudapi.asr.v20190614.models.DescribeTaskStatusResponse;
|
||||
import com.tencentcloudapi.asr.v20190614.models.SentenceDetail;
|
||||
import com.tencentcloudapi.asr.v20190614.models.TaskStatus;
|
||||
import com.tencentcloudapi.common.Credential;
|
||||
import com.tencentcloudapi.common.exception.TencentCloudSDKException;
|
||||
import com.tencentcloudapi.common.profile.ClientProfile;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
|
|
@ -69,6 +79,12 @@ public class AiTaskServiceImpl extends ServiceImpl<AiTaskMapper, AiTask> impleme
|
|||
private static final Duration ASR_QUERY_REQUEST_TIMEOUT = Duration.ofSeconds(30);
|
||||
private static final String DISPATCH_MODE_PARALLEL = "PARALLEL";
|
||||
private static final String DISPATCH_MODE_SERIAL = "SERIAL";
|
||||
private static final String TENCENT_PROVIDER = "tencent";
|
||||
private static final String TENCENT_TASK_ID_KEY = "taskId";
|
||||
private static final String MEDIA_TENCENT_APP_ID = "tencentAppId";
|
||||
private static final String MEDIA_TENCENT_SECRET_ID = "tencentSecretId";
|
||||
private static final String MEDIA_TENCENT_SECRET_KEY = "tencentSecretKey";
|
||||
private static final String TENCENT_ASR_REGION = "ap-guangzhou";
|
||||
|
||||
private final MeetingMapper meetingMapper;
|
||||
private final MeetingTranscriptMapper transcriptMapper;
|
||||
|
|
@ -609,6 +625,14 @@ public class AiTaskServiceImpl extends ServiceImpl<AiTaskMapper, AiTask> impleme
|
|||
log.info("[ASR-PROC] 解析ASR模型成功: meetingId={}, asrTaskId={}, asrModelId={}, baseUrl={}",
|
||||
meeting.getId(), taskRecord.getId(), asrModelId, asrModel.getBaseUrl());
|
||||
|
||||
if ("tencent".equalsIgnoreCase(firstNonBlank(asrModel.getProvider()))) {
|
||||
String transcriptText = processTencentOfflineAsr(meeting, taskRecord, asrModel);
|
||||
log.info("[ASR-PROC] Tencent offline transcript persisted: meetingId={}, asrTaskId={}, transcriptLength={}",
|
||||
meeting.getId(), taskRecord.getId(), transcriptText == null ? 0 : transcriptText.length());
|
||||
meetingPointsService.recordAsrSuccessCharge(meeting, taskRecord);
|
||||
return transcriptText;
|
||||
}
|
||||
|
||||
String submitUrl = appendPath(asrModel.getBaseUrl(), "api/v1/asr/transcriptions");
|
||||
String taskId = taskRecord.getResponseData() != null
|
||||
? String.valueOf(taskRecord.getResponseData().getOrDefault("task_id", ""))
|
||||
|
|
@ -702,6 +726,124 @@ public class AiTaskServiceImpl extends ServiceImpl<AiTaskMapper, AiTask> impleme
|
|||
return transcriptText;
|
||||
}
|
||||
|
||||
protected String processTencentOfflineAsr(Meeting meeting, AiTask taskRecord, AiModelVO asrModel) throws Exception {
|
||||
Long taskId = readTencentOfflineTaskId(taskRecord);
|
||||
if (taskId == null) {
|
||||
taskId = submitTencentOfflineTask(meeting, taskRecord, asrModel);
|
||||
}
|
||||
|
||||
TaskStatus taskStatus = null;
|
||||
for (int i = 0; i < 600; i++) {
|
||||
Thread.sleep(2000);
|
||||
taskStatus = queryTencentOfflineTask(asrModel, taskId);
|
||||
if (taskStatus == null) {
|
||||
throw new RuntimeException("腾讯离线 ASR 查询结果为空");
|
||||
}
|
||||
String status = firstNonBlank(taskStatus.getStatusStr(), "");
|
||||
if ("success".equalsIgnoreCase(status)) {
|
||||
updateAiTaskSuccess(taskRecord, objectMapper.valueToTree(buildTencentTaskStatusSnapshot(taskStatus)));
|
||||
return saveTencentOfflineTranscripts(meeting, taskStatus.getResultDetail());
|
||||
}
|
||||
if ("failed".equalsIgnoreCase(status)) {
|
||||
String errorMsg = firstNonBlank(taskStatus.getErrorMsg(), "腾讯离线 ASR 识别失败");
|
||||
updateAiTaskFail(taskRecord, errorMsg);
|
||||
throw new RuntimeException(errorMsg);
|
||||
}
|
||||
updateProgress(meeting.getId(), 5, "腾讯离线 ASR 识别中...", 0);
|
||||
}
|
||||
throw new RuntimeException("腾讯离线 ASR 轮询超时");
|
||||
}
|
||||
|
||||
protected Map<String, Object> buildTencentOfflineCreateRequest(Meeting meeting, AiTask taskRecord, AiModelVO asrModel) {
|
||||
Map<String, Object> req = new HashMap<>();
|
||||
req.put("engineModelType", asrModel.getModelCode());
|
||||
req.put("channelNum", 1L);
|
||||
req.put("resTextFormat", 2L);
|
||||
req.put("sourceType", 0L);
|
||||
req.put("url", resolveTencentOfflineAudioUrl(meeting));
|
||||
req.put("speakerDiarization", resolveTencentSpeakerDiarization(taskRecord));
|
||||
req.put("speakerNumber", 0L);
|
||||
String hotwordList = buildTencentHotwordList(taskRecord);
|
||||
if (hotwordList != null) {
|
||||
req.put("hotwordList", hotwordList);
|
||||
}
|
||||
return req;
|
||||
}
|
||||
|
||||
protected Long submitTencentOfflineTask(Meeting meeting, AiTask taskRecord, AiModelVO asrModel) throws Exception {
|
||||
updateProgress(meeting.getId(), 5, "提交腾讯离线 ASR 任务...", 0);
|
||||
meetingPointsService.assertSufficientPointsBeforeAsrSubmit(meeting, taskRecord);
|
||||
|
||||
Map<String, Object> reqSnapshot = buildTencentOfflineCreateRequest(meeting, taskRecord, asrModel);
|
||||
taskRecord.setRequestData(reqSnapshot);
|
||||
this.updateById(taskRecord);
|
||||
|
||||
CreateRecTaskRequest request = new CreateRecTaskRequest();
|
||||
request.setEngineModelType(String.valueOf(reqSnapshot.get("engineModelType")));
|
||||
request.setChannelNum(longValue(reqSnapshot.get("channelNum")));
|
||||
request.setResTextFormat(longValue(reqSnapshot.get("resTextFormat")));
|
||||
request.setSourceType(longValue(reqSnapshot.get("sourceType")));
|
||||
request.setUrl(String.valueOf(reqSnapshot.get("url")));
|
||||
request.setSpeakerDiarization(longValue(reqSnapshot.get("speakerDiarization")));
|
||||
request.setSpeakerNumber(longValue(reqSnapshot.get("speakerNumber")));
|
||||
String hotwordList = stringValue(reqSnapshot.get("hotwordList"));
|
||||
if (hotwordList != null && !hotwordList.isBlank()) {
|
||||
request.setHotwordList(hotwordList);
|
||||
}
|
||||
|
||||
CreateRecTaskResponse response = buildTencentOfflineAsrClient(asrModel).CreateRecTask(request);
|
||||
if (response == null || response.getData() == null || response.getData().getTaskId() == null) {
|
||||
throw new RuntimeException("腾讯离线 ASR 提交失败:未返回 TaskId");
|
||||
}
|
||||
Long taskId = response.getData().getTaskId();
|
||||
writeTencentOfflineTaskId(taskRecord, taskId);
|
||||
this.updateById(taskRecord);
|
||||
return taskId;
|
||||
}
|
||||
|
||||
protected TaskStatus queryTencentOfflineTask(AiModelVO asrModel, Long taskId) throws TencentCloudSDKException {
|
||||
DescribeTaskStatusRequest request = new DescribeTaskStatusRequest();
|
||||
request.setTaskId(taskId);
|
||||
DescribeTaskStatusResponse response = buildTencentOfflineAsrClient(asrModel).DescribeTaskStatus(request);
|
||||
return response == null ? null : response.getData();
|
||||
}
|
||||
|
||||
protected String saveTencentOfflineTranscripts(Meeting meeting, SentenceDetail[] resultDetail) {
|
||||
transcriptMapper.delete(new LambdaQueryWrapper<MeetingTranscript>().eq(MeetingTranscript::getMeetingId, meeting.getId()));
|
||||
|
||||
if (resultDetail == null || resultDetail.length == 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
int order = 0;
|
||||
for (SentenceDetail detail : resultDetail) {
|
||||
if (detail == null) {
|
||||
continue;
|
||||
}
|
||||
String content = firstNonBlank(detail.getFinalSentence(), detail.getWrittenText(), "");
|
||||
if (content == null || content.isBlank()) {
|
||||
continue;
|
||||
}
|
||||
String speakerId = String.valueOf(detail.getSpeakerId() == null ? 0L : detail.getSpeakerId());
|
||||
String speakerName = "未知说话人" + speakerId;
|
||||
|
||||
MeetingTranscript mt = new MeetingTranscript();
|
||||
mt.setMeetingId(meeting.getId());
|
||||
mt.setSpeakerId(speakerId);
|
||||
mt.setSpeakerName(speakerName);
|
||||
mt.setContent(content.trim());
|
||||
fillTencentTranscriptTime(mt, detail);
|
||||
mt.setSortOrder(order++);
|
||||
transcriptMapper.insert(mt);
|
||||
sb.append(speakerName).append(": ").append(mt.getContent()).append("\n");
|
||||
}
|
||||
if (order > 0) {
|
||||
meetingTranscriptFileService.initializeTranscriptFileIfAbsent(meeting.getId());
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private Map<String, Object> buildAsrRequest(Meeting meeting, AiTask taskRecord, AiModelVO asrModel) {
|
||||
Map<String, Object> req = new HashMap<>();
|
||||
String rawAudioUrl = meeting.getAudioUrl();
|
||||
|
|
@ -1417,6 +1559,93 @@ public class AiTaskServiceImpl extends ServiceImpl<AiTaskMapper, AiTask> impleme
|
|||
}
|
||||
}
|
||||
|
||||
private AsrClient buildTencentOfflineAsrClient(AiModelVO asrModel) {
|
||||
Map<String, Object> mediaConfig = asrModel.getMediaConfig() == null ? Map.of() : asrModel.getMediaConfig();
|
||||
String secretId = requireTencentMediaConfig(mediaConfig, MEDIA_TENCENT_SECRET_ID);
|
||||
String secretKey = requireTencentMediaConfig(mediaConfig, MEDIA_TENCENT_SECRET_KEY);
|
||||
Credential credential = new Credential(secretId, secretKey);
|
||||
ClientProfile profile = new ClientProfile();
|
||||
return new AsrClient(credential, TENCENT_ASR_REGION, profile);
|
||||
}
|
||||
|
||||
private String requireTencentMediaConfig(Map<String, Object> mediaConfig, String key) {
|
||||
String value = stringValue(mediaConfig.get(key));
|
||||
if (value == null || value.isBlank()) {
|
||||
throw new RuntimeException("腾讯离线 ASR 缺少配置: " + key);
|
||||
}
|
||||
return value.trim();
|
||||
}
|
||||
|
||||
private Long resolveTencentSpeakerDiarization(AiTask taskRecord) {
|
||||
Object useSpkObj = taskRecord.getTaskConfig() == null ? null : taskRecord.getTaskConfig().get("useSpkId");
|
||||
return "1".equals(String.valueOf(useSpkObj)) ? 1L : 0L;
|
||||
}
|
||||
|
||||
private String buildTencentHotwordList(AiTask taskRecord) {
|
||||
if (taskRecord == null || taskRecord.getTaskConfig() == null) {
|
||||
return null;
|
||||
}
|
||||
Object hotWordsObj = taskRecord.getTaskConfig().get("hotWords");
|
||||
if (!(hotWordsObj instanceof List<?> words) || words.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
return words.stream()
|
||||
.filter(Objects::nonNull)
|
||||
.map(String::valueOf)
|
||||
.map(String::trim)
|
||||
.filter(word -> !word.isBlank())
|
||||
.map(word -> word + "|5")
|
||||
.collect(Collectors.joining(","));
|
||||
}
|
||||
|
||||
private String resolveTencentOfflineAudioUrl(Meeting meeting) {
|
||||
if (meeting == null || meeting.getAudioUrl() == null || meeting.getAudioUrl().isBlank()) {
|
||||
throw new RuntimeException("腾讯离线 ASR 缺少音频地址");
|
||||
}
|
||||
String audioUrl = meeting.getAudioUrl().trim();
|
||||
if (audioUrl.startsWith("http://") || audioUrl.startsWith("https://")) {
|
||||
return audioUrl;
|
||||
}
|
||||
return serverBaseUrl + (audioUrl.startsWith("/") ? "" : "/") + audioUrl;
|
||||
}
|
||||
|
||||
private void writeTencentOfflineTaskId(AiTask taskRecord, Long taskId) {
|
||||
Map<String, Object> responseData = taskRecord.getResponseData() == null
|
||||
? new HashMap<>()
|
||||
: new HashMap<>(taskRecord.getResponseData());
|
||||
responseData.put(TENCENT_TASK_ID_KEY, taskId);
|
||||
taskRecord.setResponseData(responseData);
|
||||
}
|
||||
|
||||
private Long readTencentOfflineTaskId(AiTask taskRecord) {
|
||||
if (taskRecord == null || taskRecord.getResponseData() == null) {
|
||||
return null;
|
||||
}
|
||||
Object taskId = taskRecord.getResponseData().get(TENCENT_TASK_ID_KEY);
|
||||
return longValue(taskId);
|
||||
}
|
||||
|
||||
private Map<String, Object> buildTencentTaskStatusSnapshot(TaskStatus taskStatus) {
|
||||
Map<String, Object> snapshot = new HashMap<>();
|
||||
snapshot.put("taskId", taskStatus.getTaskId());
|
||||
snapshot.put("status", taskStatus.getStatus());
|
||||
snapshot.put("statusStr", taskStatus.getStatusStr());
|
||||
snapshot.put("result", taskStatus.getResult());
|
||||
snapshot.put("errorMsg", taskStatus.getErrorMsg());
|
||||
snapshot.put("audioDuration", taskStatus.getAudioDuration());
|
||||
return snapshot;
|
||||
}
|
||||
|
||||
private void fillTencentTranscriptTime(MeetingTranscript transcript, SentenceDetail detail) {
|
||||
if (transcript == null || detail == null) {
|
||||
return;
|
||||
}
|
||||
int startTime = detail.getStartMs() == null ? 0 : detail.getStartMs().intValue();
|
||||
int endTime = detail.getEndMs() == null ? startTime : detail.getEndMs().intValue();
|
||||
transcript.setStartTime(startTime);
|
||||
transcript.setEndTime(endTime);
|
||||
}
|
||||
|
||||
private AiModelVO resolveAsrModelForRevision(AiTask asrTask) {
|
||||
if (asrTask == null || asrTask.getTaskConfig() == null) {
|
||||
return null;
|
||||
|
|
|
|||
|
|
@ -525,7 +525,6 @@ public class LocalRealtimeAsrChannel implements RealtimeAsrChannel {
|
|||
log.error("上游 ASR websocket 异常:meetingId={}, sessionId={}, upstream={}",
|
||||
context.getMeetingId(), currentConnectionId(context), context.getTargetWsUrl(), error);
|
||||
context.getChannelState().remove(STATE_UPSTREAM_SOCKET);
|
||||
context.getCallback().removeMeetingSession(context.getMeetingId());
|
||||
context.getCallback().sendFrontendError(context.getMeetingId(),
|
||||
"REALTIME_UPSTREAM_ERROR",
|
||||
error == null || error.getMessage() == null || error.getMessage().isBlank()
|
||||
|
|
|
|||
|
|
@ -349,6 +349,30 @@ class AiModelServiceImplTest {
|
|||
assertNull(captor.getValue().getApiKey());
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveModelShouldRejectTencentAsrWithoutAppId() {
|
||||
AiModelServiceImpl service = new AiModelServiceImpl(
|
||||
objectMapper,
|
||||
mock(AsrModelMapper.class),
|
||||
mock(LlmModelMapper.class)
|
||||
);
|
||||
|
||||
AiModelDTO dto = new AiModelDTO();
|
||||
dto.setModelType("ASR");
|
||||
dto.setModelName("tencent-asr");
|
||||
dto.setProvider("tencent");
|
||||
dto.setModelCode("16k_zh");
|
||||
dto.setIsDefault(0);
|
||||
dto.setStatus(1);
|
||||
dto.setMediaConfig(Map.of(
|
||||
"tencentSecretId", "secret-id",
|
||||
"tencentSecretKey", "secret-key"
|
||||
));
|
||||
|
||||
RuntimeException ex = assertThrows(RuntimeException.class, () -> service.saveModel(dto));
|
||||
assertEquals("腾讯 ASR 模型必须配置 mediaConfig.tencentAppId", ex.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveModelShouldRejectTencentAsrWithoutSecretKey() {
|
||||
AiModelServiceImpl service = new AiModelServiceImpl(
|
||||
|
|
@ -370,7 +394,7 @@ class AiModelServiceImplTest {
|
|||
));
|
||||
|
||||
RuntimeException ex = assertThrows(RuntimeException.class, () -> service.saveModel(dto));
|
||||
assertEquals("腾讯实时 ASR 模型必须配置 mediaConfig.tencentSecretKey", ex.getMessage());
|
||||
assertEquals("腾讯 ASR 模型必须配置 mediaConfig.tencentSecretKey", ex.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
|||
|
|
@ -1,152 +1,304 @@
|
|||
//package com.imeeting.service.biz.impl;
|
||||
//
|
||||
//import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
//import com.imeeting.dto.biz.AiModelVO;
|
||||
//import com.imeeting.entity.biz.AiTask;
|
||||
//import com.imeeting.entity.biz.HotWord;
|
||||
//import com.imeeting.entity.biz.Meeting;
|
||||
//import com.imeeting.mapper.biz.MeetingMapper;
|
||||
//import com.imeeting.mapper.biz.MeetingTranscriptMapper;
|
||||
//import com.imeeting.service.biz.AiModelService;
|
||||
//import com.imeeting.service.biz.HotWordService;
|
||||
//import com.imeeting.service.biz.MeetingProgressService;
|
||||
//import com.imeeting.service.biz.MeetingSummaryFileService;
|
||||
//import com.imeeting.service.biz.MeetingTranscriptChapterService;
|
||||
//import com.imeeting.service.biz.MeetingTranscriptFileService;
|
||||
//import com.imeeting.support.RedisValueSupport;
|
||||
//import com.imeeting.support.TaskSecurityContextRunner;
|
||||
//import com.unisbase.mapper.SysUserMapper;
|
||||
//import com.unisbase.service.SysParamService;
|
||||
//import org.junit.jupiter.api.Test;
|
||||
//import org.springframework.test.util.ReflectionTestUtils;
|
||||
//
|
||||
//import java.util.HashMap;
|
||||
//import java.util.List;
|
||||
//import java.util.Map;
|
||||
//
|
||||
//import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
//import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
//import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
//import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
//import static org.mockito.ArgumentMatchers.any;
|
||||
//import static org.mockito.Mockito.mock;
|
||||
//import static org.mockito.Mockito.when;
|
||||
//
|
||||
//class AiTaskServiceImplTest {
|
||||
//
|
||||
// @Test
|
||||
// void buildAsrRequestShouldFollowCurrentOfflineAsrContract() {
|
||||
// HotWordService hotWordService = mock(HotWordService.class);
|
||||
// HotWord hotWord = new HotWord();
|
||||
// hotWord.setWord("汇智");
|
||||
// hotWord.setWeight(25);
|
||||
// when(hotWordService.list(any())).thenReturn(List.of(hotWord));
|
||||
//
|
||||
// AiTaskServiceImpl service = new AiTaskServiceImpl(
|
||||
// mock(MeetingMapper.class),
|
||||
// mock(MeetingTranscriptMapper.class),
|
||||
// mock(AiModelService.class),
|
||||
// new ObjectMapper(),
|
||||
// mock(SysUserMapper.class),
|
||||
// hotWordService,
|
||||
// mock(RedisValueSupport.class),
|
||||
// mock(MeetingProgressService.class),
|
||||
// mock(MeetingSummaryFileService.class),
|
||||
// mock(MeetingTranscriptFileService.class),
|
||||
// mock(MeetingTranscriptChapterService.class),
|
||||
// mock(MeetingSummaryPromptAssembler.class),
|
||||
// mock(TaskSecurityContextRunner.class),
|
||||
// mock(MeetingExternalSummaryWebhookTrigger.class),
|
||||
// mock(SysParamService.class)
|
||||
// );
|
||||
// ReflectionTestUtils.setField(service, "serverBaseUrl", "http://localhost:8080");
|
||||
//
|
||||
// Meeting meeting = new Meeting();
|
||||
// meeting.setAudioUrl("/api/static/meetings/12/source audio.mp4");
|
||||
//
|
||||
// AiTask task = new AiTask();
|
||||
// Map<String, Object> taskConfig = new HashMap<>();
|
||||
// taskConfig.put("useSpkId", 1);
|
||||
// taskConfig.put("enableTextRefine", true);
|
||||
// taskConfig.put("hotWords", List.of("汇智"));
|
||||
// task.setTaskConfig(taskConfig);
|
||||
//
|
||||
// AiModelVO asrModel = new AiModelVO();
|
||||
// asrModel.setModelCode("legacy-model-code");
|
||||
//
|
||||
// @SuppressWarnings("unchecked")
|
||||
// Map<String, Object> request = (Map<String, Object>) ReflectionTestUtils.invokeMethod(
|
||||
// service,
|
||||
// "buildAsrRequest",
|
||||
// meeting,
|
||||
// task,
|
||||
// asrModel
|
||||
// );
|
||||
//
|
||||
// assertEquals("http://localhost:8080/api/static/meetings/12/source%20audio.mp4", request.get("audio_address"));
|
||||
// assertFalse(request.containsKey("file_url"));
|
||||
//
|
||||
// @SuppressWarnings("unchecked")
|
||||
// Map<String, Object> config = (Map<String, Object>) request.get("config");
|
||||
// assertEquals(Boolean.TRUE, config.get("enable_speaker"));
|
||||
// assertEquals(Boolean.TRUE, config.get("match_speaker_registry"));
|
||||
// assertEquals(Boolean.TRUE, config.get("enable_text_cleanup"));
|
||||
// assertFalse(config.containsKey("enable_text_refine"));
|
||||
// assertFalse(config.containsKey("enable_two_pass"));
|
||||
// assertFalse(config.containsKey("model"));
|
||||
//
|
||||
// @SuppressWarnings("unchecked")
|
||||
// List<Map<String, Object>> hotwords = (List<Map<String, Object>>) config.get("hotwords");
|
||||
// assertEquals(1, hotwords.size());
|
||||
// assertEquals("汇智", hotwords.get(0).get("hotword"));
|
||||
// assertEquals(2.5, hotwords.get(0).get("weight"));
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// void buildAsrRequestShouldDisableRegistryMatchWhenSpeakerSplitDisabled() {
|
||||
// AiTaskServiceImpl service = new AiTaskServiceImpl(
|
||||
// mock(MeetingMapper.class),
|
||||
// mock(MeetingTranscriptMapper.class),
|
||||
// mock(AiModelService.class),
|
||||
// new ObjectMapper(),
|
||||
// mock(SysUserMapper.class),
|
||||
// mock(HotWordService.class),
|
||||
// mock(RedisValueSupport.class),
|
||||
// mock(MeetingProgressService.class),
|
||||
// mock(MeetingSummaryFileService.class),
|
||||
// mock(MeetingTranscriptFileService.class),
|
||||
// mock(MeetingTranscriptChapterService.class),
|
||||
// mock(MeetingSummaryPromptAssembler.class),
|
||||
// mock(TaskSecurityContextRunner.class),
|
||||
// mock(MeetingExternalSummaryWebhookTrigger.class),
|
||||
// mock(SysParamService.class)
|
||||
// );
|
||||
// ReflectionTestUtils.setField(service, "serverBaseUrl", "http://localhost:8080");
|
||||
//
|
||||
// Meeting meeting = new Meeting();
|
||||
// meeting.setAudioUrl("/api/static/audio/demo.wav");
|
||||
//
|
||||
// AiTask task = new AiTask();
|
||||
// Map<String, Object> taskConfig = new HashMap<>();
|
||||
// taskConfig.put("useSpkId", 0);
|
||||
// taskConfig.put("enableTextRefine", false);
|
||||
// task.setTaskConfig(taskConfig);
|
||||
//
|
||||
// @SuppressWarnings("unchecked")
|
||||
// Map<String, Object> request = (Map<String, Object>) ReflectionTestUtils.invokeMethod(
|
||||
// service,
|
||||
// "buildAsrRequest",
|
||||
// meeting,
|
||||
// task,
|
||||
// new AiModelVO()
|
||||
// );
|
||||
//
|
||||
// @SuppressWarnings("unchecked")
|
||||
// Map<String, Object> config = (Map<String, Object>) request.get("config");
|
||||
// assertEquals(Boolean.FALSE, config.get("enable_speaker"));
|
||||
// assertEquals(Boolean.FALSE, config.get("match_speaker_registry"));
|
||||
// assertEquals(Boolean.FALSE, config.get("enable_text_cleanup"));
|
||||
// assertTrue(((List<?>) config.get("hotwords")).isEmpty());
|
||||
// assertNull(request.get("file_url"));
|
||||
// }
|
||||
//}
|
||||
package com.imeeting.service.biz.impl;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.imeeting.dto.biz.AiModelVO;
|
||||
import com.imeeting.entity.biz.AiTask;
|
||||
import com.imeeting.entity.biz.Meeting;
|
||||
import com.imeeting.mapper.biz.AiTaskMapper;
|
||||
import com.imeeting.mapper.biz.MeetingMapper;
|
||||
import com.imeeting.mapper.biz.MeetingTranscriptMapper;
|
||||
import com.imeeting.service.android.AndroidMeetingPushService;
|
||||
import com.imeeting.service.biz.AiModelService;
|
||||
import com.imeeting.service.biz.HotWordService;
|
||||
import com.imeeting.service.biz.MeetingPointsService;
|
||||
import com.imeeting.service.biz.MeetingProgressService;
|
||||
import com.imeeting.service.biz.MeetingSummaryFileService;
|
||||
import com.imeeting.service.biz.MeetingTranscriptChapterService;
|
||||
import com.imeeting.service.biz.MeetingTranscriptFileService;
|
||||
import com.imeeting.support.TaskSecurityContextRunner;
|
||||
import com.imeeting.support.redis.MeetingAsrPermitCache;
|
||||
import com.imeeting.support.redis.MeetingLockCache;
|
||||
import com.unisbase.mapper.SysUserMapper;
|
||||
import com.unisbase.service.SysParamService;
|
||||
import com.tencentcloudapi.asr.v20190614.models.SentenceDetail;
|
||||
import com.tencentcloudapi.asr.v20190614.models.TaskStatus;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class AiTaskServiceImplTest {
|
||||
|
||||
@Test
|
||||
void processAsrTaskShouldUseTencentOfflineBranchWhenProviderIsTencent() throws Exception {
|
||||
MeetingMapper meetingMapper = mock(MeetingMapper.class);
|
||||
MeetingTranscriptMapper transcriptMapper = mock(MeetingTranscriptMapper.class);
|
||||
AiModelService aiModelService = mock(AiModelService.class);
|
||||
HotWordService hotWordService = mock(HotWordService.class);
|
||||
MeetingLockCache meetingLockCache = mock(MeetingLockCache.class);
|
||||
MeetingAsrPermitCache meetingAsrPermitCache = mock(MeetingAsrPermitCache.class);
|
||||
MeetingProgressService meetingProgressService = mock(MeetingProgressService.class);
|
||||
MeetingPointsService meetingPointsService = mock(MeetingPointsService.class);
|
||||
MeetingSummaryFileService meetingSummaryFileService = mock(MeetingSummaryFileService.class);
|
||||
MeetingTranscriptFileService meetingTranscriptFileService = mock(MeetingTranscriptFileService.class);
|
||||
MeetingTranscriptChapterService meetingTranscriptChapterService = mock(MeetingTranscriptChapterService.class);
|
||||
MeetingSummaryPromptAssembler meetingSummaryPromptAssembler = mock(MeetingSummaryPromptAssembler.class);
|
||||
TaskSecurityContextRunner taskSecurityContextRunner = mock(TaskSecurityContextRunner.class);
|
||||
MeetingExternalSummaryWebhookTrigger meetingExternalSummaryWebhookTrigger = mock(MeetingExternalSummaryWebhookTrigger.class);
|
||||
SysParamService sysParamService = mock(SysParamService.class);
|
||||
|
||||
AiTaskServiceImpl service = spy(new AiTaskServiceImpl(
|
||||
meetingMapper,
|
||||
transcriptMapper,
|
||||
aiModelService,
|
||||
new ObjectMapper(),
|
||||
mock(SysUserMapper.class),
|
||||
hotWordService,
|
||||
meetingLockCache,
|
||||
meetingAsrPermitCache,
|
||||
meetingProgressService,
|
||||
meetingPointsService,
|
||||
meetingSummaryFileService,
|
||||
meetingTranscriptFileService,
|
||||
meetingTranscriptChapterService,
|
||||
meetingSummaryPromptAssembler,
|
||||
taskSecurityContextRunner,
|
||||
meetingExternalSummaryWebhookTrigger,
|
||||
sysParamService
|
||||
));
|
||||
ReflectionTestUtils.setField(service, "baseMapper", mock(AiTaskMapper.class));
|
||||
ReflectionTestUtils.setField(service, "androidMeetingPushService", mock(AndroidMeetingPushService.class));
|
||||
|
||||
Meeting meeting = new Meeting();
|
||||
meeting.setId(1L);
|
||||
meeting.setAudioUrl("https://cdn.example.com/audio/demo.m4a");
|
||||
|
||||
AiTask task = new AiTask();
|
||||
task.setId(11L);
|
||||
task.setMeetingId(1L);
|
||||
task.setTaskType("ASR");
|
||||
task.setTaskConfig(new HashMap<>(Map.of(
|
||||
"asrModelId", 101L,
|
||||
"useSpkId", 1,
|
||||
"enableTextRefine", true
|
||||
)));
|
||||
|
||||
AiModelVO model = new AiModelVO();
|
||||
model.setId(101L);
|
||||
model.setProvider("tencent");
|
||||
model.setModelCode("16k_zh");
|
||||
model.setMediaConfig(Map.of(
|
||||
"tencentAppId", "123456",
|
||||
"tencentSecretId", "secret-id",
|
||||
"tencentSecretKey", "secret-key"
|
||||
));
|
||||
when(aiModelService.getModelById(101L, "ASR")).thenReturn(model);
|
||||
|
||||
doReturn(true).when(service).updateById(any(AiTask.class));
|
||||
doReturn("腾讯离线转写结果").when(service).processTencentOfflineAsr(eq(meeting), eq(task), eq(model));
|
||||
|
||||
String result = ReflectionTestUtils.invokeMethod(service, "processAsrTask", meeting, task);
|
||||
|
||||
assertEquals("腾讯离线转写结果", result);
|
||||
verify(service).processTencentOfflineAsr(meeting, task, model);
|
||||
verify(meetingPointsService).recordAsrSuccessCharge(meeting, task);
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildTencentOfflineCreateRequestShouldMapMeetingAndTaskConfig() {
|
||||
AiTaskServiceImpl service = createService(mock(MeetingPointsService.class));
|
||||
ReflectionTestUtils.setField(service, "serverBaseUrl", "https://server.example.com");
|
||||
|
||||
Meeting meeting = new Meeting();
|
||||
meeting.setId(2L);
|
||||
meeting.setAudioUrl("/upload/audio/demo.m4a");
|
||||
|
||||
AiTask task = new AiTask();
|
||||
task.setTaskConfig(new HashMap<>(Map.of(
|
||||
"useSpkId", 1,
|
||||
"enableTextRefine", true,
|
||||
"hotWords", java.util.List.of("腾讯会议", "离线转写")
|
||||
)));
|
||||
|
||||
AiModelVO model = new AiModelVO();
|
||||
model.setModelCode("16k_zh");
|
||||
model.setMediaConfig(Map.of(
|
||||
"tencentAppId", "123456",
|
||||
"tencentSecretId", "secret-id",
|
||||
"tencentSecretKey", "secret-key"
|
||||
));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> request = (Map<String, Object>) ReflectionTestUtils.invokeMethod(
|
||||
service,
|
||||
"buildTencentOfflineCreateRequest",
|
||||
meeting,
|
||||
task,
|
||||
model
|
||||
);
|
||||
|
||||
assertEquals("16k_zh", request.get("engineModelType"));
|
||||
assertEquals(1L, request.get("channelNum"));
|
||||
assertEquals(2L, request.get("resTextFormat"));
|
||||
assertEquals(0L, request.get("sourceType"));
|
||||
assertEquals("https://server.example.com/upload/audio/demo.m4a", request.get("url"));
|
||||
assertEquals(1L, request.get("speakerDiarization"));
|
||||
assertEquals(0L, request.get("speakerNumber"));
|
||||
assertEquals("腾讯会议|5,离线转写|5", request.get("hotwordList"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildTencentOfflineCreateRequestShouldKeepAbsoluteAudioUrl() {
|
||||
AiTaskServiceImpl service = createService(mock(MeetingPointsService.class));
|
||||
ReflectionTestUtils.setField(service, "serverBaseUrl", "https://server.example.com");
|
||||
|
||||
Meeting meeting = new Meeting();
|
||||
meeting.setId(22L);
|
||||
meeting.setAudioUrl("https://cdn.example.com/audio/demo.m4a");
|
||||
|
||||
AiTask task = new AiTask();
|
||||
task.setTaskConfig(new HashMap<>(Map.of("useSpkId", 0)));
|
||||
|
||||
AiModelVO model = new AiModelVO();
|
||||
model.setModelCode("16k_zh");
|
||||
model.setMediaConfig(Map.of(
|
||||
"tencentAppId", "123456",
|
||||
"tencentSecretId", "secret-id",
|
||||
"tencentSecretKey", "secret-key"
|
||||
));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> request = (Map<String, Object>) ReflectionTestUtils.invokeMethod(
|
||||
service,
|
||||
"buildTencentOfflineCreateRequest",
|
||||
meeting,
|
||||
task,
|
||||
model
|
||||
);
|
||||
|
||||
assertEquals("https://cdn.example.com/audio/demo.m4a", request.get("url"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void processTencentOfflineAsrShouldPollUntilSuccessAndReturnTranscriptText() throws Exception {
|
||||
MeetingPointsService meetingPointsService = mock(MeetingPointsService.class);
|
||||
AiTaskServiceImpl service = spy(createService(meetingPointsService));
|
||||
|
||||
Meeting meeting = new Meeting();
|
||||
meeting.setId(3L);
|
||||
meeting.setAudioUrl("https://cdn.example.com/audio/demo.m4a");
|
||||
|
||||
AiTask task = new AiTask();
|
||||
task.setId(31L);
|
||||
task.setMeetingId(3L);
|
||||
task.setTaskType("ASR");
|
||||
task.setTaskConfig(new HashMap<>(Map.of(
|
||||
"asrModelId", 101L,
|
||||
"useSpkId", 1
|
||||
)));
|
||||
|
||||
AiModelVO model = new AiModelVO();
|
||||
model.setProvider("tencent");
|
||||
model.setModelCode("16k_zh");
|
||||
model.setMediaConfig(Map.of(
|
||||
"tencentAppId", "123456",
|
||||
"tencentSecretId", "secret-id",
|
||||
"tencentSecretKey", "secret-key"
|
||||
));
|
||||
|
||||
TaskStatus doing = new TaskStatus();
|
||||
doing.setStatusStr("doing");
|
||||
|
||||
SentenceDetail sentence = new SentenceDetail();
|
||||
sentence.setSpeakerId(3L);
|
||||
sentence.setFinalSentence("测试文本");
|
||||
sentence.setStartMs(1000L);
|
||||
sentence.setEndMs(2500L);
|
||||
|
||||
TaskStatus success = new TaskStatus();
|
||||
success.setStatusStr("success");
|
||||
success.setResultDetail(new SentenceDetail[]{sentence});
|
||||
|
||||
doReturn(90001L).when(service).submitTencentOfflineTask(meeting, task, model);
|
||||
doReturn(doing).doReturn(success).when(service).queryTencentOfflineTask(model, 90001L);
|
||||
doReturn("未知说话人3: 测试文本\n").when(service).saveTencentOfflineTranscripts(meeting, success.getResultDetail());
|
||||
doReturn(true).when(service).updateById(any(AiTask.class));
|
||||
|
||||
String text = service.processTencentOfflineAsr(meeting, task, model);
|
||||
|
||||
assertEquals("未知说话人3: 测试文本\n", text);
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveTencentOfflineTranscriptsShouldOverwriteExistingRowsAndUseUnknownSpeakerName() {
|
||||
MeetingPointsService meetingPointsService = mock(MeetingPointsService.class);
|
||||
MeetingTranscriptMapper transcriptMapper = mock(MeetingTranscriptMapper.class);
|
||||
MeetingTranscriptFileService transcriptFileService = mock(MeetingTranscriptFileService.class);
|
||||
AiTaskServiceImpl service = createService(meetingPointsService, transcriptMapper, transcriptFileService);
|
||||
|
||||
Meeting meeting = new Meeting();
|
||||
meeting.setId(5L);
|
||||
|
||||
SentenceDetail first = new SentenceDetail();
|
||||
first.setSpeakerId(7L);
|
||||
first.setFinalSentence("第一句");
|
||||
first.setStartMs(0L);
|
||||
first.setEndMs(1000L);
|
||||
|
||||
SentenceDetail second = new SentenceDetail();
|
||||
second.setSpeakerId(7L);
|
||||
second.setFinalSentence("第二句");
|
||||
second.setStartMs(1000L);
|
||||
second.setEndMs(2000L);
|
||||
|
||||
String text = service.saveTencentOfflineTranscripts(meeting, new SentenceDetail[]{first, second});
|
||||
|
||||
assertEquals("未知说话人7: 第一句\n未知说话人7: 第二句\n", text);
|
||||
verify(transcriptMapper).delete(any());
|
||||
verify(transcriptMapper, org.mockito.Mockito.times(2)).insert(any());
|
||||
verify(transcriptFileService).initializeTranscriptFileIfAbsent(5L);
|
||||
}
|
||||
|
||||
private AiTaskServiceImpl createService(MeetingPointsService meetingPointsService) {
|
||||
return createService(meetingPointsService, mock(MeetingTranscriptMapper.class), mock(MeetingTranscriptFileService.class));
|
||||
}
|
||||
|
||||
private AiTaskServiceImpl createService(MeetingPointsService meetingPointsService,
|
||||
MeetingTranscriptMapper transcriptMapper,
|
||||
MeetingTranscriptFileService transcriptFileService) {
|
||||
AiTaskServiceImpl service = new AiTaskServiceImpl(
|
||||
mock(MeetingMapper.class),
|
||||
transcriptMapper,
|
||||
mock(AiModelService.class),
|
||||
new ObjectMapper(),
|
||||
mock(SysUserMapper.class),
|
||||
mock(HotWordService.class),
|
||||
mock(MeetingLockCache.class),
|
||||
mock(MeetingAsrPermitCache.class),
|
||||
mock(MeetingProgressService.class),
|
||||
meetingPointsService,
|
||||
mock(MeetingSummaryFileService.class),
|
||||
transcriptFileService,
|
||||
mock(MeetingTranscriptChapterService.class),
|
||||
mock(MeetingSummaryPromptAssembler.class),
|
||||
mock(TaskSecurityContextRunner.class),
|
||||
mock(MeetingExternalSummaryWebhookTrigger.class),
|
||||
mock(SysParamService.class)
|
||||
);
|
||||
ReflectionTestUtils.setField(service, "baseMapper", mock(AiTaskMapper.class));
|
||||
ReflectionTestUtils.setField(service, "androidMeetingPushService", mock(AndroidMeetingPushService.class));
|
||||
return service;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue