怎样在C++中实现决策树机器学习算法

来源:我的博客作者:大卫头衔:程序员
导读:本期聚焦于小伙伴创作的《怎样在C++中实现决策树机器学习算法》,敬请观看详情,探索知识的价值。以下视频、文章将为您系统阐述其核心内容与价值。如果您觉得《怎样在C++中实现决策树机器学习算法》有用,将其分享出去将是对创作者最好的鼓励。

决策树是一种基于树结构进行决策的机器学习算法,核心是递归地将数据集按照最优特征划分,直到子集属于同一类别或满足停止条件。在C++中实现决策树需要结合数据结构设计与算法逻辑编码,下面逐步讲解完整实现过程。

怎样在C++中实现决策树机器学习算法

决策树核心原理回顾

决策树的构建依赖特征选择准则,常用的有信息增益、信息增益率和基尼系数。本文采用信息增益作为划分标准,信息增益基于熵的概念计算,熵越小代表数据集的纯度越高。

熵的计算公式为:对于数据集D,假设有K个类别,第k类样本占比为p_k,则熵H(D) = -Σ(p_k * log2(p_k))。特征A对数据集D的信息增益g(D,A) = H(D) - Σ( |D_v|/|D| * H(D_v) ),其中D_v是特征A取值为v的子集。

数据结构设计

首先需要定义决策树节点和数据集的结构,方便后续递归构建树。

#include <vector>
#include <string>
#include <map>
#include <cmath>
#include <algorithm>

// 样本结构,features存储特征取值,label存储类别标签
struct Sample {
    std::vector<double> features;
    int label;
};

// 决策树节点结构
struct TreeNode {
    // 如果是内部节点,split_feature是划分特征索引,split_value是划分值(连续特征用)
    // 如果是叶子节点,result是最终分类结果
    bool is_leaf;
    int split_feature;
    double split_value;
    int result;
    std::map<double, TreeNode*> children; // 子节点,key是特征取值
    TreeNode() : is_leaf(false), split_feature(-1), split_value(0.0), result(-1) {}
};

核心功能实现

计算熵

首先需要实现熵的计算函数,根据数据集的标签分布计算当前集合的熵值。

// 计算数据集的熵
double calculate_entropy(const std::vector<Sample>& dataset) {
    if (dataset.empty()) return 0.0;
    std::map<int, int> label_count;
    for (const auto& sample : dataset) {
        label_count[sample.label]++;
    }
    double entropy = 0.0;
    int total = dataset.size();
    for (const auto& pair : label_count) {
        double p = (double)pair.second / total;
        if (p > 0) {
            entropy -= p * log2(p);
        }
    }
    return entropy;
}

计算信息增益

遍历所有特征,计算每个特征的信息增益,选择信息增益最大的特征作为当前划分特征。

// 计算指定特征的信息增益
double calculate_info_gain(const std::vector<Sample>& dataset, int feature_idx) {
    double base_entropy = calculate_entropy(dataset);
    // 获取该特征的所有取值及对应的子集
    std::map<double, std::vector<Sample>> feature_subsets;
    for (const auto& sample : dataset) {
        double feature_val = sample.features[feature_idx];
        feature_subsets[feature_val].push_back(sample);
    }
    double weighted_entropy = 0.0;
    int total = dataset.size();
    for (const auto& pair : feature_subsets) {
        double weight = (double)pair.second.size() / total;
        weighted_entropy += weight * calculate_entropy(pair.second);
    }
    return base_entropy - weighted_entropy;
}

// 选择最优划分特征
int choose_best_feature(const std::vector<Sample>& dataset, const std::vector<int>& remain_features) {
    double max_gain = -1.0;
    int best_feature = -1;
    for (int feature_idx : remain_features) {
        double gain = calculate_info_gain(dataset, feature_idx);
        if (gain > max_gain) {
            max_gain = gain;
            best_feature = feature_idx;
        }
    }
    return best_feature;
}

递归构建决策树

基于选择的最优特征递归划分数据集,直到满足停止条件(如所有样本属于同一类别、没有剩余特征等)生成叶子节点。

// 判断数据集是否属于同一类别
bool is_same_label(const std::vector<Sample>& dataset, int& label) {
    if (dataset.empty()) return false;
    label = dataset[0].label;
    for (const auto& sample : dataset) {
        if (sample.label != label) return false;
    }
    return true;
}

