huliguo
2 天以前 d8143b9121bbe941f116230eaa5524ab2cc12a66
src/main/java/com/linghu/controller/CollectController.java
@@ -47,6 +47,7 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@RestController
@RequestMapping("/collect")
@@ -206,10 +207,11 @@
       int maxConcurrentUsers = searchTaskRequest.getConfig() != null ?
               searchTaskRequest.getConfig().getMax_concurrent_users() : 3;
       List<List<UserDto>> userBatches = splitUsersIntoBatches(searchTaskRequest.getUsers(), maxConcurrentUsers);
       // 获取 keywordId
       Integer keywordId = searchTaskRequest.getKeyword_id();
       //分割
       List<List<UserDto>> userBatches = splitUsersIntoBatches(searchTaskRequest.getUsers(), maxConcurrentUsers,keywordId);
       return Flux.fromIterable(userBatches)
               .flatMap(batch -> {
@@ -255,7 +257,15 @@
                .then();
    }
    private List<List<UserDto>> splitUsersIntoBatches(List<UserDto> users, int batchSize) {
    private List<List<UserDto>> splitUsersIntoBatches(List<UserDto> users, int batchSize,Integer keywordId) {
        Keyword keyword = keywordService.getById(keywordId);
        if (null==keyword.getNum()){
            keyword.setNum(0);
        }
        keyword.setNum(keyword.getNum()+1);
        keywordService.updateById(keyword);
        List<List<UserDto>> batches = new ArrayList<>();
        for (int i = 0; i < users.size(); i += batchSize) {
            batches.add(users.subList(i, Math.min(i + batchSize, users.size())));
@@ -485,7 +495,7 @@
                // 3. 收集所有需要更新的问题和引用
                List<Question> questionsToUpdate = new ArrayList<>();
                List<Reference> allReferences = new ArrayList<>();
                List<Reference> resultList = new ArrayList<>();
                // 遍历结果
                for (UserResult userResult : result.getResults()) {
                    for (QuestionResult questionResult : userResult.getQuestions_results()) {
@@ -511,17 +521,6 @@
                                questionsToUpdate.add(question);
                                //如果查询结果不为空查询num
                                Integer maxNumByKeywordId = referenceService.getMaxNumByKeywordId(keyword.getKeyword_id());
                               if (maxNumByKeywordId != null){
                                   maxNumByKeywordId++;
                               }else {
                                   maxNumByKeywordId = 1;
                               }
                                // 收集引用数据,处理空集合情况
                                Integer finalMaxNumByKeywordId = maxNumByKeywordId;
                                List<Reference> references =
                                        Optional.ofNullable(questionResult.getReferences())
                                                .orElse(Collections.emptyList())
@@ -532,30 +531,38 @@
                                                    reference.setTitle(ref.getTitle());
                                                    reference.setUrl(ref.getUrl());
                                                    reference.setDomain(ref.getDomain());
                                                    reference.setNum(finalMaxNumByKeywordId);
                                                    reference.setNum(keyword.getNum());
                                                    reference.setTask_id(result.getTask_id());
                                                    reference.setKeyword_id(keyword.getKeyword_id());
                                                    //域名和平台id映射
                                                    reference.setCreate_time(LocalDateTime.now());
                                                    Platform platform = platformService.getPlatformByDomain(reference.getDomain());
//                                                    if (platform == null) {
//                                                        throw new RuntimeException("未找到对应的平台: " + reference.getDomain());
//                                                    }
                                                    if (platform != null){
                                                    if (platform == null) {
                                                        //平台为空 创建平台 类型为“默认”
                                                        Type type = typeService.getOne(new LambdaQueryWrapper<Type>().eq(Type::getType_name,"默认"));
                                                        if (type == null) {
                                                            Type newType = new Type();
                                                            newType.setType_name("默认");
                                                            typeService.save(newType);
                                                            type = newType;
                                                        }
                                                        Platform platform1 = new Platform();
                                                        platform1.setDomain(reference.getDomain());
                                                        platform1.setPlatform_name(reference.getDomain());
                                                        platform1.setType_id(type.getType_id());
                                                        platformService.save(platform1);
                                                        reference.setType_id(type.getType_id());
                                                        reference.setPlatform_id(platform1.getPlatform_id());
                                                    }
                                                    else {
                                                        reference.setPlatform_id(platform.getPlatform_id());
                                                        Type type = typeService.getById(platform.getType_id());
//                                                    if (type == null) {
//                                                        throw new RuntimeException("未找到对应的类型: " + reference.getDomain());
//                                                    }
                                                        if (type != null){
                                                            reference.setType_id(type.getType_id());
                                                        }
                                                    }
                                                    // 根据 domain 查询类型
                                                    return reference;
                                                })
                                                .collect(Collectors.toList());
@@ -564,6 +571,53 @@
                                if (!references.isEmpty()) {
                                    allReferences.addAll(references);
                                }
                                //取数据库中当前关键词的当前轮次的当前问题id结果拿出来
                                List<Reference> dbList = referenceService.list(new LambdaQueryWrapper<Reference>().eq(Reference::getKeyword_id, keyword.getKeyword_id())
                                        .eq(Reference::getNum, keyword.getNum())
                                        .eq(Reference::getQuestion_id, question.getQuestion_id())
                                );
                                // 1. 合并两个列表
                                List<Reference> combinedList = new ArrayList<>();
                                combinedList.addAll(allReferences);
                                combinedList.addAll(dbList);
                                // 2. 创建复合键的Map,用于统计完全匹配的记录
                                Map<String, List<Reference>> compositeKeyMap = combinedList.stream()
                                        .collect(Collectors.groupingBy(
                                                ref -> ref.getTitle() + "|" + ref.getUrl() + "|" + ref.getDomain()
                                        ));
                                // 3. 处理每组重复记录
                                compositeKeyMap.forEach((key, refGroup) -> {
                                    // 3.1 找出组内有ID的记录(优先从dbList中获取)
                                    Optional<Reference> existingRecord = refGroup.stream()
                                            .filter(ref -> ref.getReference_id() != null)
                                            .findFirst();
                                    // 3.2 统计该组的重复次数(总数-1)
                                    int repetitionCount = refGroup.size() - 1;
                                    // 3.3 决定最终保留的记录
                                    Reference recordToSave;
                                    if (existingRecord.isPresent()) {
                                        // 使用已有ID的记录并更新重复次数
                                        recordToSave = existingRecord.get();
                                        recordToSave.setRepetition_num(
                                                (recordToSave.getRepetition_num() == null ? 0 : recordToSave.getRepetition_num())
                                                        + repetitionCount
                                        );
                                    } else {
                                        // 没有ID记录则取第一条并设置重复次数
                                        recordToSave = refGroup.get(0);
                                        recordToSave.setRepetition_num(repetitionCount);
                                    }
                                    resultList.add(recordToSave);
                                });
                                referenceService.saveOrUpdateBatch(resultList);
                            }
                        } catch (Exception e) {
                            log.error(e.getMessage(), e);
@@ -578,7 +632,7 @@
                    questionService.updateBatchById(questionsToUpdate);
                    System.out.println("成功批量更新 " + questionsToUpdate.size() + " 个问题");
                }
                referenceService.saveBatch(allReferences);
                // 5. 批量插入引用,使用流式分批处理
//                if (!allReferences.isEmpty()) {
//                    int batchSize = 1000;