如何在 GridSearchCV 中正确配置多个评估指标

来源:编程学习作者:厦门程序员头衔:程序员
导读:本期聚焦于小伙伴创作的《如何在 GridSearchCV 中正确配置多个评估指标》,敬请观看详情,探索知识的价值。以下视频、文章将为您系统阐述其核心内容与价值。如果您觉得《如何在 GridSearchCV 中正确配置多个评估指标》有用,将其分享出去将是对创作者最好的鼓励。

GridSearchCV是sklearn中用于自动化超参数调优的核心工具,默认情况下仅支持单评估指标,只能根据单个指标的数值大小筛选最优参数组合。但在实际业务场景中,我们往往需要同时参考准确率、召回率、F1值等多个指标来评估模型表现,此时就需要对GridSearchCV进行多评估指标的配置。

如何在 GridSearchCV 中正确配置多个评估指标

多评估指标配置的两种核心方式

方式一:通过scoring参数传入字典

这种方式是最常用的多指标配置方法,直接将多个评估指标以字典形式传给GridSearchCV的scoring参数,字典的键为指标名称,值为对应的评估指标函数或者字符串别名。

from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score, recall_score, f1_score

# 加载数据集
data = load_iris()
X = data.data
y = data.target

# 定义模型
model = RandomForestClassifier(random_state=42)

# 定义参数网格
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [3, 5, 7]
}

# 定义多个评估指标
scoring = {
    'accuracy': 'accuracy',
    'recall': 'recall_macro',
    'f1': 'f1_macro'
}

# 初始化GridSearchCV,传入多指标配置
grid_search = GridSearchCV(
    estimator=model,
    param_grid=param_grid,
    scoring=scoring,
    cv=5,
    refit='accuracy'  # 指定最终选择最优参数的参考指标
)

# 执行调参
grid_search.fit(X, y)

方式二:自定义多指标评分函数

如果sklearn内置的指标字符串别名无法满足需求,比如需要自定义加权评估指标,可以通过make_scorer构造自定义评分函数,再传入scoring字典中。

from sklearn.metrics import make_scorer

# 自定义加权准确率评分函数,假设正样本权重是负样本的2倍
def weighted_accuracy(y_true, y_pred):
    # 计算正样本和负样本的数量
    pos_count = (y_true == 1).sum()
    neg_count = (y_true != 1).sum()
    # 计算正样本和负样本的准确率
    pos_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1]) if pos_count > 0 else 0
    neg_acc = accuracy_score(y_true[y_true!=1], y_pred[y_true!=1]) if neg_count > 0 else 0
    # 加权求和
    return (pos_acc * pos_count * 2 + neg_acc * neg_count) / (pos_count * 2 + neg_count)

# 构造scorer
weighted_acc_scorer = make_scorer(weighted_accuracy)

# 配置多指标
scoring = {
    'accuracy': 'accuracy',
    'weighted_accuracy': weighted_acc_scorer,
    'f1': 'f1_macro'
}

refit参数的配置逻辑

当配置了多个评估指标后,必须通过refit参数指定最终筛选最优参数的参考指标,否则GridSearchCV无法自动确定最优参数组合,也无法调用refit()方法重新训练模型。

  • 如果refit传入的是字符串,对应scoring字典中的某个指标名称,GridSearchCV会按照该指标的数值大小选择最优参数,数值越大(如果是损失类指标则越小)的参数组合会被选中。
  • 如果refit传入的是可调用对象,该对象需要接收GridSearchCV的cv_results_作为参数,返回最优参数的索引,适合需要根据多个指标综合判断最优参数的场景。
import numpy as np

# 自定义refit函数,选择准确率大于0.95且F1值最大的参数组合
def custom_refit(cv_results):
    # 筛选准确率大于0.95的索引
    valid_indices = np.where(cv_results['mean_test_accuracy'] > 0.95)[0]
    if len(valid_indices) == 0:
        # 如果没有符合条件的,选择准确率最高的
        return np.argmax(cv_results['mean_test_accuracy'])
    # 在符合条件的索引中选择F1值最大的
    best_idx = valid_indices[np.argmax(cv_results['mean_test_f1'][valid_indices])]
    return best_idx

grid_search = GridSearchCV(
    estimator=model,
    param_grid=param_grid,
    scoring=scoring,
    cv=5,
    refit=custom_refit
)

多指标结果的提取方法

GridSearchCV训练完成后,所有交叉验证的评估结果都会存储在cv_results_属性中,这是一个字典结构的对象,可以通过对应的键提取不同指标的结果。

# 提取所有参数组合的平均准确率
mean_accuracy = grid_search.cv_results_['mean_test_accuracy']
# 提取所有参数组合的平均F1值
mean_f1 = grid_search.cv_results_['mean_test_f1']
# 提取最优参数对应的所有指标结果
best_accuracy = grid_search.best_score_
best_f1 = grid_search.cv_results_['mean_test_f1'][grid_search.best_index_]

print(f"最优参数: {grid_search.best_params_}")
print(f"最优参数对应的准确率: {best_accuracy:.4f}")
print(f"最优参数对应的F1值: {best_f1:.4f}")

常见问题说明

  • 多个指标的方向需要一致,比如准确率和F1值都是越大越好,如果同时传入准确率和对数损失(越小越好),需要统一指标方向,比如将损失类指标取负数,保证所有指标都是越大越优。
  • 如果不需要GridSearchCV自动选择最优参数,只想获取所有参数组合的多指标评估结果,可以将refit设置为False,此时grid_search.best_params_等属性会返回None,仅能提取cv_results_中的结果。
  • 当使用多指标配置时,GridSearchCV的n_jobs参数可以并行计算不同参数组合的交叉验证结果,不会受到多指标配置的影响,适当调大n_jobs可以提升调参效率。

GridSearchCVsklearn多评估指标模型调参交叉验证修改时间:2026-07-01 16:36:37

免责声明:​ 已尽一切努力确保本网站所含信息的准确性。网站内容多为原创整理与精心编撰,观点力求客观中立。本站旨在免费分享,内容仅供个人学习、研究或参考使用。若引用了第三方作品,版权归原作者所有。如内容涉及您的权益,请联系我们处理。
内容垂直聚焦
专注技术核心技术栏目,确保每篇文章深度聚焦于实用技能。从代码技巧到架构设计,为用户提供无干扰的纯技术知识沉淀,精准满足专业提升需求。
知识结构清晰
覆盖从开发到部署的全链路。AI、前端、编程、数据库、服务器、建站、系统层层递进,构建清晰学习路径,帮助用户系统化掌握开发与运维所需的核心技术。
深度技术解析
拒绝泛泛而谈,深入技术细节与实践难点。无论是数据库优化还是服务器配置,均结合真实场景与代码示例进行剖析,致力于提供可直接应用于工作的解决方案。
专业领域覆盖
精准对应开发生命周期。从前端界面到后端编程,从数据库操作到服务器运维,形成完整闭环,一站式满足全栈工程师和运维人员的技术需求。
即学即用高效
内容强调实操性,步骤清晰、代码完整。用户可根据教程直接复现和应用于自身项目,显著缩短从学习到实践的距离,快速解决开发中的具体问题。
持续更新保障
专注既定技术方向进行长期、稳定的内容输出。确保各栏目技术文章持续更新迭代,紧跟主流技术发展趋势,为用户提供经久不衰的学习价值。