IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch处理CK+数据集 -> 正文阅读

[人工智能]pytorch处理CK+数据集

CK+数据集介绍

?CK+数据库是在 Cohn-Kanade Dataset 的基础上扩展来的,包含表情的label和Action Units 的label。这个数据库包括123个subjects, 593 个 image sequence,每个image sequence的最后一张 Frame 都有action units 的label,而在这593个image sequence中,有327个sequence 有 emotion的 label。这个数据库是人脸表情识别中比较流行的一个数据库,很多文章都会用到这个数据做测试。

?CK+数据库一共有4个压缩文件:

  • extended-cohn-kanade-images.zip:共有123个受试者的593个序列,在峰值帧进行FACS编码。所有序列都是从中性面到峰值表情。
  • Landmarks.zip:所有序列都是AAM跟踪的,每个图像有68个点的Landmarks。
  • FACS_labels.zip:对于每个序列(593),只有1个FACS文件,这是最后一帧(峰值帧)。文件的每一行对应一个特定的AU,然后是强度。
  • Emotion_labels.zip:593个序列中只有327个具有情感序列,情绪类别分别是0=neutral, 1=anger, 2=contempt, 3=disgust, 4=fear, 5=happy, 6=sadness, 7=surprise。

百度云链接:https://pan.baidu.com/s/182ZigVgfhmO-YnLy3ip1dQ
百度云密码:CKCK
在这里插入图片描述

CK+数据集存为.h5文件

?首先对CK+数据集切割后的图片文件保存为为.h5文件,具体代码如下所示:
在这里插入图片描述

#create data and label for CK+
#0=anger,1=disgust,2=fear,3=happy,4=sadness,5=surprise,6=contempt
#contain 135,177,75,207,84,249,54 images
import csv
import os
import numpy as np
import h5py
import skimage.io

ck_path = r'Dataset\CK+'

anger_path = os.path.join(ck_path, 'anger')
disgust_path = os.path.join(ck_path, 'disgust')
fear_path = os.path.join(ck_path, 'fear')
happy_path = os.path.join(ck_path, 'happy')
sadness_path = os.path.join(ck_path, 'sadness')
surprise_path = os.path.join(ck_path, 'surprise')
contempt_path = os.path.join(ck_path, 'contempt')

## Create the list to store the data and label information
data_x = []
data_y = []

datapath = os.path.join('H5File','CK+.h5')
if not os.path.exists(os.path.dirname(datapath)):
    os.makedirs(os.path.dirname(datapath))
    
## order the file, so the training set will not contain the test set (don't random)
files = os.listdir(anger_path)
files.sort()
for filename in files:
    I = skimage.io.imread(os.path.join(anger_path,filename))
    data_x.append(I.tolist())
    data_y.append(0)

files = os.listdir(disgust_path)
files.sort()
for filename in files:
    I = skimage.io.imread(os.path.join(disgust_path,filename))
    data_x.append(I.tolist())
    data_y.append(1)

files = os.listdir(fear_path)
files.sort()
for filename in files:
    I = skimage.io.imread(os.path.join(fear_path,filename))
    data_x.append(I.tolist())
    data_y.append(2)

files = os.listdir(happy_path)
files.sort()
for filename in files:
    I = skimage.io.imread(os.path.join(happy_path,filename))
    data_x.append(I.tolist())
    data_y.append(3)

files = os.listdir(sadness_path)
files.sort()
for filename in files:
    I = skimage.io.imread(os.path.join(sadness_path,filename))
    data_x.append(I.tolist())
    data_y.append(4)

files = os.listdir(surprise_path)
files.sort()
for filename in files:
    I = skimage.io.imread(os.path.join(surprise_path,filename))
    data_x.append(I.tolist())
    data_y.append(5)

files = os.listdir(contempt_path)
files.sort()
for filename in files:
    I = skimage.io.imread(os.path.join(contempt_path,filename))
    data_x.append(I.tolist())
    data_y.append(6)

print(np.shape(data_x))
print(np.shape(data_y))

datafile = h5py.File(datapath,'w')
datafile.create_dataset("data_pixel", dtype='uint8', data = data_x)
datafile.create_dataset("data_label", dtype='int64', data = data_y)
datafile.close()
print('Save data finish!!!')

数据存储为torch.utils.data类型

?将数据集转化为torch.utils.data类型的数据,具体的代码如下所示:

from __future__ import print_function
from PIL import Image
import numpy as np 
import h5py
import torch.utils.data as data

class CK(data.Dataset):
	'''

	Args:
		train (bool, optional): If True, creates dataset from training set, otherwise
			creates from test set.
		transform (callable, optional): A function/transforms that takes in an PIL image
			and returns a transformed version.
		there are 135, 177, 75, 207, 84, 249, 54 images in data;
		we choose 123, 159, 66, 186, 75, 225, 48 images for training;
		we choose 12, 8, 9, 21, 9, 24, 6 images for testing;
		the split are in order according to the fold number.

	'''		

	def __init__(self, split='Training', fold=1, transform=None):
		self.transform = transform
		self.split = split # training set or test set
		self.fold = fold # the k-fold cross validation
		self.data = h5py.File('./H5File/CK+.h5','r',driver='core')

		number = len(self.data['data_label']) # 981
		sum_number = [0, 135, 312, 387, 594, 678, 927, 981] # the sum of class number
		test_number = [12, 18, 9, 21, 9, 24, 6] # the number of each class

		test_index = []
		train_index = []

		for j in range(len(test_number)):
			for k in range(test_number[j]):
				if self.fold != 10: #the last fold start from the last element
					test_index.append(sum_number[j]+(self.fold-1)*test_number[j]+k)
				else:
					test_index.append(sum_number[j+1]-1-k)		

		for i in range(number):
			if i not in test_index:
				train_index.append(i)


		#now load the picked numpy arrays
		if self.split == 'Training':
			self.train_data = []
			self.train_labels = []
			for ind in range(len(train_index)):
				self.train_data.append(self.data['data_pixel'][train_index[ind]])
				self.train_labels.append(self.data['data_label'][train_index[ind]])


		elif self.split == 'Testing':
			self.test_data = []
			self.test_labels = []
			for ind in range(len(test_index)):
				self.test_data.append(self.data['data_pixel'][test_index[ind]])
				self.test_labels.append(self.data['data_label'][test_index[ind]])


	def __getitem__(self, index):
		if self.split == 'Training':
			img, target = self.train_data[index], self.train_labels[index]
		elif self.split == 'Testing':
			img, target = self.test_data[index], self.test_labels[index]
		img = img[:,:,np.newaxis]
		img = np.concatenate((img,img,img),axis=2)
		img = Image.fromarray(img)
		if self.transform is not None:
			img = self.transform(img)
		return img, target

	def __len__(self):
		if self.split == 'Training':
			return len(self.train_data)
		elif self.split == 'Testing':
			return len(self.test_data)

if __name__ == '__main__':
	import transforms as transforms
	transform_train = transforms.Compose([
		transforms.ToTensor(),
	])
	# transform_train = None
	data = CK(split = 'Training', transform = transform_train)
	for i in range(3):
		print(data.__getitem__(i))
	print(data.__len__())
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-12 23:26:02  更:2021-10-12 23:29:34 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 12:32:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码