// 递归构建决策树
TreeNode* build_decision_tree(const std::vector<Sample>& dataset, std::vector<int> remain_features) {
    TreeNode* node = new TreeNode();
    int same_label;
    // 停止条件1:所有样本属于同一类别
    if (is_same_label(dataset, same_label)) {
        node->is_leaf = true;
        node->result = same_label;
        return node;
    }
    // 停止条件2:没有剩余特征
    if (remain_features.empty()) {
        node->is_leaf = true;
        // 取样本最多的类别作为结果
        std::map<int, int> label_count;
        for (const auto& sample : dataset) {
            label_count[sample.label]++;
        }
        int max_count = 0;
        for (const auto& pair : label_count) {
            if (pair.second > max_count) {
                max_count = pair.second;
                node->result = pair.first;
            }
        }
        return node;
    }
    // 选择最优划分特征
    int best_feature = choose_best_feature(dataset, remain_features);
    node->split_feature = best_feature;
    // 移除已使用的特征
    remain_features.erase(std::remove(remain_features.begin(), remain_features.end(), best_feature), remain_features.end());
    // 按特征取值划分子集,递归构建子节点
    std::map<double, std::vector<Sample>> feature_subsets;
    for (const auto& sample : dataset) {
        double feature_val = sample.features[best_feature];
        feature_subsets[feature_val].push_back(sample);
    }
    for (const auto& pair : feature_subsets) {
        TreeNode* child = build_decision_tree(pair.second, remain_features);
        node->children[pair.first] = child;
    }
    return node;
}

预测功能实现

构建完成的决策树可以用于新样本的分类预测,从根节点开始根据特征取值向下遍历,直到到达叶子节点得到结果。

// 使用决策树进行预测
int predict(TreeNode* root, const std::vector<double>& features) {
    TreeNode* curr = root;
    while (!curr->is_leaf) {
        double feature_val = features[curr->split_feature];
        if (curr->children.find(feature_val) != curr->children.end()) {
            curr = curr->children[feature_val];
        } else {
            // 如果特征取值未出现过,返回-1表示无法预测
            return -1;
        }
    }
    return curr->result;
}

完整测试示例

下面构造一个简单的测试数据集,验证决策树的构建和预测功能是否正常。

int main() {
    // 构造测试数据集,两个特征,两个类别
    std::vector<Sample> dataset = {
        {{1.0, 2.0}, 0},
        {{1.0, 3.0}, 0},
        {{2.0, 2.0}, 1},
        {{2.0, 3.0}, 1},
        {{1.0, 1.0}, 0},
        {{2.0, 1.0}, 1}
    };
    // 初始剩余特征为0和1
    std::vector<int> remain_features = {0, 1};
    // 构建决策树
    TreeNode* root = build_decision_tree(dataset, remain_features);
    // 测试预测
    std::vector<double> test_sample1 = {1.0, 2.5};
    std::vector<double> test_sample2 = {2.0, 1.5};
    std::cout << "Test sample 1 predict: " << predict(root, test_sample1) << std::endl;
    std::cout << "Test sample 2 predict: " << predict(root, test_sample2) << std::endl;
    // 实际项目中需要添加节点内存释放逻辑,这里省略
    return 0;
}

优化与扩展

上述实现是基础版本,实际使用中可以根据需求扩展:比如支持连续特征的划分,通过遍历特征值选择最优划分点;添加剪枝逻辑避免过拟合;支持回归任务,将叶子节点的结果改为连续值的平均值等。同时需要注意内存管理,递归构建的树节点在不需要时应及时释放,避免内存泄漏。

C++决策树机器学习信息增益修改时间:2026-07-01 23:12:22

免责声明:​ 已尽一切努力确保本网站所含信息的准确性。网站内容多为原创整理与精心编撰,观点力求客观中立。本站旨在免费分享,内容仅供个人学习、研究或参考使用。若引用了第三方作品,版权归原作者所有。如内容涉及您的权益,请联系我们处理。
内容垂直聚焦
专注技术核心技术栏目,确保每篇文章深度聚焦于实用技能。从代码技巧到架构设计,为用户提供无干扰的纯技术知识沉淀,精准满足专业提升需求。
知识结构清晰
覆盖从开发到部署的全链路。AI、前端、编程、数据库、服务器、建站、系统层层递进,构建清晰学习路径,帮助用户系统化掌握开发与运维所需的核心技术。
深度技术解析
拒绝泛泛而谈,深入技术细节与实践难点。无论是数据库优化还是服务器配置,均结合真实场景与代码示例进行剖析,致力于提供可直接应用于工作的解决方案。
专业领域覆盖
精准对应开发生命周期。从前端界面到后端编程,从数据库操作到服务器运维,形成完整闭环,一站式满足全栈工程师和运维人员的技术需求。
即学即用高效
内容强调实操性,步骤清晰、代码完整。用户可根据教程直接复现和应用于自身项目,显著缩短从学习到实践的距离,快速解决开发中的具体问题。
持续更新保障
专注既定技术方向进行长期、稳定的内容输出。确保各栏目技术文章持续更新迭代,紧跟主流技术发展趋势,为用户提供经久不衰的学习价值。