IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 用python实现经验模态分解+小波软阈值去噪 -> 正文阅读

[人工智能]用python实现经验模态分解+小波软阈值去噪

PyEmd模块安装

试过很多博主说的pip insyall PyEmd都失败了,偶然间运气好发现正确的安装方式是pip install PyEmd-signal。如果找不到相关的库或者模块,直接去github上去搜索,上面有很详细的安装教程,不要被误导

pywt模块安装

pywt可以实现小波分解与重构,小波阈值降噪,小波包分解等功能,同样安装也是用相应的pip instal pywt来进行安装,如果找不到还是去github上寻找。

特别说明

关于EMD类方法和小波阈值降噪的相关理论知识可直接百度或者在知网找几篇硕博论文来看,里面有详细的推导过程。

不要祈求能推导明白相关公式,这些公式很复杂一边人理解不了,个人建议了解相关算法流程就可以了,相关的优化也是不同方法之间的排列组合,还是要多多尝试。

以上只是个人看法,不喜欢可以划走,别在这体现自己的优越感,我写这文章的目的就是单纯的记录!!!!

代码实现

由于写代码的时候,个人的理论了解程度仅仅停留在入门阶段,IMF分量(EMD类方法分解得到的分量)的选择是凭借个人感觉来选择的,正确的做法是计算多尺度排列熵(github上可以找到相关的模块,个人正在研究相关代码)、相关系数等来进行选择。

个人的玩具代码,有很多不严谨的地方。

EMD类方法实现

import numpy as np
from PyEMD import EMD,EEMD,CEEMDAN,Visualisation
from threshold import Threshold
from matplotlib import pyplot as plt
import pywt
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']

def read_txt_file(input_file_path):
    """该函数主要用来从txt文件中读取所需数据,并转换数据类型
    输入为待处理文件的路径
    输出为一个存放txt文件数据的列表"""
    file_list = []
    file = open(input_file_path)
    file_lines = file.readline()
    file_lines = list(file_lines)
    file_lines.pop(0)
    file_lines.pop(-1)
    file_lines = ''.join(file_lines)
    cur = file_lines.strip().split(",")
    for i in range(0,len(cur)):
        file_list.append(float(cur[i]))
    #print(file_list)
    return file_list

class EmdFunction:
    """
    emd方法的调用
    """
    def __init__(self,data,function_name,sym,level,imfs_start_step,thr_select,thr_way):

        X = np.array(data)
        self.signal = (X - np.mean(X)) / np.std(X)
        self.function_name = function_name
        self.soft_threshold = Threshold(data,thr_select,sym,level,thr_way)
        self.imfs_start_step = imfs_start_step


    def emd_completed(self):

        if self.function_name == 'EMD':
            emd = EMD()
            emd.emd(self.signal)
            ims,res = emd.get_imfs_and_residue()
            return ims,res
        elif self.function_name == 'EEMD':
            eemd = EEMD()
            eemd.eemd(self.signal)
            ims,res = eemd.get_imfs_and_residue()
            return ims,res
        elif self.function_name == 'CEEMDAN':
            ceemdan = CEEMDAN()
            ceemdan.ceemdan(self.signal)
            ims,res = ceemdan.get_imfs_and_residue()
            return ims,res

    def plot_imfs_and_res(self,imfs,res):
        t = np.arange(0,len(self.signal),1)
        vis = Visualisation()
        vis.plot_imfs(imfs=imfs,residue=res,t=t,include_residue=True)
        vis.show()

    def wavelet_and_emd(self):

        useful_imfs_add = np.zeros(len(self.signal)).tolist()
        imfs,res = self.emd_completed()
        for i in range(self.imfs_start_step,len(imfs)):
            useful_imfs_add += imfs[i]

        for j in range(0,self.imfs_start_step):
            data = imfs[j]
            #mid_param = self.soft_threshold.wavelet_dec_rec(data)
            mid_param = self.soft_threshold.wavelet_dec_rec(data)
            useful_imfs_add  += mid_param
        return useful_imfs_add

    def plot_org_sotfthreshold(self):

        end_signal = self.wavelet_and_emd()
        snr = self.soft_threshold.compute_snr(self.signal,end_signal)
        rmse = self.soft_threshold.compute_mse(self.signal,end_signal)
        print('snr:{} , rmse:{}'.format(snr, rmse))
        figure,(ax1,ax2) = plt.subplots(nrows=2,ncols=1)
        ax1.plot(self.signal, label='org signal')
        ax1.set_title('降噪前的信号')
        ax1.legend()

        ax2.plot(end_signal, 'g',label='after wavele')
        ax2.set_title('降噪后的信号')
        ax2.legend()
        plt.show()



if __name__ == "__main__":
    path = 'D:\桌面文件夹\数据文件/0001001228.txt'
    #path = 'D:\桌面文件夹\新建文件夹 (3)/1101000158.txt'

    data = read_txt_file(path)

    function_name = 'CEEMDAN'

    sym = 'sym8'
    level = 3
    imfs_start_step = 4
    thr_select = 'sqtwolog'
    thr_way = 'soft'

    emdfunction = EmdFunction(data,function_name,sym,level,imfs_start_step,thr_select,thr_way)
    # ims,res=emdfunction.emd_completed()
    # emdfunction.plot_imfs_and_res(ims,res)
    emdfunction.plot_org_sotfthreshold()


小波软阈值代码实现

