• R语言
  • 怎样提升自定义函数的执行效率?

楼主给 AI 的描述如下:

有这样一个需求:同一投保人多保单累计应交保费50万及以上,即不同保单分别累计达50万可以分别获得确认函; 例如:投保人A名下有4张保单,分别为10万、20万、30万、50万,10+20+30可以获得1份,50万可获得1份,累计2份。这个问题会有一些复杂的特殊情况需要解决,比如投保人B名下有4张保单,累计应交保费分别为15万、40万、10万、35万,如果在计算函件数时,按照15万+35万、40万+10万来计算就会得到2张函件,可是如果按照15万+40万、10万+35万来计算就只能得到1张函件。现在我的表中有的数据字段分别是投保人id、保单id、累计应交保费,那么应该怎么计算每个投保人id分别获得多少确认函数量呢?并且知道每个投保人id的哪些保单id的累计应交保费合起来超过50万得到了1个函件。具体方法应该更加通用,就是每个投保人id不限保单数量的。请使用R中 data.table 来实现。

AI 给的答案是这样的:

# 加载必要的库
library(data.table)

# 创建示例数据框
data <- data.table(
  投保人id = c(rep(100,4), rep(200,4)),
  保单id = 1:8,
  累计应交保费 = c(10, 20, 30, 50, 15, 40, 10, 35)
)

# 函数来计算每个投保人所获的确认函数量及保单组合
calculate_confirmations <- function(data, threshold) {
  results <- data.table(投保人id = numeric(), 保单组合 = character(), 确认函数 = numeric())
  
  unique_ids <- unique(data$投保人id)
  
  for (id in unique_ids) {
    policies <- data[投保人id == id]
    n <- nrow(policies)
    
    used_policies <- rep(FALSE, n)
    
    find_combinations <- function(start, remaining_threshold, current_comb) {
      if (remaining_threshold <= 0) {
        return(list(current_comb))
      }
      if (start > n) {
        return(NULL)
      }
      combinations <- list()
      for (i in start:n) {
        if (!used_policies[i] && policies$累计应交保费[i] <= remaining_threshold) {
          used_policies[i] <- TRUE
          new_comb <- find_combinations(i + 1, remaining_threshold - policies$累计应交保费[i], c(current_comb, policies$保单id[i]))
          if (!is.null(new_comb)) {
            combinations <- c(combinations, new_comb)
          }
          used_policies[i] <- FALSE
        }
      }
      return(combinations)
    }
    
    while (sum(!used_policies) > 0) {
      comb <- find_combinations(1, threshold, c())
      if (!is.null(comb) && length(comb) > 0) {
        best_comb <- comb[[1]]
        results <- rbind(results, data.table(投保人id = id,
                                             保单组合 = paste(best_comb, collapse = ","),
                                             确认函数 = 1))
        for (policy in best_comb) {
          used_policies[which(policies$保单id == policy)] <- TRUE
        }
      } else {
        break
      }
    }
  }
  
  return(results)
}

# 设置累计应交保费阈值
threshold <- 50

# 计算结果
results <- calculate_confirmations(data, threshold)
print(results)

这个答案可以用,但遇到的新问题是,如果投保人id对应特别多保单id,执行效率就会很慢。本楼主今年太依赖 AI,已经不会自己写代码了,而 AI 又无法满足新的要求,请各位路过的坛友帮忙瞅瞅。

