当前位置: 首页 > news >正文

PhysioNet2017分类的代码实现

PhysioNet2017数据集介绍可参考文章:https://wendy.blog.csdn.net/article/details/128686196。本文主要介绍利用PhysioNet2017数据集对其进行分类的代码实现。

目录

    • 一、数据集预处理
    • 二、训练
      • 2.1 导入数据集并进行数据裁剪
      • 2.2 划分训练集、验证集和测试集
      • 2.3 设置训练网络和结构
      • 2.4 开始训练
      • 2.5 查看训练结果
    • 三、测试

一、数据集预处理

首先需要进行数据集预处理。

train2017文件夹中存放相应的训练集,其中REFERENCE.csv文件存放分类结果。分类结果有四种,分别是:N(Normal,正常),A(AF,心房颤动),O(Other,其他节律),~(Noisy,噪声记录)

首先需要划分训练集、验证集和测试集:

# 加载数据集,默认80%训练集和20%测试集
def load_physionet(dir_path, test=0.2,vali=0, shuffle=True):"return train_X, train_y, test_X, test_y, valid_X, valid_y"if dir_path[-1]!='/': dir_path = dir_path+'/'ref = pd.read_csv(dir_path+'REFERENCE.csv',header=None) # 分类结果label_id = {'N':0, 'A':1, 'O':2, '~':3 }#Normal, AF, Other, NoisyX = []y = []test_X = Nonetest_y = Nonevalid_X = Nonevalid_y = Nonefor index, row in ref.iterrows():file_prefix = row[0]mat_file = dir_path+file_prefix+'.mat'hea_file = dir_path+file_prefix+'.hea'data = loadmat(mat_file)['val']data = data.squeeze()data = np.nan_to_num(data)data = data-np.mean(data)data = data/np.std(data)X.append( data )y.append( label_id[row[1]] )data_n = len(y)print(data_n)X = np.array(X)y = np.array(y)if shuffle:shuffle_idx = list(range(data_n))random.shuffle(shuffle_idx)X = X[shuffle_idx]y = y[shuffle_idx]valid_n = int(vali*data_n)  test_n = int(test*data_n)assert (valid_n+test_n <= data_n) , "Dataset has no enough samples!"if vali>0:valid_X = X[0:valid_n]valid_y = y[0:valid_n]if test>0:test_X = X[valid_n: valid_n+test_n]test_y = y[valid_n: valid_n+test_n]if vali>0 or test>0:X = X[valid_n+test_n: ]y = y[valid_n+test_n: ]#print('Train: %d, Test: %d, Validation: %d   (%s)'%((data_n-valid_n-test_n), test_n, valid_n, 'shuffled' if shuffle else 'unshuffled'))return np.squeeze(X), np.squeeze(y), np.squeeze(test_X), np.squeeze(test_y), np.squeeze(valid_X), np.squeeze(valid_y)

加载数据集并将其保存为mat文件:

def merge_data(dir_path, test=0.2, train_file='train',test_file='test',shuffle=True):train_X, train_y, test_X, test_y, _, _ = load_physionet(dir_path=dir_path, test=test, vali=0, shuffle=True) # 划分训练集、验证集和测试集# 数据集8528个记录  8528*0.8=6823,8528*0.2=1705train_data = {'data': train_X, 'label':train_y} # 6823test_data = {'data': test_X, 'label':test_y}    # 1705# 保存训练集和测试集为mat文件savemat(train_file,train_data)savemat(test_file, test_data)print("[!] Train set saved as %s"%(train_file))print("[!] Test set saved as %s"%(test_file))def main():parser = argparse.ArgumentParser()parser.add_argument('--dir',type=str,default='training2017',help='the directory of dataset')parser.add_argument('--test_set',type=float,default=0.2,help='The percentage of test set')args = parser.parse_args()merge_data(args.dir, test=args.test_set)if __name__=='__main__':main()

运行之后将PhysioNet2017心电图数据集保存为train.mat和test.mat。
在这里插入图片描述

