任务
用C++和GDAL实现K均值分类
原理
关于K均值算法的帖子很多,原理也不复杂:KMeans 算法
代码
程序结构:
主函数 main.cpp
#include <iostream>
#include <gdal.h>
#include "gdal_priv.h"
#include "func.h"
using namespace std;
int main()
{
GDALAllRegister();
char ImgPath[] = "D:\\Practice\\org\\exp8_Img.tif";
char SavePath[] = "D:\\Practice\\res\\exp8_classification.tif";
GDALDataset* ImgSet = (GDALDataset*)GDALOpen(ImgPath, GA_ReadOnly);
int mX = ImgSet->GetRasterXSize();
int mY = ImgSet->GetRasterYSize();
GDALDataType mDataType = ImgSet->GetRasterBand(1)->GetRasterDataType();
unsigned char* Reslut = new unsigned char[mX*mY];
int Num = 3;
mKmeans(ImgSet, Reslut, Num);
GDALDriver* hDriver = GetGDALDriverManager()->GetDriverByName("GTiff");
GDALDataset* mSaveSet = hDriver->Create(SavePath, mX, mY, 1, GDT_Byte, NULL);
double geoInformation[6];
ImgSet->GetGeoTransform(geoInformation);
const char* gdalProjection = ImgSet->GetProjectionRef();
mSaveSet->SetGeoTransform(geoInformation);
mSaveSet->SetProjection(gdalProjection);
mSaveSet->GetRasterBand(1)->RasterIO(GF_Write, 0, 0, mX, mY, Reslut, mX, mY, GDT_Byte, 0, 0);
delete[] Reslut;
GDALClose(ImgSet);
GDALClose(mSaveSet);
GDALDestroyDriverManager();
cout << "程序运行完毕!!!" << endl;
getchar();
return 0;
}
头文件 func.h
#pragma once
#include "gdal.h"
#include "gdal_priv.h"
#include <iostream>
using namespace std;
bool mKmeans(GDALDataset* mImgSet, unsigned char* Reslut, int Num);
函数文件func.cpp
#include "func.h"
bool mKmeans(GDALDataset* mImgSet, unsigned char* Reslut, int Num)
{
int mX = mImgSet->GetRasterXSize();
int mY = mImgSet->GetRasterYSize();
GDALDataType mDataType = mImgSet->GetRasterBand(1)->GetRasterDataType();
GDALRasterBand* mBand = mImgSet->GetRasterBand(1);
unsigned char* mBand_ = new unsigned char[mX*mY*mDataType];
mBand->RasterIO(GF_Read, 0, 0, mX, mY, mBand_, mX, mY, mDataType, 0, 0);
unsigned char* label = new unsigned char[mX*mY];
double* Center = new double[Num];
for (int i = 0; i < Num; i++)
{
Center[i] = i * 255 / (Num - 1);
}
double* NewCenter = new double[Num];
memset(NewCenter, 0, sizeof(double));
int iter = 0;
int max_iter = 50;
while (true)
{
cout << "Center[0]= " << Center[0] << " Center[1]= " << Center[1] << " Center[2]= " << Center[2] << endl;
double* cnt = new double[Num];
memset(cnt, 0, sizeof(double));
double* sum = new double[Num];
memset(sum, 0, sizeof(double));
double* d = new double[Num];
memset(d, 0, sizeof(double));
for (int i = 0; i < mX*mY; i++)
{
double minDist = 255;
int tag = -1;
for (int j = 0; j < Num; j++)
{
d[j] = abs(double(mBand_[i]) - Center[j]);
if (d[j] <= minDist) { minDist = d[j], tag = j; }
}
sum[tag] += double(mBand_[i]);
cnt[tag] ++;
Reslut[i] = 255.0 * tag / (Num - 1);
}
for (int j = 0; j < Num; j++)
{
if (cnt[j] == 0) { cout << "错误!" << endl; }
else { NewCenter[j] = sum[j] / cnt[j]; }
}
bool flag = 1;
for (int j = 0; j < Num; j++)
{
if (NewCenter[j] != Center[j]) { flag = 0; }
Center[j] = NewCenter[j];
}
cout << endl;
if (flag == 1) { cout << "分类成功,迭代次数为 " << iter << endl; break; }
if (iter == max_iter) { cout << "达到最大循环次数,看来是不收敛了!!!" << endl; break;}
iter++;
}
return true;
}
结果展示
待分类影像: K均值聚类结果:
结语
K均值本身不复杂,编写过程中只需要注意循环体的设计就可以了。
至此,编程练习的全部内容结束了,总共八个实验,有难有易,虽说为了学习GDAL的使用,但这些编程练习涉及到的GDAL的东西主还是数据的读写,很多函数都没涉及到,只能等以后需要用到的时候再临时学了。
|