IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 数据结构与算法 -> ceres库 拟合曲线 (含高斯牛顿法) -> 正文阅读

[数据结构与算法]ceres库 拟合曲线 (含高斯牛顿法)

Ceres是一个广泛使用的最小二乘问题求解库。
Ceres求解的最小二乘问题最一般的形式如下
min ? x 1 2 ∑ i ρ i ( ∣ ∣ f i ( x i 1 , . . . , x i n ∣ ∣ 2 ) \min_{x}\frac{1}{2}\sum_{i}\rho_{i}(||f_{i}(x_{i1}, ... , x_{in}||^{2}) xmin?21?i?ρi?(fi?(xi1?,...,xin?2)

其中 f i f_{i} fi?为代价函数,在ceres中为残差块, ρ i \rho_{i} ρi?为核函数,如果不用的话,那么目标函数仍然是许多平方项的和。

Ceres求解步骤:

  1. 定义每个参数块,可以是向量,也可以是四元数,李代数。
    如果是向量,需要为每个参数块分配double数组来储存变量的值。
  2. 定义残差块的计算公式。
  3. 残差块需要定义雅可比的计算方式,也可以采用Ceres的自动求导功能。如果用自动求导,需要重载一个()操作符。
  4. 把参数块和残差块加入Ceres的problem对象中,调用Solve函数求解。

示例:
假设要拟合一个曲线 y = e x p ( a x 2 + b x + c ) + w y=exp(ax^{2}+bx+c)+w y=exp(ax2+bx+c)+w
其中w是高斯噪声,现有N个数据(x, y)
那么最小二乘问题如下,这里不用核函数
min ? a , b , c 1 2 ∑ i = 1 N ( ∣ ∣ y i ? e x p ( a x i 2 + b x i + c ) ∣ ∣ 2 ) \min_{a,b,c}\frac{1}{2}\sum_{i=1}^{N}(||y_{i}-exp(ax_{i}^{2}+bx_{i}+c)||^{2}) a,b,cmin?21?i=1N?(yi??exp(axi2?+bxi?+c)2)
误差函数(残差)定义为 e i = y i ? e x p ( a x i 2 + b x i + c ) e_{i}=y_{i}-exp(ax_{i}^{2}+bx_{i}+c) ei?=yi??exp(axi2?+bxi?+c)

(x, y)是观测数据,我们要求解的参数是(a,b,c)
得到拟合的曲线 e x p ( a x 2 + b x + c ) exp(ax^{2}+bx+c) exp(ax2+bx+c)

先 给 出 高 斯 牛 顿 法 的 解 法 , 帮 助 理 解 最 小 二 乘 优 化 问 题 \color{#00F}{先给出高斯牛顿法的解法,帮助理解最小二乘优化问题}

如果直接用高斯牛顿法解的话,需要求误差函数 e i e_{i} ei?分别对a, b, c的偏导,得到雅可比矩阵J = [e对a的偏导,e对b的偏导,e对c的偏导]T
用JJT近似Hessian矩阵,令b=Je (e是 e i e_{i} ei?
直接给出增量解为 H Δ x = b H\Delta x = b HΔx=b
这个 Δ x = [ Δ a , Δ b , Δ c ] \Delta x = [\Delta a, \Delta b, \Delta c] Δx=[Δa,Δb,Δc]
然后每次对a,b,c调整 Δ x \Delta x Δx,直到误差 e i e_{i} ei?收敛且不再增大。

int gaussNewton() {
    double ar = 1.0, br = 2.0, cr = 1.0;  //真实参数值
    double ae = 2.0, be = -1.0, ce = 5.0;   //估计参数值
    int N = 100;
    double w_sigma = 1.0;  //噪声sigma
    double inv_sigma = 1.0 / w_sigma;
    cv::RNG rng;

    vector<double> x_data, y_data;
    for(int i = 0; i < N; i++) {
        double x = i / 100.0;
        x_data.push_back((x));
        y_data.push_back(exp(ar*x*x + br*x + cr) + rng.gaussian(w_sigma)); //rng:随机数产生,sigma指标准差
    }

    int iteration = 100;
    double cost = 0, lastCost = 0;  //本次迭代和上一次迭代
    chrono::steady_clock::time_point t1 = chrono::steady_clock::now();
    for(int iter = 0; iter < iteration; iter++) {
        Matrix3d H = Matrix3d::Zero();   //Hessian = JJ^T
        Vector3d b = Vector3d::Zero();   //bias
        cost = 0;

        //每次迭代批量累加所有数据的误差
        for(int i = 0; i < N; i++) {
            double xi = x_data[i], yi = y_data[i];
            double error = yi - exp(ae*xi*xi + be*xi + ce);
            Vector3d J;
            J[0] = -xi*xi * exp(ae*xi*xi + be*xi + ce);  //de/da
            J[1] = -xi * exp(ae*xi*xi + be*xi + ce); //de/db
            J[2] = - exp(ae*xi*xi + be*xi + ce);  //de/dc

            H += inv_sigma * inv_sigma * J * J.transpose();
            b += -inv_sigma * inv_sigma * error * J;

            cost += error * error;
        }

        //批量求解方程Hx = b
        Vector3d dx = H.ldlt().solve(b); //用cholesky分解,避免求矩阵的逆
        if(isnan(dx[0])) {
            cout << "result is nan!" << endl;
            break;
        }
        if(iter > 0 && cost >= lastCost) {
            cout << "Cost : " << cost << ">=lastCost: " << lastCost << ", break: " << endl;
            break;  //理论上cost应该是逐渐变小的,变大了说明有问题
        }
        ae += dx[0];  //梯度下降法
        be += dx[1];
        ce += dx[2];

        lastCost = cost;
        cout << "total cost: " << cost << ", \t\tupdate: " << dx.transpose() <<
        "\t\testimated params: " << ae << ", " << be << ", " << ce << endl;
    }
    chrono::steady_clock::time_point t2 = chrono::steady_clock::now();
    chrono::duration<double> time_used = chrono::duration_cast<chrono::duration<double>>(t2 - t1);
    cout << "solve time cost = " << time_used.count() << " seconds." << endl;
    cout << "estimated abc = " << ae << ", " << be << ", " << ce << endl;
}

解出参数

abc = 0.890912, 2.1719, 0.943629

下面用Ceres解同一问题

//代价函数
struct CURVE_FITTING_COST {
    CURVE_FITTING_COST(double x, double y) : _x(x), _y(y) {}  //构造函数,使用初始化列表来初始化
    //残差的计算
    template<typename T> bool operator()(  //重载()运算符
            const T *const abc, //模型参数,3维
            T *residual) const{
        //y-exp(ax^2 + bx + c)
        residual[0] = T(_y) - ceres::exp(abc[0] * T(_x) * T(_x) + abc[1] * T(_x) + abc[2]);
        return true;
    }
    const double _x, _y;
};

int ceres_fitting(){
    double ar = 1.0, br = 2.0, cr = 1.0;  //真实参数值
    double ae = 2.0, be = -1.0, ce = 5.0;  //估计参数值
    int N = 100;  //数据点
    double w_sigma = 1.0;   //噪声sigma
    double inv_sigma = 1.0 / w_sigma;
    cv::RNG rng;  //随机数产生器

    vector<double> x_data, y_data;  //数据
    for(int i = 0; i < N; i++) {
        double x = i / 100.0;
        x_data.push_back((x));
        y_data.push_back(exp(ar*x*x + br*x + cr) + rng.gaussian(w_sigma)); //rng:随机数产生,sigma指标准差
    }

    double abc[3] = {ae, be, ce};

    //构建最小二乘问题
    ceres::Problem problem;
    //所有数据
    for(int i = 0; i < N; i++) {
        problem.AddResidualBlock(  //添加误差项
                //使用自动求导,模板参数:误差类型,输出维度,输入维度,维度要与struct中一致
                new ceres::AutoDiffCostFunction<CURVE_FITTING_COST, 1, 3> (
                        new CURVE_FITTING_COST(x_data[i], y_data[i])  //定义的残差函数
                        ),
                nullptr,   //核函数,这里不使用,为空
                abc     //待估计参数
                );
    }

    //配置求解器
    ceres::Solver::Options options;
    options.linear_solver_type = ceres::DENSE_NORMAL_CHOLESKY;  //增量方程如何求解
    options.minimizer_progress_to_stdout = true;  //输出到cout

    ceres::Solver::Summary summary;  //优化信息
    chrono::steady_clock::time_point t1 = chrono::steady_clock::now();
    ceres::Solve(options, &problem, &summary);  //开始优化
    chrono::steady_clock::time_point t2 = chrono::steady_clock::now();
    chrono::duration<double> time_used = chrono::duration_cast<chrono::duration<double>>(t2 - t1);
    cout << "solve time cost = " << time_used.count() << " seconds." << endl;

    //输出结果
    cout << summary.BriefReport() << endl;
    cout << "estimated abc = ";
    for(auto a:abc) cout << a << " ";
    cout << endl;

    return 0;
}

解出参数

estimated abc = 0.890908 2.1719 0.943628

参考链接

  数据结构与算法 最新文章
【力扣106】 从中序与后续遍历序列构造二叉
leetcode 322 零钱兑换
哈希的应用:海量数据处理
动态规划|最短Hamilton路径
华为机试_HJ41 称砝码【中等】【menset】【
【C与数据结构】——寒假提高每日练习Day1
基础算法——堆排序
2023王道数据结构线性表--单链表课后习题部
LeetCode 之 反转链表的一部分
【题解】lintcode必刷50题<有效的括号序列
上一篇文章      下一篇文章      查看所有文章
加:2022-03-04 15:50:05  更:2022-03-04 15:52:28 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 2:16:29-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码