二、训练

2.1 导入数据集并进行数据裁剪

时序数据都需要进行相应的数据裁剪。裁剪函数如下:

def cut_and_pad(X, cut_size):n = len(X)X_cut = np.zeros(shape=(n, cut_size))   # (6823,300*30)for i in range(n):data_len = X[i].squeeze().shape[0]  # 每个数据的长度# cut if too long / padd if too shortX_cut[i, :min(cut_size, data_len)] = X[i][0,  :min(cut_size, data_len)] # 每个长度裁剪为cut_size=9000个return X_cut

首先需要将处理后的数据集导入并进行数据裁剪。
训练集的数据尺寸为:(1, 6823);训练集的标签尺寸为:(1, 6823);【总数据量为8528个数据,训练集数据占比80%,即8528*80%=6823】
加载训练集train.mat,进行数据裁剪,裁剪长度为300x30=9000,即前9000个数据。代码如下:

training_set = loadmat('train.mat') # 加载训练集
X = training_set['data'][0]
y = training_set['label'][0].astype('int32')#cut_size_start = 300 * 3
cut_size = 300 * 30X = cut_and_pad(X, cut_size) 

裁剪后可以查看第一个数据的图像:
代码如下:

import matplotlib.pyplot as plt
plt.plot(range(cut_size),X[0])
plt.show()

效果图如下:
在这里插入图片描述

2.2 划分训练集、验证集和测试集

首先需要判断是否进行k折交叉验证,若进行k折交叉验证,下界为0上界为5(5折);若不进行k折交叉验证则下界为0上界为1(默认不进行交叉验证)。

# k-fold / train
if args.k_folder:low_border = 0high_border = 5F1_valid = np.zeros(5)
else:low_border = 0high_border = 1

然后利用get_sub_set函数根据是否进行交叉验证划分训练集和验证集,90%为训练集,10%为验证集。

# 划分训练集和验证集
def get_sub_set(X, y, k, K_folder_or_not):if not K_folder_or_not:     # Falsek_dataset_len = int(len(X) * 0.9)   # 6823*0.9=6140train_X = X[ : k_dataset_len]   # 6140train_y = y[ : k_dataset_len]valid_X = X[ k_dataset_len:]    # 683valid_y = y[ k_dataset_len:]else:k_dataset_len = int(len(X) / 5)if k == 0:valid_X = X[ : k_dataset_len ]valid_y = y[ : k_dataset_len ]train_X = X[ k_dataset_len :]train_y = y[ k_dataset_len :]else:print(k*k_dataset_len)valid_X = X[ k*k_dataset_len : (k+1)*k_dataset_len ]valid_y = y[ k*k_dataset_len : (k+1)*k_dataset_len ]train_X = np.concatenate((X[ : k*k_dataset_len] , X[(k+1)*k_dataset_len: ]), axis=0)train_y = np.concatenate((y[ : k*k_dataset_len] , y[(k+1)*k_dataset_len: ]), axis=0)return train_X, train_y, valid_X, valid_y

输出训练集长度和验证集长度查看信息。
在这里插入图片描述

2.3 设置训练网络和结构

网络架构利用ResNet实现,损失函数使用交叉熵损失函数softmax_cross_entropy,优化器利用Adam优化器实现。

加载模型时,如果有已经训练好的模型,则恢复模型:Model restored from checkpoints;否则,重新训练模型:Restore failed, training new model!

2.4 开始训练

