0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

基于KerasConv1D心电图检测开源教程

WpOh_rgznai100 来源:YXQ 2019-06-10 15:48 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

本实战内容取自笔者参加的首届中国心电智能大赛项目,初赛要求为设计一个自动识别心电图波形算法。笔者使用Keras框架设计了基于Conv1D结构的模型,并且开源了代码作为Baseline。内容包括数据预处理,模型搭建,网络训练,模型应用等,此Baseline采用最简单的一维卷积达到了88%测试准确率。有多支队伍在笔者基线代码基础上调优取得了优异成绩,顺利进入复赛。

数据介绍

下载完整的训练集和测试集,共1000例常规心电图,其中训练集中包含600例,测试集中共400例。该数据是从多个公开数据集中获取。参赛团队需要利用有正常/异常两类标签的训练集数据设计和实现算法,并在没有标签的测试集上做出预测。

该心电数据的采样率为500 Hz。为了方便参赛团队用不同编程语言都能读取数据,所有心电数据的存储格式为MAT格式。该文件中存储了12个导联的电压信号。训练数据对应的标签存储在txt文件中,其中0代表正常,1代表异常。

赛题分析

简单分析一下,初赛的数据集共有1000个样本,其中训练集中包含600例,测试集中共400例。其中训练集中包含600例是具有label的,可以用于我们训练模型;测试集中共400例没有标签,需要我们使用训练好的模型进行预测。

赛题就是一个二分类预测问题,解题思路应该包括以下内容

数据读取与处理

网络模型搭建

模型的训练

模型应用与提交预测结果

实战应用

经过对赛题的分析,我们把任务分成四个小任务,首先第一步是:

1.数据读取与处理

该心电数据的采样率为500 Hz。为了方便参赛团队用不同编程语言都能读取数据,所有心电数据的存储格式为MAT格式。该文件中存储了12个导联的电压信号。训练数据对应的标签存储在txt文件中,其中0代表正常,1代表异常。

我们由上述描述可以得知,

我们的数据保存在MAT格式文件中(这决定了后面我们要如何读取数据)

采样率为500 Hz(这个信息并没有怎么用到,大家可以简单了解一下,就是1秒采集500个点,由后面我们得知每个数据都是5000个点,也就是10秒的心电图片)

12个导联的电压信号(这个是指采用12种导联方式,大家可以简单理解为用12个体温计量体温,从而得到更加准确的信息,下图为导联方式简单介绍,大家了解下即可。要注意的是,既然提供了12种导联,我们应该全部都用到,虽然我们仅使用一种导联方式也可以进行训练与预测,但是经验告诉我们,采取多个特征会取得更优效果)

数据处理函数定义:

import kerasfrom scipy.io import loadmatimport matplotlib.pyplot as pltimport globimport numpy as npimport pandas as pdimport mathimport osfrom keras.layers import *from keras.models import *from keras.objectives import *BASE_DIR = “preliminary/TRAIN/”#进行归一化def normalize(v): return (v - v.mean(axis=1).reshape((v.shape[0],1))) / (v.max(axis=1).reshape((v.shape[0],1)) + 2e-12)loadmat打开文件def get_feature(wav_file,Lens = 12,BASE_DIR=BASE_DIR): mat = loadmat(BASE_DIR+wav_file) dat = mat[“data”] feature = dat[0:12] return(normalize(feature).transopse())#把标签转成oneHot形式def convert2oneHot(index,Lens): hot = np.zeros((Lens,)) hot[index] = 1 return(hot)TXT_DIR = “preliminary/reference.txt”MANIFEST_DIR = “preliminary/reference.csv”

读取一条数据进行显示

if name__ == “__main”: dat1 = get_feature(“preliminary/TRAIN/TRAIN101.mat”) print(dat1.shape) #one data shape is (12, 5000) plt.plt(dat1[:,0]) plt.show()