这个答案是哪个AI给出来的

    感觉代码里还是有点问题……这个代码是瞄着刚刚好50去设计的?如果是达到51这种是不能被判定出来的。还是说这就是你的需求?另外好像你也没说如果一个人的保单能有不同的方式达到累积50W,那这个算法是不是要找到那个使得确认函最多,或者最少的保单组合方式?

    data <- data.table(
      投保人id = c(rep(100,4), rep(200,4)),
      保单id = 1:8,
      累计应交保费 = c(31, 11, 21, 51, 41, 16, 11, 36)
    )
    
    results <- calculate_confirmations(data, threshold)
    print(results)
    # Empty data.table (0 rows and 3 cols): 投保人id,保单组合,确认函数

      fenguoerbian

      是要找到使得确认函最多的保单组合方式,需求应该是达到51也要被判定出来的。

      我看了你的栗子后,瞬间明白,还是要锻炼下自己的代码水平,不然如果自己不能辨析 AI 的答案,就会被 AI 坑了。

      fenguoerbian

      其实我也记得好像是发过帖子问的,不过一想到坛子的搜索功能很难用,就懒得翻帖子了,居然是2022年的事了,我记得当时我是写 sql 解决的,主要是分门别类去算。这次又觉得执行 sql 太麻烦了,而且我的服务器上的 Rstudio Server 连接数据库也很方便了,所以想着让 AI 给我写段 R 代码来算。

      这个问题的本质是一个离散数学问题而不是一个编程问题,用推理模型效果肯定更好。以下结果是 o1 给出的。我把问题的最后一个条件换成了 Python,因为做算法题 Python 表达会更准确一些。


      下面给出一种思路和示例代码,帮助你实现「对于同一投保人,给定多张保单的累计应交保费,如何尽量多地凑出若干组,每组的总额达到或超过 50 万,并且同时要知道是哪几张保单凑在一起满足了这 50 万」,以及最终计算每个投保人可以获得多少个“函件”。

      需要先说明的是,这个问题本质上是一个「组合优化」问题,和常见的「子集求和(Subset Sum)」「多子集划分(Multiple Subset Partition)」等问题比较接近,都属于 NP-难问题。也就是说,如果保单数量很多,要想获得“绝对最优解”(即最大可能的组数),往往需要使用回溯 / 动态规划 / 分支限界 等手段进行全局搜索,但其时间复杂度会非常大。

      在实际业务场景中,如果保单数量有限,或者能够接受近似最优(启发式)解法,那么可以考虑一些贪心策略,往往就能取得不错的效果。

      下文主要分成两部分:

      1. 思路与算法描述
      2. 示例 Python 实现

      1. 思路与算法描述

      1.1 问题抽象

      给定一个投保人名下的一组保单 {p1,p2,...,pn}\{p_1, p_2, ..., p_n\},每张保单有一个累计应交保费 viv_i。我们希望从这些保单中,尽量多地划分若干子集(每个子集中可能包含若干张保单),使得每个子集的累计应交保费之和 50\ge 50万。并且我们还需要知道具体哪些保单被分到同一子集。

      这里的目标有两个:

      1. 最大化子集个数(最大化“函件”数)。
      2. 获取每个子集对应的保单列表(方便后续知道具体哪几张保单合并在一起)。

      1.2 常见算法与优缺点

      1. 回溯 / 全局搜索(Optimal 解)

        • 原理:穷举所有可能的分组方法,找到能组成 50\ge 50 万 的最多子集个数。
        • 缺点:随着保单数增加,计算量呈指数级增长,实际业务中往往不可用。
        • 优点:可以得到最优解。
      2. 贪心策略(Heuristic 近似解)

        • 原理:常见做法是将保单按累计应交保费从大到小排序,然后依次尝试把保单加入某个“未到达 50 万的组”中。如果本身已经超过 50 万,就直接“锁定”为一组,开始下一组。
        • 缺点:贪心方法不一定得到最优解,可能存在局部最优。
        • 优点:实现简单、计算速度快,对于很多实际案例也能有较好的效果。
        • 示例:
          • 投保人 B 的保单为 [15 万, 40 万, 10 万, 35 万]。
          • 从大到小排序: [40, 35, 15, 10]。
          • 取 40 万,尚未达 50 万,则再加一个 10 万,总和 = 50 万,刚好凑够一组 -> [40,10][40, 10]
          • 剩余 [35, 15],35 + 15 = 50 万 -> 又是一组 -> [35,15][35, 15]
          • 得到 2 组。
          • 这样“碰巧”就拿到了最优结果。
        • 但你也可以发现,如果数据量更大、更复杂,贪心可能会错过一些更优组合。

      在实际业务中,如果你的表规模不算太大,可以考虑做一个“回溯 + 剪枝”或“分支限界”来找最优解;如果规模比较大,或者数据更新很频繁,则用贪心或其它近似算法可能更合适。


      2. 示例 Python 实现

      下面给出两个示例实现:

      1. 贪心实现(适合保单数量较多、追求速度):
      2. 回溯 / 搜索实现(适合保单数量相对少、追求绝对最优):

      这两种实现的输入思路相同:

      • 我们先按投保人 ID 对数据进行分组,比如 data 的每一行包含 [投保人id, 保单id, 累计应交保费]
      • 对每个投保人,取出所有的 (保单id, 累计应交保费) 放到一起,然后进行相应的算法处理,得到分组结果和最终可获得的函件数量。

      下面都给出示例代码,你可以根据自己的实际数据结构做进一步调整。

      2.1 贪心策略示例

      def greedy_group_policies(policy_list, threshold=50):
          """
          参数:
              policy_list: [(policy_id, premium), (policy_id, premium), ...]
                           已按 premium 从大到小排序
              threshold:   分组的门槛,这里默认 50(表示 50万)
          返回:
              groups:  列表,形如 [ [ (policy_id, premium), (policy_id, premium) ], [ ... ], ... ]
                       表示分出来的每一组里有哪些保单
          """
          groups = []  # 最终结果
          current_group = []
          current_sum = 0
      
          for (pid, val) in policy_list:
              # 如果加入当前组还能让当前组 < threshold,则加入
              # 否则如果本身就 >= threshold,就单独成为一组
              # 否则就把这个当前组先收起来,新建一组
      
              # 尝试把当前保单加入 current_group
              if current_sum + val < threshold:
                  current_group.append((pid, val))
                  current_sum += val
              else:
                  # 分两种情况:
                  # 1. 如果 current_sum + val >= threshold,但 current_sum 还没到 threshold,
                  #    那么把这个保单也放进当前组后,整个组就到达或超过门槛了,可以锁定这一组。
                  current_group.append((pid, val))
                  current_sum += val
                  # 锁定这一组
                  groups.append(current_group)
                  # 开始新的一组
                  current_group = []
                  current_sum = 0
      
          # 如果最后还有未锁定的 current_group,需要判断一下是否 >= threshold?
          # 如果不达标,看业务需求是否要保留或丢弃/合并等等,按需求来。
          # 这里为了演示,我们只保留“达到标准”的组
          # 也可以改成只要剩下的不空就也当一组返回,但通常这会导致这组并不满足 >= 50万的条件
          # 如果你需要“凑一凑”再看看能不能达到 threshold,就要改成更复杂的逻辑了。
          if current_group and current_sum >= threshold:
              groups.append(current_group)
      
          return groups
      
      def run_greedy_solution(data):
          """
          data: List of (投保人id, 保单id, 累计应交保费) 的数据
          """
          from collections import defaultdict
      
          # 1) 按投保人 ID 分组
          dict_by_insurer = defaultdict(list)   # {投保人id: [(保单id, 累计应交保费), ...], ...}
          for insurer_id, policy_id, premium in data:
              dict_by_insurer[insurer_id].append((policy_id, premium))
      
          # 2) 对每个投保人处理
          result = {}
          for insurer_id, policy_list in dict_by_insurer.items():
              # 按应交保费从大到小排序
              policy_list.sort(key=lambda x: x[1], reverse=True)
      
              # 调用贪心算法
              groups = greedy_group_policies(policy_list, threshold=50)
      
              # groups 就是所有有效分组,分组数即“函件”数
              result[insurer_id] = {
                  "groups": groups,
                  "函件数": len(groups)
              }
      
          return result
      
      if __name__ == "__main__":
          # 示例数据
          data = [
              # 投保人A
              ("A", "p1", 10),
              ("A", "p2", 20),
              ("A", "p3", 30),
              ("A", "p4", 50),
              # 投保人B
              ("B", "p1", 15),
              ("B", "p2", 40),
              ("B", "p3", 10),
              ("B", "p4", 35),
          ]
      
          results = run_greedy_solution(data)
          for insurer_id, info in results.items():
              print(f"投保人 {insurer_id} 可获得函件数:{info['函件数']}")
              for idx, group in enumerate(info['groups'], start=1):
                  group_ids = [g[0] for g in group]
                  group_sum = sum(g[1] for g in group)
                  print(f"  第 {idx} 组保单: {group_ids}, 总额 = {group_sum}万")
              print("------")
      • 在这个示例里,阈值(threshold=50)就是 50 万。
      • 这里的贪心方法是一个简单版本:从大到小依次加入,当前组只要达到 >= 50 就收手,开始下一组。
      • 这个实现速度很快,代码比较直观,但是不一定得到全局最优解;在示例给出的 A/B 测试数据中能凑出不错的结果。

      2.2 回溯 / 搜索示例

      如果你的数据量不大(比如每个投保人只有几十张保单以内),想要找「尽量最优」的划分方法,可以用回溯 / 搜索来尝试。思路伪代码如下:

      1. 对保单从大到小排序(有助于更快地剪枝)。
      2. 递归函数:
        • 当前处理到第 i 张保单;
        • 维护一个列表 groups,其中每个元素是一个子集(子集的和要尽量 >= 50);
        • 对于第 i 张保单,你可以尝试把它放到已有的某一个子集中(前提是不会“浪费”太多),或者新开一个子集;
        • 过程中,当某个子集达到 50\ge 50 时,可以“锁定”这一组,继续处理下一张保单……
      3. 记录分组数最多时的方案;
      4. 使用一些剪枝策略(比如一旦当前所有剩余保单加起来也无法多凑出一份 50 万,或已经无法超过当前已知的最优分组数,就提前停止)。

      代码示例(演示思路,未做复杂剪枝):

      def backtrack_policies(policy_list, threshold=50):
          """
          policy_list: [(policy_id, premium), ...],已排序(一般从大到小)
          threshold:   目标值,默认 50
      
          返回:
              (max_count, best_groups)
              max_count: 最大可分组数
              best_groups: 对应的保单分组方案
          """
      
          best_result = {
              "count": 0,
              "groups": []
          }
      
          # groups 用于保存若干子集,每个子集是 [(pid, val), ...]
          def backtrack(index, groups):
              nonlocal best_result
      
              # 如果处理完所有保单
              if index == len(policy_list):
                  # 这里可以统计一下 >= threshold 的子集个数
                  valid_count = sum(1 for g in groups if sum(x[1] for x in g) >= threshold)
                  if valid_count > best_result["count"]:
                      best_result["count"] = valid_count
                      # 这里要深拷贝一下分组结果
                      best_result["groups"] = [list(g) for g in groups]
                  return
      
              (pid, val) = policy_list[index]
      
              # 尝试放到当前已有的某个子集中
              for i in range(len(groups)):
                  groups[i].append((pid, val))
                  backtrack(index + 1, groups)
                  groups[i].pop()
      
              # 或者另起一个子集
              groups.append([(pid, val)])
              backtrack(index + 1, groups)
              groups.pop()
      
          backtrack(0, [])
      
          return best_result["count"], best_result["groups"]
      
      
      def run_backtrack_solution(data):
          """
          data: List of (投保人id, 保单id, 累计应交保费)
          返回: {投保人id: {"groups": ..., "函件数": ...}, ...}
          """
          from collections import defaultdict
      
          dict_by_insurer = defaultdict(list)
          for insurer_id, policy_id, premium in data:
              dict_by_insurer[insurer_id].append((policy_id, premium))
      
          result = {}
          for insurer_id, policy_list in dict_by_insurer.items():
              # 从大到小排序
              policy_list.sort(key=lambda x: x[1], reverse=True)
      
              max_count, best_groups = backtrack_policies(policy_list, threshold=50)
      
              # 注意:best_groups 里面还可能有些子集和 < 50 的(回溯时也会生成),
              # 我们只关心最终 >= 50 的子集
              final_groups = []
              for g in best_groups:
                  s = sum(x[1] for x in g)
                  if s >= 50:
                      final_groups.append(g)
      
              result[insurer_id] = {
                  "groups": final_groups,
                  "函件数": len(final_groups)
              }
      
          return result
      
      
      if __name__ == "__main__":
          data = [
              ("A", "p1", 10),
              ("A", "p2", 20),
              ("A", "p3", 30),
              ("A", "p4", 50),
              ("B", "p1", 15),
              ("B", "p2", 40),
              ("B", "p3", 10),
              ("B", "p4", 35),
          ]
      
          results_bt = run_backtrack_solution(data)
          for insurer_id, info in results_bt.items():
              print(f"[回溯] 投保人 {insurer_id} 可获得函件数:{info['函件数']}")
              for idx, group in enumerate(info['groups'], start=1):
                  group_ids = [g[0] for g in group]
                  group_sum = sum(g[1] for g in group)
                  print(f"  第 {idx} 组保单: {group_ids}, 总额 = {group_sum}万")
              print("------")
      • 在这个回溯过程中,你会看到它枚举了所有可能的分法,并记录「达到 50 万及以上的子集个数最多」时的方案。
      • 为了避免输出过多的冗余子集,需要在最终结果里再次过滤,确保只保留那些 50\ge 50 的子集。
      • 如果数据量一旦变大(比如几十上百张保单),这个回溯就会很慢,需要用更多剪枝策略做优化,或者干脆换用近似算法。

      3. 总结

      1. 如果数据量不大、求尽量精确:可以考虑上面展示的「回溯 / 搜索」,配合合理的剪枝,得到最优结果。
      2. 如果数据量较大、需快速计算:可以考虑先排序后“贪心”来凑,虽然不保证最优,但实现与执行都更简单快速,往往也能取得较好结果。

      无论是贪心还是回溯,最终都能得到两部分信息:

      • 最大或近似最大能分出的函件(子集)数量
      • 每份函件对应到的具体保单(哪些 policy_id 被放到一个组里)。

      这样就可以满足你的需求:既知道每个投保人能获得多少函件数,也能知道每张函件由哪些保单所凑成。

        可以看到这些解法都涉及到大量无法向量化的 for 循环…… 这种情况要达到最优的运行速度和内存使用,让 Claude 3.5 Sonnet 之类编程水平较高的模型用 C/C++ 实现一下就可以了。

        nan.xiao
        哇,这真的是人写的嘛,真厉害。

        我刚又想了一下,最初处理数据的时候,先把一个投保人id一个保单id的累计应交保费大于等于50万的单独拿出来,然后先计算多个保单组合后累计应交保费等于50万的,最后计算多个保单组合后累计应交保费大于50万的。也许这样分类处理后还能再快点。

          yuanfan 这个也是昨天我想的思路,但是写成一个可用的 demo 还是比较费时间。所以ai要是可以很快就给一个可行的答案,真是厉害,AI已经发展到一个临界点了。

          手动点赞上面xiaonan的回答,简直就是一篇论文了

          9 天 后

          yuanfan
          由于前面回溯搜索的 PYTHON 脚本执行起来很慢,大概200条数据就不出结果了,最后用把 AI 给的多次不同的结果整合到一起用了。

          library(data.table)
          
          data <- data.table(
              投保人id = c(rep(100, 4), rep(200, 4), rep(300, 4), rep(400, 4)),
              保单id = 1:16,
              累计应交保费  = c(10, 20, 30, 50, 15, 40, 10, 35, 31, 11, 21, 51, 41, 16, 11, 36)
          )
          # 函数来计算每个投保人所获的确认函数量及保单组合
          calculate_confirmations1 <- function(data, threshold) {
              results <-
                  data.table(投保人id = numeric(),
                                保单组合  = character(),
                                确认函数  = numeric())
              
              unique_ids <- unique(data$投保人id)
              
              for (id in unique_ids) {
                  policies <- data[投保人id == id]
                  n <- nrow(policies)
                  
                  used_policies <- rep(FALSE, n)
                  
                  find_combinations <-
                      function(start,
                               remaining_threshold,
                               current_comb) {
                          if (remaining_threshold <= 0) {
                              return(list(current_comb))
                          }
                          if (start > n) {
                              return(NULL)
                          }
                          combinations <- list()
                          for (i in start:n) {
                              if (!used_policies[i] &&
                                  policies$累计应交保费[i] <= remaining_threshold) {
                                  used_policies[i] <- TRUE
                                  new_comb <-
                                      find_combinations(
                                          i + 1,
                                          remaining_threshold - policies$累计应交保费[i],
                                          c(current_comb, policies$保单id[i])
                                      )
                                  if (!is.null(new_comb)) {
                                      combinations <- c(combinations, new_comb)
                                  }
                                  used_policies[i] <- FALSE
                              }
                          }
                          return(combinations)
                      }
                  
                  while (sum(!used_policies) > 0) {
                      comb <- find_combinations(1, threshold, c())
                      if (!is.null(comb) && length(comb) > 0) {
                          best_comb <- comb[[1]]
                          results <- rbind(results,
                                           data.table(
                                               投保人id = id,
                                                  保单组合  = paste(best_comb, collapse = ","),
                                                  确认函数  = 1
                                           ))
                          for (policy in best_comb) {
                              used_policies[which(policies$保单id == policy)] <- TRUE
                          }
                      } else {
                          break
                      }
                  }
              }
              
              return(results)
          }
          
          
          # 函数来计算每个投保人所获的确认函数量及保单组合
          calculate_confirmations2 <- function(data, threshold) {
              results <-
                  data.table(投保人id = numeric(),
                                保单组合  = character(),
                                确认函数  = numeric())
              
              unique_ids <- unique(data$投保人id)
              
              for (id in unique_ids) {
                  policies <- data[投保人id == id]
                  n <- nrow(policies)
                  
                  used_policies <- rep(FALSE, n)
                  
                  find_combinations <-
                      function(start,
                               remaining_threshold,
                               current_comb) {
                          if (remaining_threshold <= 0) {
                              return(list(current_comb))
                          }
                          if (start > n) {
                              return(NULL)
                          }
                          combinations <- list()
                          for (i in start:n) {
                              if (!used_policies[i]) {
                                  used_policies[i] <- TRUE
                                  new_comb <-
                                      find_combinations(
                                          i + 1,
                                          remaining_threshold - policies$累计应交保费[i],
                                          c(current_comb, policies$保单id[i])
                                      )
                                  if (!is.null(new_comb)) {
                                      combinations <- c(combinations, new_comb)
                                  }
                                  used_policies[i] <- FALSE
                              }
                          }
                          return(combinations)
                      }
                  
                  best_combinations <- list()
                  
                  while (sum(!used_policies) > 0) {
                      comb <- find_combinations(1, threshold, c())
                      if (!is.null(comb) && length(comb) > 0) {
                          best_comb <- comb[[1]]
                          for (c in comb) {
                              if (sum(policies$累计应交保费[policies$保单id %in% c]) >= threshold) {
                                  best_comb <- c
                              }
                          }
                          results <- rbind(results,
                                           data.table(
                                               投保人id = id,
                                                  保单组合  = paste(best_comb, collapse = ","),
                                                  确认函数  = 1
                                           ))
                          for (policy in best_comb) {
                              used_policies[which(policies$保单id == policy)] <- TRUE
                          }
                          best_combinations <- c(best_combinations, list(best_comb))
                      } else {
                          break
                      }
                  }
              }
              
              return(results)
          }
          
          # 先用第一个函数
          results1 <- calculate_confirmations1(data, 50)
          
          # 提取 results1 中的保单 ID 保单
          ID_list <-
              unlist(strsplit(results1$保单组合, ","))
          #筛选出 data 表中不包含这些保单 ID 的行
          filtered_data <- data[!(保单id %in% as.integer(ID_list))]
          
          # 再用第二个函数
          results2 <- calculate_confirmations2(filtered_data, 50)
          
          results_end <- rbind(results1, results2)
          print(results_end)