开始训练代码如下:

    # 开始训练while True:total_loss = []ep = ep + 1for itr in range(0,len(train_X),batch_size):# prepare data batchif itr+batch_size>=len(train_X):cat_n = itr+batch_size-len(train_X)cat_idx = random.sample(range(len(train_X)),cat_n)batch_inputs = np.concatenate((train_X[itr:],train_X[cat_idx]),axis=0)batch_labels = np.concatenate((y_onehot[itr:],y_onehot[cat_idx]),axis=0)else:batch_inputs = train_X[itr:itr+batch_size]        batch_labels = y_onehot[itr:itr+batch_size]_, summary, cur_loss = sess.run([opt, merge, loss], {data_input: batch_inputs, label_input: batch_labels})total_loss.append(cur_loss)#if itr % 10==0:#    print('   iter %d, loss = %f'%(itr, cur_loss))#    saver.save(sess, args.ckpt)# 将所有日志写入文件summary_writer.add_summary(summary, global_step=ep)  # 将训练过程数据保存在summary中[train_loss]print('[*] epoch %d, average loss = %f'%(ep, np.mean(total_loss)))if not args.k_folder:saver.save(sess, 'checkpoints/model')# validationif ep % 5 ==0: #and ep!=0:err = 0n = np.zeros(class_num)N = np.zeros(class_num)correct = np.zeros(class_num)valid_n = len(valid_X)for i in range(valid_n):res = sess.run([logits], {data_input: valid_X[i].reshape(-1, cut_size,1)})# print(valid_y[i])# print(res)predicts  = np.argmax(res[0],axis=1)n[predicts] = n[predicts] + 1   N[valid_y[i]] = N[valid_y[i]] + 1if predicts[0]!= valid_y[i]:err+=1else:correct[predicts] = correct[predicts] + 1print("[!] %d validation data, accuracy = %f"%(valid_n, 1.0 * (valid_n - err)/valid_n))res = 2.0 * correct / (N + n)print("[!] Normal = %f, Af = %f, Other = %f, Noisy = %f" % (res[0], res[1], res[2], res[3]))print("[!] F1 accuracy = %f" % np.mean(2.0 * correct / (N + n)))if args.k_folder:F1_valid[k] = np.mean(res)if np.mean(total_loss) < 0.2 and ep % 5 == 0:# 保存内容summary_writer.close()# 将total_loss保存为csvtl = pd.DataFrame(data=total_loss)tl.to_csv('loss.csv')break

2.5 查看训练结果

利用tensorboard可以查看训练的loss损失,损失图像如下:
在这里插入图片描述
loss阈值设置为0.2,最后的准确率如下:
在这里插入图片描述

三、测试

训练完成后,开始测试。
首先需要将处理后的测试集导入并进行数据裁剪。
测试集的数据尺寸为:(1, 1705);测试集的标签尺寸为:(1, 1705);【总数据量为8528个数据,测试集数据占比20%,即8528*20%=1705】
加载测试集test.mat,进行数据裁剪,裁剪长度为300x30=9000,即前9000个数据。代码如下:

training_set = loadmat('test.mat')
X = training_set['data'][0]     # (1705,)
y = training_set['label'][0].astype('int32')    # (1705,)cut_size = 300 * 30
n = len(X)
X_cut = np.zeros(shape=(n, cut_size))
for i in range(n):data_len = X[i].squeeze().shape[0]X_cut[i, :min(cut_size, data_len)] = X[i][0, :min(cut_size, data_len)]
X = X_cut

然后将数据输入训练好的网络进行测试:

# reconstruct model
test_input = tf.placeholder(dtype='float32',shape=(None,cut_size,1))
res_net = ResNet(test_input, class_num=class_num)tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
sess = tf.Session(config=tf_config)sess.run(tf.global_variables_initializer())
saver =  tf.train.Saver(tf.global_variables())# restore model
if os.path.exists(args.check_point_folder + '/'):saver.restore(sess, args.check_point_folder + '/model')print('Model successfully restore from ' + args.check_point_folder + '/model')
else: print('Restore failed. No model found!')

测试结束后,需要查看测试准确率,F1-score等诸多指标,这里首先需要定义三个变量:

PreCount = np.zeros(class_num)  # 每种类型的预测数量
RealCount = np.zeros(class_num) # 每种类型的数量
CorrectCount = np.zeros(class_num)  # 每种类型预测正确数量

PreCount用于存放每种类型的预测结果,RealCount用于存放每种类型的数量,CorrectCount用于存放每种类型预测正确的数量。