我们由上述信息可以看出每种导联都是由5000个点组成的列表,12种导联方式使每个样本都是12*5000的矩阵,类似于一张分辨率为12x5000的照片。

我们需要处理的就是把每个读取出来,归一化一下,送入网络进行训练可以了。

标签处理方式

def create_csv(TXT_DIR=TXT_DIR): lists = pd.read_csv(TXT_DIR,sep=r“\t”,header=None) lists = lists.sample(frac=1) lists.to_csv(MANIFEST_DIR,index=None) print(“Finish save csv”)

我这里是采用从reference.txt读取,然后打乱保存到reference.csv中,注意一定要进行数据打乱操作,不然训练效果很差。因为原始数据前面便签全部是1,后面全部是0

数据迭代方式

Batch_size = 20def xs_gen(path=MANIFEST_DIR,batch_size = Batch_size,train=True):img_list = pd.read_csv(path)if train : img_list = np.array(img_list)[:500] print(“Found %s train items.”%len(img_list)) print(“list 1 is”,img_list[0]) steps = math.ceil(len(img_list) / batch_size) # 确定每轮有多少个batchelse: img_list = np.array(img_list)[500:] print(“Found %s test items.”%len(img_list)) print(“list 1 is”,img_list[0]) steps = math.ceil(len(img_list) / batch_size) # 确定每轮有多少个batchwhile True: for i in range(steps): batch_list = img_list[i * batch_size : i * batch_size + batch_size] np.random.shuffle(batch_list) batch_x = np.array([get_feature(file) for file in batch_list[:,0]]) batch_y = np.array([convert2oneHot(label,2) for label in batch_list[:,1]]) yield batch_x, batch_y

数据读取的方式我采用的是生成器的方式,这样可以按batch读取,加快训练速度,大家也可以采用一下全部读取,看个人的习惯了

2.网络模型搭建

数据我们处理好了,后面就是模型的搭建了,我使用keras搭建的,操作简单便捷,tf,pytorch,sklearn大家可以按照自己喜好来。

网络模型可以选择CNN,RNN,Attention结构,或者多模型的融合,抛砖引玉,此Baseline采用的一维CNN方式,一维CNN学习地址

模型搭建

TIME_PERIODS = 5000num_sensors = 12def build_model(input_shape=(TIME_PERIODS,num_sensors),num_classes=2): model = Sequential() #model.add(Reshape((TIME_PERIODS, num_sensors), input_shape=input_shape)) model.add(Conv1D(16, 16,strides=2, activation=‘relu’,input_shape=input_shape)) model.add(Conv1D(16, 16,strides=2, activation=‘relu’,padding=“same”)) model.add(MaxPooling1D(2)) model.add(Conv1D(64, 8,strides=2, activation=‘relu’,padding=“same”)) model.add(Conv1D(64, 8,strides=2, activation=‘relu’,padding=“same”)) model.add(MaxPooling1D(2)) model.add(Conv1D(128, 4,strides=2, activation=‘relu’,padding=“same”)) model.add(Conv1D(128, 4,strides=2, activation=‘relu’,padding=“same”)) model.add(MaxPooling1D(2)) model.add(Conv1D(256, 2,strides=1, activation=‘relu’,padding=“same”)) model.add(Conv1D(256, 2,strides=1, activation=‘relu’,padding=“same”)) model.add(MaxPooling1D(2)) model.add(GlobalAveragePooling1D()) model.add(Dropout(0.3)) model.add(Dense(num_classes, activation=‘softmax’)) return(model)

用model.summary()输出的网络模型为

训练参数比较少,大家可以根据自己想法更改。

3.网络模型训练

模型训练

