三点二次插值法代码
#include <iostream>
#include <cmath>
#include <random>
#include <ctime>
int SEED = 0; // 用于设置不同的种子,防止产生相同的随机情况
// 课本P137第6题函数
double f(double t) {
return 1 - t * exp(- t * t);
}
// 课本P114例3.3.2
double f1(double t) {
return t*t*t - 3*t + 2;
}
// 产生区间内的随机数
double get_alpha(double a, double b) {
SEED++; // 每调用一次函数改变一次种子,防止产生相同的随机情况
std::default_random_engine e; // 新建随机数引擎对象
e.seed(SEED); // 撒下种子
// 在此确定随机数区间
std::uniform_real_distribution<double> u(a,b); // 左闭右闭区间[a,b]
double alpha = u(e);
return alpha;
}
// 确定3个alpha,参数传入三个alpha的指针地址
void get3alpha(double(*f)(double), double a, double b, double *alpha1, double *alpha2, double *alpha3, double *f1, double *f2, double *f3) {
while (true) {
*alpha1 = get_alpha(a, b);
*alpha2 = get_alpha(*alpha1, b);
*alpha3 = get_alpha(*alpha2, b);
*f1 = f(*alpha1);
*f2 = f(*alpha2);
*f3 = f(*alpha3);
if (*f1 > *f2 && *f3 > *f2) break;
}
}
/*!
* 三点二次插值法
* @param f
* @param a
* @param b
* @return
*/
double three_point_interpolation(double(*f)(double), double a, double b) {
int iteration = 0;
double epsilon1 = 0.001;
double epsilon2 = 0.00001;
// step 0
double alpha1, alpha2, alpha3, f1, f2, f3;
get3alpha(f, a, b, &alpha1, &alpha2, &alpha3, &f1, &f2, &f3);
double alpha_hat, f_hat = 0;
while (true) {
// step 1
alpha_hat = 0.5 * ( (alpha2 * alpha2 - alpha3 * alpha3) * f1 + (alpha3 * alpha3 - alpha1 * alpha1) * f2 + (alpha1 * alpha1 - alpha2 * alpha2) * f3 ) / ( (alpha2 - alpha3) * f1 + (alpha3 - alpha1) * f2 + (alpha1 - alpha2) * f3 );
f_hat = f(alpha_hat);
// step 2
if (alpha_hat > alpha2) {
// step 3
if (f_hat <= f2) {
alpha1 = alpha2;
alpha2 = alpha_hat;
f1 = f2;
f2 = f_hat; // turn to step 5
} else {
alpha3 = alpha_hat;
f3 = f_hat; // turn to step 5
}
} else {
// step 4
if (f_hat <= f2) {
alpha3 = alpha2;
alpha2 = alpha_hat;
f3 = f2;
f2 = f_hat; // turn to step 5
} else {
alpha1 = alpha_hat;
f1 = f_hat; // turn to step 5
}
}
// step 5
if (std::abs(f2) >= epsilon2) {
if (std::abs(f2 - f_hat) <= epsilon1 * std::abs(f2)) break;
} else {
if (std::abs(f2 - f_hat) <= epsilon1) break;
}
iteration++; // test
std::cout << iteration << std::endl; // test
}
// step 5
if (f_hat < f2) return f_hat;
else return alpha2;
}
int main() {
double x = three_point_interpolation(f, 0, 1);
std::cout << "x* = " << x << std::endl;
std::cout << "f(x*) = " << f(x) << std::endl;
return 0;
}
结果:
|