最后查看所有结果,F1-score、Accuracy,Precision,Recall,Time结果如下:(这是loss为0.2时的结果)
在这里插入图片描述


ok,以上便是本文的全部内容了,如果想要获取完整代码,可以参考资源:https://download.csdn.net/download/didi_ya/87444631

如果想重新训练,请删除checkpoints文件夹内所有文件和logs文件夹内所有文件(不要删除logs文件夹)并重新运行train.py程序,若不删除,则继续使用之前模型训练,logs文件夹主要用于存放tensorboard可视化图像,若不删除重新运行程序,可能会重新生成可视化图像,影响效果。188行可以指定最终的loss,如果想精确度高,请将loss尽量调小。tensorflow版本:1.x。(我使用的是tensorflow1.15)
遇到任何问题欢迎私信咨询~

相关文章:

PhysioNet2017分类的代码实现

PhysioNet2017数据集介绍可参考文章&#xff1a;https://wendy.blog.csdn.net/article/details/128686196。本文主要介绍利用PhysioNet2017数据集对其进行分类的代码实现。 目录一、数据集预处理二、训练2.1 导入数据集并进行数据裁剪2.2 划分训练集、验证集和测试集2.3 设置训…...

正大期货本周财经大事抢先看

美国1月CPI、Fed 等央行官员谈话 美国1月超强劲的非农就业人口&#xff0c;让投资人开始上修对这波升息循环利率顶点的预测&#xff0c;也使本周二 (14 日) 的美国 1月 CPI 格外受关注。 介绍正大国际期货主账户对比国内期货的优势 ​第一点&#xff1a;权限都在主账户 例如…...

html+css综合练习一

文章目录一、小米注册页面1、要求2、案例图3、实现效果3.1、index.html3.2、style.css二、下午茶页面1、要求2、案例图3、index.html4、style.css三、法国巴黎页面1、要求2、案例图3、index.html4、style.css一、小米注册页面 1、要求 阅读下列说明、效果图&#xff0c;进行静…...

安装jdk8

目录标题一、下载地址&#xff08;一&#xff09;Linux下载&#xff08;二&#xff09;Win下载二、安装&#xff08;一&#xff09;Linux&#xff08;二&#xff09;Win三、卸载&#xff08;一&#xff09;Linux&#xff08;二&#xff09;Win一、下载地址 jdk8最新版 jdk8其他…...

二分法心得

原教程见labuladong 首先&#xff0c;我们建议左右区间全部用闭区间。那么第一个搜索区间&#xff1a;left0; rightlen-1; 进入while循环&#xff0c;结束条件是right<left。 然后求mid&#xff0c;如果nums[mid]的值比target大&#xff0c;说明target在左边&#xff0c;…...

Linux安装Docker完整教程

背景最近接手了几个项目&#xff0c;发现项目的部署基本上都是基于Docker的&#xff0c;幸亏在几年前已经熟悉的Docker的基本使用&#xff0c;没有抓瞎。这两年随着云原生的发展&#xff0c;Docker在云原生中的作用使得它也蓬勃发展起来。今天这篇文章就带大家一起实现一下在Li…...

备份基础知识

备份策略可包括&#xff1a;– 整个数据库&#xff08;整个&#xff09;– 部分数据库&#xff08;部分&#xff09;• 备份类型可指示包含以下项&#xff1a;– 所选文件中的所有数据块&#xff08;完全备份&#xff09;– 只限自以前某次备份以来更改过的信息&#xff08;增量…...

C++学习记录——팔 内存管理

文章目录1、动态内存管理2、内存管理方式operator new operator delete3、new和delete的实现原理1、动态内存管理 C兼容C语言关于内存分配的语法&#xff0c;而添加了C独有的东西。 //int* p1 (int*)malloc(sizeof(int));int* p1 new int;new是一个操作符&#xff0c;C不再需…...

Spring事务失效原因分析解决