if name__ == “__main”: “”“dat1 = get_feature(”TRAIN101.mat“) print(”one data shape is“,dat1.shape) #one data shape is (12, 5000) plt.plot(dat1[0]) plt.show()”“” if (os.path.exists(MANIFEST_DIR)==False): create_csv() train_iter = xs_gen(train=True) test_iter = xs_gen(train=False) model = build_model() print(model.summary()) ckpt = keras.callbacks.ModelCheckpoint( filepath=‘best_model.{epoch:02d}-{val_acc:.2f}.h5’, monitor=‘val_acc’, save_best_only=True,verbose=1) model.compile(loss=‘categorical_crossentropy’, optimizer=‘adam’, metrics=[‘accuracy’]) model.fit_generator( generator=train_iter, steps_per_epoch=500//Batch_size, epochs=20, initial_epoch=0, validation_data = test_iter, nb_val_samples = 100//Batch_size, callbacks=[ckpt], )

训练过程输出(最优结果:loss: 0.0565 - acc: 0.9820 - val_loss: 0.8307 - val_acc: 0.8800)

Epoch 10/2025/25 [==============================] - 1s 37ms/step - loss: 0.2329 - acc: 0.9040 - val_loss: 0.4041 - val_acc: 0.8700Epoch 00010: val_acc improved from 0.85000 to 0.87000, saving model to best_model.10-0.87.h5Epoch 11/2025/25 [==============================] - 1s 38ms/step - loss: 0.1633 - acc: 0.9380 - val_loss: 0.5277 - val_acc: 0.8300Epoch 00011: val_acc did not improve from 0.87000Epoch 12/2025/25 [==============================] - 1s 40ms/step - loss: 0.1394 - acc: 0.9500 - val_loss: 0.4916 - val_acc: 0.7400Epoch 00012: val_acc did not improve from 0.87000Epoch 13/2025/25 [==============================] - 1s 38ms/step - loss: 0.1746 - acc: 0.9220 - val_loss: 0.5208 - val_acc: 0.8100Epoch 00013: val_acc did not improve from 0.87000Epoch 14/2025/25 [==============================] - 1s 38ms/step - loss: 0.1009 - acc: 0.9720 - val_loss: 0.5513 - val_acc: 0.8000Epoch 00014: val_acc did not improve from 0.87000Epoch 15/2025/25 [==============================] - 1s 38ms/step - loss: 0.0565 - acc: 0.9820 - val_loss: 0.8307 - val_acc: 0.8800Epoch 00015: val_acc improved from 0.87000 to 0.88000, saving model to best_model.15-0.88.h5Epoch 16/2025/25 [==============================] - 1s 38ms/step - loss: 0.0261 - acc: 0.9920 - val_loss: 0.6443 - val_acc: 0.8400Epoch 00016: val_acc did not improve from 0.88000Epoch 17/2025/25 [==============================] - 1s 38ms/step - loss: 0.0178 - acc: 0.9960 - val_loss: 0.7773 - val_acc: 0.8700Epoch 00017: val_acc did not improve from 0.88000Epoch 18/2025/25 [==============================] - 1s 38ms/step - loss: 0.0082 - acc: 0.9980 - val_loss: 0.8875 - val_acc: 0.8600Epoch 00018: val_acc did not improve from 0.88000Epoch 19/2025/25 [==============================] - 1s 37ms/step - loss: 0.0045 - acc: 1.0000 - val_loss: 1.0057 - val_acc: 0.8600Epoch 00019: val_acc did not improve from 0.88000Epoch 20/2025/25 [==============================] - 1s 37ms/step - loss: 0.0012 - acc: 1.0000 - val_loss: 1.1088 - val_acc: 0.8600Epoch 00020: val_acc did not improve from 0.88000

4.模型应用预测结果

预测数据

