IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Java学习第52-53天:KNN分类器(续) -> 正文阅读

[人工智能]Java学习第52-53天:KNN分类器(续)

第52天

主要任务:

1.重新实现 computeNearests, 仅需要扫描一遍训练集, 即可获得 k kk 个邻居. 提示: 现代码与插入排序思想相结合.
2.增加 setDistanceMeasure() 方法.
3.增加 setNumNeighors() 方法.

代码分别如下:

 /**
     * @Description: 计算最近的k个邻居。在每一轮扫描中选择一个邻居
     * @Param: [paraIndex]
     * @return: int[]
     */
    public int[] computeNearests(int paraCurrent) {
        int[] resultNearests = new int[numNeighbors];
        boolean[] tempSelected = new boolean[trainingSet.length];
        double tempDistance;
        double tempMinimalDistance;
        int tempMinimalIndex = 0;

        /*//选择最近的k个索引
        for (int i = 0; i < numNeighbors; i++) {
            tempMinimalDistance = Double.MAX_VALUE;

            for (int j = 0; j < trainingSet.length; j++) {
                if (tempSelected[j]) {
                    continue;
                }

                tempDistance = distance(paraCurrent, trainingSet[j]);
                if (tempDistance < tempMinimalDistance) {
                    tempMinimalDistance = tempDistance;
                    tempMinimalIndex = j;
                }
            }

            resultNearests[i] = trainingSet[tempMinimalIndex];
            tempSelected[tempMinimalIndex] = true;
        }*/

        //使用直接插入排序
        //创建一个临时二维数组去存储距离
        double[][] tempDistanceArray = new double[trainingSet.length][2];
        tempDistanceArray[0][0] = 0;
        tempDistanceArray[0][1] = distance(paraCurrent, trainingSet[0]);
        int j;
        for (int i = 1; i < trainingSet.length; i++) {
            tempDistance = distance(paraCurrent, trainingSet[i]);
            for (j = i - 1; j >= 0; j--) {
                if (tempDistance < tempDistanceArray[j][1]) {
                    tempDistanceArray[j + 1] = tempDistanceArray[j];
                } else {
                    break;
                }
            }
            tempDistanceArray[j + 1][0] = i;
            tempDistanceArray[j + 1][1] = tempDistance;
        }

        for (int i = 0; i < numNeighbors; i++) {
            resultNearests[i] = trainingSet[(int)tempDistanceArray[i][0]];
        }

        System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
        return resultNearests;
    }
 /**
     * @Description: 选择距离计算方式
     * @Param: [paraType:0 or 1]
     * @return: void
     */
    public void setDistanceMeasure(int paraType) {
        if (paraType == 0) {
            distanceMeasure = MANHATTAN;
        } else if (paraType == 1) {
            distanceMeasure = EUCLIDEAN;
        } else {
            System.out.println("Wrong Distance Measure!!!");
        }
    }

    public static void main(String[] args) {
        KnnClassification tempClassifier = new KnnClassification("D:\\JAVA\\学习\\data_set\\iris.arff");
        tempClassifier.setDistanceMeasure(1);
        tempClassifier.splitTrainingTesting(0.8);
        tempClassifier.predict();
        System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
    }
/**
     * @Description: 设置邻居数量
     * @Param: [paraNumNeighbors]
     * @return: void
     */
    public void setNumNeighbors(int paraNumNeighbors) {
        if (paraNumNeighbors > dataset.numInstances()) {
            System.out.println("The number of neighbors is bigger than the number of dataset!!!");
            return;
        }

        numNeighbors = paraNumNeighbors;
    }

    public static void main(String[] args) {
        KnnClassification tempClassifier = new KnnClassification(":\\JAVA\\学习\\data_set\\iris.arff"");
        tempClassifier.setDistanceMeasure(1);
        tempClassifier.setNumNeighbors(8);
        tempClassifier.splitTrainingTesting(0.8);
        tempClassifier.predict();
        System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
    }

运行结果?

?

?

第53天?

主要任务:

  1. 增加 weightedVoting() 方法, 距离越短话语权越大. 支持两种以上的加权方式.
  2. 实现 leave-one-out 测试

?代码分别如下:

 /**
     * @Description: 距离权重投票
     * 距离越短,话语权越大,直接设置个权值根据顺序来增强其关联性
     * @Param: [paraCurrent, paraNeighbors]
     * @return: int
     */
    public int weightedVoting(int paraCurrent, int[] paraNeighbors) {
        //numClasses?? 计算每一个类型出现的次数
        double[] tempVotes = new double[dataset.numClasses()];

        double tempDistance;
        int a = 1, b = 1;
        //这样距离越短则值就越大
        for (int i = 0; i < paraNeighbors.length; i++) {
            tempDistance = distance(paraCurrent, paraNeighbors[i]);
            tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]
                    += getWeightedNum(a, b, tempDistance);
        }

        int tempMaximalVotingIndex = 0;
        double tempMaximalVoting = 0;
        for (int i = 0; i < dataset.numClasses(); i++) {
            if (tempVotes[i] > tempMaximalVoting) {
                tempMaximalVoting = tempVotes[i];
                tempMaximalVotingIndex = i;
            }
        }

        return tempMaximalVotingIndex;
    }

    /**
     * @Description: 获取权重,利用反函数
     * @Param: [a, b, paraDistance]
     * @return: double
     */
    public double getWeightedNum(int a, int b, double paraDistance) {
        return b / (paraDistance + a);
    }
/** 
    * @Description: 留一交叉验证
     * 留一法交叉验证是一种用来训练和测试分类器的方法,会用到图像数据集里所有的数据,假定数据集有N个样本(N1、N2、...Nn),
     * 将这个样本分为两份,第一份N-1个样本用来训练分类器,另一份1个样本用来测试,
     * 如此从N1到Nn迭代N次,所有的样本里所有对象都经历了测试和训练。
    * @Param: []
    * @return: void
    */
    public void leave_one_out() {
        int tempSize = dataset.numInstances();
        int[] tempIndices = getRandomIndices(tempSize);
        double tempCorrect = 0;
        for (int i = 0; i < tempSize; i++) {
            trainingSet = new int[tempSize - 1];
            testingSet = new int[1];

            int tempIndex = 0;
            for (int j = 0; j < tempSize; j++) {
                if (j == i) {
                    continue;
                }
                trainingSet[tempIndex++] = tempIndices[j];
            }

            testingSet[0] = tempIndices[i];

            this.predict();

            if (predictions[0] == dataset.instance(testingSet[0]).classValue()) {
                tempCorrect++;
            }
        }

        System.out.println("正确率为:" + tempCorrect / tempSize);
    }

    public static void main(String[] args) {
        KnnClassification tempClassifier = new KnnClassification(":\\JAVA\\学习\\data_set\\iris.arff"");
        tempClassifier.setDistanceMeasure(1);
        tempClassifier.setNumNeighbors(8);
        tempClassifier.splitTrainingTesting(0.8);
        tempClassifier.predict();
        System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());


        System.out.println("\r\n-------leave_one_out-------");
        tempClassifier.leave_one_out();
    }

?运行结果如下:

?

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-14 10:51:52  更:2021-07-14 10:57:02 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年4日历 -2024/4/28 16:56:45-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码