文章目录1、方法内部调用2、修饰符3、非运行时异常4、try…catch捕获异常5、多线程调用6、同时使用Transactional和Async7、错误使用事务传播行为8、使用的数据库不支持事务9、是否开启事务支持在工作中&#xff0c;经常会碰到一些事务失效的坑&#xff0c;基于遇到的情况&…...

4个月的测试经验,来面试就开口要17K,面试完,我连5K都不想给他.....

2021年8月份我入职了深圳某家创业公司&#xff0c;刚入职还是很兴奋的&#xff0c;到公司一看我傻了&#xff0c;公司除了我一个测试&#xff0c;公司的开发人员就只有3个前端2个后端还有2个UI&#xff0c;在粗略了解公司的业务后才发现是一个从零开始的项目&#xff0c;目前啥…...

python学习之pyecharts库的使用总结

pyecharts官方文档&#xff1a;https://pyecharts.org//#/zh-cn/ 【1】Timeline 其是一个时间轴组件&#xff0c;如下图红框所示&#xff0c;当点击红色箭头指向的“播放”按钮时&#xff0c;会呈现动画形式展示每一年的数据变化。 data格式为DataFrame&#xff0c;数据如下图…...

【taichi】利用 taichi 编写深度学习算子 —— 以提取右上三角阵为例