if name__ == “__main”: “”“dat1 = get_feature(”TRAIN101.mat“) print(”one data shape is“,dat1.shape) #one data shape is (12, 5000) plt.plot(dat1[0]) plt.show()”“” “”“if (os.path.exists(MANIFEST_DIR)==False): create_csv() train_iter = xs_gen(train=True) test_iter = xs_gen(train=False) model = build_model() print(model.summary()) ckpt = keras.callbacks.ModelCheckpoint( filepath=‘best_model.{epoch:02d}-{val_acc:.2f}.h5’, monitor=‘val_acc’, save_best_only=True,verbose=1) model.compile(loss=‘categorical_crossentropy’, optimizer=‘adam’, metrics=[‘accuracy’]) model.fit_generator( generator=train_iter, steps_per_epoch=500//Batch_size, epochs=20, initial_epoch=0, validation_data = test_iter, nb_val_samples = 100//Batch_size, callbacks=[ckpt], )”“” PRE_DIR = “sample_codes/answers.txt” model = load_model(“best_model.15-0.88.h5”) pre_lists = pd.read_csv(PRE_DIR,sep=r“ ”,header=None) print(pre_lists.head()) pre_datas = np.array([get_feature(item,BASE_DIR=“preliminary/TEST/”) for item in pre_lists[0]]) pre_result = model.predict_classes(pre_datas)#0-1概率预测 print(pre_result.shape) pre_lists[1] = pre_result pre_lists.to_csv(“sample_codes/answers1.txt”,index=None,header=None) print(“predict finish”)

下面是前十条预测结果:

TEST394,0TEST313,1TEST484,0TEST288,0TEST261,1TEST310,0TEST286,1TEST367,1TEST149,1TEST160,1

展望

此Baseline采用最简单的一维卷积达到了88%测试准确率(可能会因为随机初始化值上下波动),大家也可以多尝试GRU,Attention,和Resnet等结果,测试准确率会突破90+。


声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 心电图
    +关注

    关注

    1

    文章

    81

    浏览量

    25887
  • 开源
    +关注

    关注

    3

    文章

    4045

    浏览量

    45583

原文标题:实战 | 基于KerasConv1D心电图检测开源教程(附代码)