import numpy as np
import os
from matplotlib import pyplot as plt
import pywt
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']


def read_txt_file(input_file_path):
    """该函数主要用来从txt文件中读取所需数据,并转换数据类型
    输入为待处理文件的路径
    输出为一个存放txt文件数据的列表"""
    file_list = []
    file = open(input_file_path)
    file_lines = file.readline()
    file_lines = list(file_lines)
    file_lines.pop(0)
    file_lines.pop(-1)
    file_lines = ''.join(file_lines)
    cur = file_lines.strip().split(",")
    for i in range(0,len(cur)):
        file_list.append(float(cur[i]))
    #print(file_list)
    return file_list

class Threshold:
    """小波阈值降噪"""
    def __init__(self,data,thr_select,wave_bais='sym8',level=3,thr_way='soft'):

        if type(data) == list:
            X = np.array(data)
            self.data = (X - np.mean(X)) / np.std(X)
        else:
            self.data = (data - np.mean(data)) / np.std(data)

        self.data = data

        self.wave_bais = wave_bais
        self.level = level
        self.thr_way = thr_way
        if thr_select in ['rigrsure','heursure','sqtwolog','minimaxi']:
            self.thr_select = thr_select
        else:
            raise print('取值计算函数名称错误,请重新输入')

    def thrselect(self,data):
        """
        阈值lambda的计算方式选择
        :return: 返回阈值
        """
        N = len(data)
        if self.thr_select == 'sqtwolog':
            #固定阈值
            thr = round(np.sqrt(2.0 * np.log(N)),4)
            return thr

        elif self.thr_select == 'minimaxi':
            #极大极小阈值
            if N<32:
                thr =0
            else:
                thr = 0.3936 + 0.1829*(np.log(N)/np.log(2))
            return thr

        elif self.thr_select == 'rigrsure':
            # #风险阈值
            # sx = np.sort(abs(self.data))
            # sx2 = np.square(sx)
            # N1 = np.repeat(N-2*[i for i in range(0,N)],1)
            pass
            return -1

        elif self.thr_select == 'heursure':
            pass
            return -1

    def wavelet_dec_rec(self,data):
        """小波分解"""
        coffe = pywt.wavedec(data,self.wave_bais,level=self.level)
        #低频分量分量
        ca = coffe[0]
        #高频分量
        cd_out_list = []
        cd_out_list.append(ca)
        #阈值
        thr = self.thrselect(data)
        for i in range(1,len(coffe)):
            cd = coffe[i]
            ysotf = pywt.threshold(cd,thr,self.thr_way)
            cd_out_list.append(ysotf)

        Y = pywt.waverec(cd_out_list,self.wave_bais)
        return Y

    def plot_signal(self,data):
        #获得降噪后的信号
        Y = self.wavelet_dec_rec(data)

        #绘制原始图像
        figure, axes = plt.subplots(2, 1)
        ax1 = axes[0]
        ax1.set_title('降噪前的信号')
        ax1.plot(self.data)
        #绘制降噪的图像
        ax2 = axes[1]
        ax2.set_title('降噪后的信号')
        ax2.plot(Y,color='g')
        plt.show()

    @staticmethod
    def compute_snr(org_signal, final_signal):
        """
        信噪比:信噪比越大越好
        均方根误差:均方根误差越小越好,越小去噪效果越好
        :param org_signal:原始信号
        :param final_signal:降噪后的信号
        :return: 信噪比,均方根误差
        """

        clean = np.array(final_signal)
        org_signal = np.array(org_signal)
        #est_noise = org_signal - clean
        # power_data = np.mean(np.square(data))
        # power_noise = np.mean(np.square(data - final_signal))
        #snr = 10 * np.log10((np.sum(clean ** 2)) / (np.sum(est_noise ** 2)))
        # snr = (math.log((power_data/power_noise),10) )* 10
        sigPower = sum(abs(clean) ** 2) / len(clean)  # 求出信号功率
        noisePower = sum(abs(org_signal - clean) ** 2) / len(org_signal - clean)  # 求出噪声功率
        SNR_10 = 10 * np.log10(sigPower / noisePower)
        #SNR_10 = (sigPower / noisePower)
        return SNR_10

    @staticmethod
    def compute_mse(org_signal, final_signal):
        """
        计算均方根误差:均方根误差越小越好,越小去噪效果越好
        :param org_signal:原始信号
        :param final_signal:降噪后的信号
        :return: 均方根误差
        """
        data = np.array(org_signal)
        final_signal = np.array(final_signal)
        rmse = np.sqrt(np.mean(np.square(data - final_signal)))
        return rmse




if __name__ == "__main__":
    #path = 'D:\桌面文件夹\新建文件夹 (3)/1101000158.txt'
    path = 'D:\桌面文件夹\数据文件/0001001228.txt'

    data = read_txt_file(path)
    data = (data - np.mean(data)) / np.std(data)

    thr_select = 'sqtwolog'
    wave_bais = 'sym8'
    level = 3
    thr_way = 'soft'

    wave = Threshold(data,thr_select,wave_bais,level,thr_way)
    Y = wave.wavelet_dec_rec(data)

    snr = wave.compute_snr(wave.data,Y)
    rmse = wave.compute_mse(wave.data,Y)
    print('snr:{},rmse:{}'.format(snr,rmse))

    wave.plot_signal(data)






  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-01 23:22:47  更:2022-04-01 23:26:24 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 0:22:21-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码