本文以取 (bs, n, n) 张量的右上三角阵并展平为向量 (bs, n*(n1)//2)) 为例&#xff0c;展示如何用 taichi 编写深度学习算子。 如图&#xff0c;要把形状为 (bs,n,n)(bs,n,n)(bs,n,n) 的张量&#xff0c;转化为 (bs,n(n1)2)(bs,\frac{n(n1)}{2})(bs,2n(n1)​) 的向量。我们先写…...

二进制 k8s 集群下线 worker 组件流程分析和实践

文章目录[toc]事出因果个人思路准备实践当前 worker 节点信息将节点标记为不可调度驱逐节点 pod将 worker 节点从 k8s 集群踢出下线 worker 节点相关组件事出因果 因为之前写了一篇 二进制 k8s 集群下线 master 组件流程分析和实践&#xff0c;所以索性再写一个 worker 节点的缩…...

Bean的六种作用域

限定程序中变量的可用范围叫做作用域&#xff0c;Bean对象的作用域是指Bean对象在Spring整个框架中的某种行为模式~~ Bean对象的六种作用域&#xff1a; singleton&#xff1a;单例作用域&#xff08;默认&#xff09; prototype&#xff1a;原型作用域&#xff08;多例作用域…...

Http发展历史

1 缘起 有一次&#xff0c;听到有人在议论招聘面试的人员&#xff0c; 谈及应聘人员的知识深度&#xff0c;说&#xff1a;问了一些关于Http的问题&#xff0c;如Http相关结构、网络结构等&#xff0c; 然后又说&#xff0c;问没问相关原理、来源&#xff1f; 我也是有些困惑了…...

高级Java程序员必备的技术点,你会了吗?

很多程序员在入行之后的前一两年&#xff0c;快速学习到了做项目常用的各种技术之后&#xff0c;便进入了技术很难寸进的平台期。反正手里掌握的一些技术对于应付普通项目来说&#xff0c;足够用了。因此也会缺入停滞&#xff0c;最终随着年龄的增长&#xff0c;竞争力不断下降…...

【暴力量化】查找最优均线

搜索逻辑 代码主要以支撑概率和压力概率来判断均线的优劣 判断为压力&#xff1a; 当日线与测试均线发生金叉或即将发生金叉后继续下行 判断为支撑&#xff1a; 当日线与测试均线发生死叉或即将发生死叉后继续上行 判断结果的天数&#xff1a; 小于6日均线&#xff0c;用金叉或…...

Java读取mysql导入的文件时中文字段出现�??的乱码如何解决

今天在写程序时遇到了一个乱码问题&#xff0c;困扰了好久&#xff0c;事情是这样的&#xff0c; 在Mapper层编写了查询语句&#xff0c;然后服务处调用&#xff0c;结果控制器返回一堆乱码 然后查看数据源头处&#xff1a; 由重新更改解码的字符集&#xff0c;在数据库中是正…...

k8s核心概念—Pod Controller Service介绍——20230213

文章目录一、Pod1. pod概述2. pod存在意义3. Pod实现机制4. pod镜像拉取策略5. pod资源限制6. pod重启机制7. pod健康检查8. 创建pod流程9. pod调度二、Controller1. 什么是Controller2. Pod和Controller关系3. deployment应用场景4. 使用deployment部署应用&#xff08;yaml&a…...

Tensorflow的数学基础

Tensorflow的数学基础 在构建一个基本的TensorFlow程序之前&#xff0c;关键是要掌握TensorFlow所需的数学思想。任何机器学习算法的核心都被认为是数学。某种机器学习算法的策略或解决方案是借助于关键的数学原理建立的。让我们深入了解一下TensorFlow的数学基础。 Scalar 标…...

树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频

使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源&#xff1a; http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...

【HarmonyOS 5.0】DevEco Testing:鸿蒙应用质量保障的终极武器

——全方位测试解决方案与代码实战 一、工具定位与核心能力 DevEco Testing是HarmonyOS官方推出的​​一体化测试平台​​&#xff0c;覆盖应用全生命周期测试需求&#xff0c;主要提供五大核心能力&#xff1a; ​​测试类型​​​​检测目标​​​​关键指标​​功能体验基…...

基于Uniapp开发HarmonyOS 5.0旅游应用技术实践

一、技术选型背景 1.跨平台优势 Uniapp采用Vue.js框架&#xff0c;支持"一次开发&#xff0c;多端部署"&#xff0c;可同步生成HarmonyOS、iOS、Android等多平台应用。 2.鸿蒙特性融合 HarmonyOS 5.0的分布式能力与原子化服务&#xff0c;为旅游应用带来&#xf…...

三体问题详解

从物理学角度&#xff0c;三体问题之所以不稳定&#xff0c;是因为三个天体在万有引力作用下相互作用&#xff0c;形成一个非线性耦合系统。我们可以从牛顿经典力学出发&#xff0c;列出具体的运动方程&#xff0c;并说明为何这个系统本质上是混沌的&#xff0c;无法得到一般解…...

前端开发面试题总结-JavaScript篇(一)

文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包&#xff08;Closure&#xff09;&#xff1f;闭包有什么应用场景和潜在问题&#xff1f;2.解释 JavaScript 的作用域链&#xff08;Scope Chain&#xff09; 二、原型与继承3.原型链是什么&#xff1f;如何实现继承&a…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中&#xff0c;电磁频谱已成为继陆、海、空、天之后的 “第五维战场”&#xff0c;雷达作为电磁频谱领域的关键装备&#xff0c;其干扰与抗干扰能力的较量&#xff0c;直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器&#xff0c;凭借数字射…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)

上一章用到了V2 的概念&#xff0c;其实 Fiori当中还有 V4&#xff0c;咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务)&#xff0c;代理中间件&#xff08;ui5-middleware-simpleproxy&#xff09;-CSDN博客…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

【Redis】笔记|第8节|大厂高并发缓存架构实战与优化

缓存架构 代码结构 代码详情 功能点&#xff1a; 多级缓存&#xff0c;先查本地缓存&#xff0c;再查Redis&#xff0c;最后才查数据库热点数据重建逻辑使用分布式锁&#xff0c;二次查询更新缓存采用读写锁提升性能采用Redis的发布订阅机制通知所有实例更新本地缓存适用读多…...

MFC 抛体运动模拟:常见问题解决与界面美化

在 MFC 中开发抛体运动模拟程序时,我们常遇到 轨迹残留、无效刷新、视觉单调、物理逻辑瑕疵 等问题。本文将针对这些痛点,详细解析原因并提供解决方案,同时兼顾界面美化,让模拟效果更专业、更高效。 问题一:历史轨迹与小球残影残留 现象 小球运动后,历史位置的 “残影”…...