文章出处:【微信号:rgznai100,微信公众号:rgznai100】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    面向心电医疗器械的低功耗模拟前端afe芯片解决方案

    的作用。特别是针对心电图机、监护仪、动态心电记录仪等设备,低功耗、高精度的AFE芯片已成为推动心电医疗器械性能提升的核心引擎。
    的头像 发表于 11-12 15:03 264次阅读

    AFE4960 技术文档总结

    该AFE4960可以配置为 2 通道心电图接收器或 1 通道心电图接收器和呼吸阻抗通道。AFE 信号链可以灵活地连接到多达 4 个电极。右腿驱动 (RLD) 放大器输出可用于设置身体偏置。AFE 具有用于引线开/关
    的头像 发表于 10-30 14:32 435次阅读
    AFE4960 技术文档总结

    850立式加工中心电路图资料

    电子发烧友网站提供《850立式加工中心电路图资料.pdf》资料免费下载
    发表于 09-12 17:08 0次下载

    便携式心电图机定制_医疗手持终端方案定制_联发科安卓主板方案商

    便携式心电图机作为心血管疾病诊断的核心设备,其设计与性能直接决定了诊断的精确性和便捷性。基于联发科 MT8768 平台研发的便携式心电图机方案,通过采用标准的 12导联同步采集模式,在确保诊断级精度
    的头像 发表于 07-30 20:30 491次阅读
    便携式<b class='flag-5'>心电图</b>机定制_医疗手持终端方案定制_联发科安卓主板方案商

    心电图示仪的静电和浪涌保护参考设计方案

    概述心电图示仪作为一种精密医疗电子设备,通过WiFi、Bluetooth和NFC等传输数据,其稳定性和安全性至关重要。在设计和使用过程中,必须考虑静电放电和浪涌等电磁干扰对其性能和安全
    的头像 发表于 06-03 11:08 626次阅读
    <b class='flag-5'>心电图</b>示仪的静电和浪涌保护参考设计方案

    STM32L431RCT6主芯片 搭配 SD NAND-动态心电图设备存储解决方案

    MKDV08GCL-STPA 贴片式TF卡存储解决方案,实现设备性能的全面优化。 动态心电图设备对于存储得要求:1)海量数据存储需求 动态心电图设备需要长时间、高频率地采集心脏电信号,生成的数据量庞大
    发表于 03-27 10:56

    动态心电图设备存储解决方案——STM32L431RCT6主芯片 与 贴片式TF卡MKDV08GCL-STPA

    在医疗健康领域,心电图设备是心脏疾病诊断的核心工具。随着医疗技术的不断进步,动态心电图设备逐渐成为临床诊断的主流选择。它不仅能够提供静态心电图数据,还能实时记录心脏活动的动态变化,为医生提供更全
    的头像 发表于 03-27 10:40 1506次阅读
    动态<b class='flag-5'>心电图</b>设备存储解决方案——STM32L431RCT6主芯片 与 贴片式TF卡MKDV08GCL-STPA

    参考设计# 支持边缘 AI 的无线 ECG 动态心电图监护仪

    该支持边缘人工智能 (AI) 可穿戴生物传感动态心电图监测仪参考设计提供了一个评估平台,用于评估持续监测心电图(ECG)、心率、呼吸、起搏脉冲、体温和运动等生命体征的最新产品。该设计利
    的头像 发表于 02-17 17:54 1240次阅读
    参考设计# 支持边缘 AI 的无线 ECG 动态<b class='flag-5'>心电图</b>监护仪

    ADS1292做心电放大的疑问谁来解答一下

    是硬件问题吗?我没有加右腿驱动。用的是ECG信号产生模块,产生出的是1Hz,1mV的模拟心电信号。 出现在任何导联电极间的差分心电图信号的幅度都限定在±5 mV大小,频率在0.05
    发表于 01-20 07:23

    ADS1298作为心电信号采集的模拟前端, 使用心电图机检定仪进行共模抑制比的测试时,结果不是很理想怎么改善?

    产品使用了ADS1298作为心电信号采集的模拟前端, 使用心电图机检定仪进行共模抑制比的测试时,结果不是很理想。右腿驱动和屏蔽驱动都使用了,还可能是哪些方面的原因,或者有什么方法可以改善。 1.是不是使用镀金的连接器会好一点。
    发表于 01-08 08:05

    ADS1293采集到了心电数据,如何画成心电图

    各位大神好,自己做的ADS1293的板子,然后采集到了下图所示的数据,请问对不对?然后这些数据该如何画成心电图?谢谢!!
    发表于 01-08 06:50

    请问使用ads1298R同一通道为什么呼吸检测信号不会对心电检测有影响?

    如题,同一电极对即采集呼吸波也采集心电信号,我的理解是心电采集通道通过低通滤波器滤除调制后的呼吸波反馈信号,呼吸波检测通道通过特定频率同步解调的方法解调出呼吸波,也会加上高通滤波,这样同一通道可以连接到
    发表于 12-20 16:10

    用ads1298做了一个12导联的心电图设备,寄存器在板子上配置出来一直有问题,为什么?

    你好! 我用1298做了一个12导联的心电图设备,但是按照开发板配置的寄存器在我的板子上配置出来一直有问题.我的电路跟参考板一样,想问一下这个寄存器配置应该怎么配置?输出的8chn数据是否就是LEAD1,LEDA2.v
    发表于 12-20 08:36

    ads1292r可以读出数据但画图不准确是什么原因导致的?

    尝试了 很多的配置 但是画的不像心电图 谁有数据可以参考一下吗
    发表于 12-18 08:15

    用ADS1198设计12导的心电检测仪遇到的信号采集问题求解

    请教TI工程师,我正在使用ADS1198设计12导的心电检测仪,原理根据sbau180.pdf设计,模拟电源和数字电源都是采用3.3V单电源供电,中间用磁珠联通.目前程序已经完成,用心电模拟仪
    发表于 12-17 08:01