当前位置: 首页 > news >正文

搭建全连接网络进行分类(糖尿病为例)

拿来练手,大神请绕道。

1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。

2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。

3.全连接的回归也写了,有空再上传吧。

4.一般都是先写data或者model

import torch
import torch.nn as nn
import torch.nn.functional as F
#nn.func这个里面很多功能其实nn里就有,可以不导入,而且后面新的版本的torch也取消了cc.functional里面的部分函数#定义网络,需要定义两部分,一部分就是初始化,另一部分就是数据流
class FCNet(nn.Module):def __init__(self):super(FCNet,self).__init__()self.fc1 = nn.Linear(8,16)#初始的这个8,要和你的数据的特征数一样才行,后面的数可以随意设置,但是不要太多,容易过拟合# self.fc2 = nn.Linear(50,20)self.fc3 = nn.Linear(16,2)#二分类,输出2,其实1也可以的#最后的就是分类数,因为用的sigmod和交叉熵损失,就不用额外加softmax了,多分类要用softmaxself.sig = nn.Sigmoid()# self.drop = nn.Dropout(0.3)#可以把用到的放在这里,也可以用nn.Sequential()放在一起,这样后面的话就可以直接用这个,不用写那么多了def forward(self,x):x = self.sig(self.fc1(x))# x = self.sig(self.fc2(x))x = self.sig(self.fc3(x))return x#就是x要怎么在网络中走,要写一遍#可以自己输出测试一下看看网络是不是自己想的那样,在真的调用的时候再屏蔽掉
# net= FCNet()
# print(net)

首先看看数据是是啥样,outcome就是有没有糖尿病

其实可以手动把csv分成train和test

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
#导入pands是为了读数据,当然使用numpy也可以读得,sklearn是为了把训练数据分为训练和验证集data = pd.read_csv('./train.csv')
#就是把对应的数据哪出来,x代表的是feature上的data,y代表的是label,因为pd可以读到最上面的标签,所以从第2行(i=1)开始读就行
x = data.iloc[1:,:-1]
y = data.iloc[1:,[-1]]
#可以输出看看数据对不对,x中不应该包含labels
# print(x)
# print(y)
#test_size就是划分的比例,后面的是种子,意思是每次运行这个函数时候,0.8就是那些,0.2也还是每次一样,如果想要不一样,只要每次运行这个函数时候换个值就行
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
#print(x_train,y_test)
# print(x_test,y_test)
#给数据进行归一化,可以用很多方法,我用最简单的归一到-1到1
x_train = x_train.apply(lambda x: (x - x.mean()) / (x.std()))
x_test = x_test.apply(lambda x: (x - x.mean()) / (x.std()))#写dataset可以用两种方法,第一种就是 每一个数据自己单独处理,第二个就是要自己重写dataset类
#1.
# 可以使用分别的处理,把数据(首先转换为tensor,或者把dataframe.valus拿出来才能转换为tensor)转换为tensor并且数据类型转换为float32,如果测试没有真值,需要单独转换
# x_train = torch.tensor(np.array(x_train),dtype=torch.float32)
# y_train = torch.tensor(np.array(y_train),dtype=torch.float32)
# x_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# y_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# train_dataset = torch.utils.data.TensorDataset(x_train,y_train)
# test_dataset = torch.utils.data.TensorDataset(x_test,y_test)#2.也可以直接重写datasetclass dataset(Dataset):def __init__(self, x, y):#把值拿出来或者变为np类型才能转换为tensor# self.data = torch.tensor(x.values,dtype=torch.float32)# self.labels = torch.tensor(y.values,dtype=torch.float32)self.data = torch.tensor(np.array(x),dtype=torch.float32)self.labels = torch.tensor(np.array(y),dtype=torch.float32)def __len__(self):return len(self.data)def __getitem__(self,idx):return self.data[idx],self.labels[idx]#应该返回的是list类型,不是字典也不是setBATCH_SIZE = 64#验证集一般不用shuffle
train_dataset = dataset(x_train,y_train)
test_dataset = dataset(x_test,y_test)
# print(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_lodaer = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)
# print(train_loader)

然后就可以写train或者test了,其实test和train一样

from Model import FCNet
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import data
#导入要调用的net和data,也可以from data import xxx 这样可以直接用xxx,现在的这个需要用data.xxx#看自己的设备,最好用gpu来跑
if (torch.cuda.is_available()):my_device = torch.device('cuda')
else:my_device = torch.device('cpu')print(my_device)
#实例化一个net,并且放到gpu上,需要放到gpu上的有inputs,labels,net,loss
net = FCNet().to(my_device)
# print(net)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#一开始是不需要weight_decay(也就是l2正则化),可以等出现过拟合在用,也可以先用上
optimizer = optim.Adam(net.parameters(),lr=0.001,weight_decay=0.01)epochs = 600
#定义train,因为一边训练一边验证,所有就把两个loader都放进去了,不过写法很多,也可以不放dataloader,放epoches也可以
def train(dataloader,valloader):losses = []acces = []losses_val = []for epoch in range(epochs):loss_batch = 0for i,data in enumerate(dataloader):#需要注意的,这里的inputs和labels和之前定义的dataset相关,需要是list类型才可以inputs,labels = data#print(data)可以打印出来查看一下inputs,labels = inputs.to(my_device),labels.to(my_device)optimizer.zero_grad()#每次要梯度清零outputs = net(inputs)#print(outputs)#model的最后一层是sigmod#labels的格式需要注意,因为现在是[[1],[0],[1],[1]..]这样得格式,无法放到交叉熵了,需要时[0,1,1,1...]这样得格式才行loss = criterion(outputs,labels.squeeze(1).long()).to(my_device)#print(labels.squeeze(1).long())loss.backward()optimizer.step()loss_batch += loss.item()length = i#验证的时候不用反向传播和梯度下降这些net.eval()count = 0right = 0loss_batch_val =0with torch.no_grad():for j,data2 in enumerate(valloader):val_inputs,val_labels = data2val_inputs,val_labels = val_inputs.to(my_device),val_labels.squeeze(1).long().to(my_device)val_outputs = net(val_inputs)loss_val = criterion(val_outputs,val_labels)#因为net的最后一层是2,所以输出的是2维的【0.6,0.4】这种,但是这个可以直接放到交叉熵中#——中放的是概率,pred中放的是预测的类别,算损失还是要用outputs,但是算准确率就是用pred和真实labels相比了_,pred = torch.max(val_outputs,1)#print(pred)right = (pred == val_labels).sum().item()count = len(val_labels)acc = right/countloss_batch_val += loss_val.item()length2 = jif epoch % 10 == 9:print('train_epoch:',epoch+1,'train_loss:',loss_batch/length,'val_loss:',loss_batch_val/length2,'acc:',acc)losses.append(loss_batch/length)acces.append(acc)losses_val.append(loss_batch_val/length2)#可以画一些曲线,输出一些值plt.plot(range(60),losses,color ='blue',label ='train_loss')plt.plot(range(60),acces, color ='red',label ='val_acc')plt.plot(range(60),losses_val,color ='yellow',label ='val_loss')plt.legend()plt.show()torch.save(net.state_dict(),'./weights_epoch1000.pth')#保存参数train(data.train_loader,data.test_lodaer)

最后看一下结果,最后的准确率在85%左右,还可以,毕竟数据不多,也是简单的全连接。

在这个结果之前出现了很多问题,比如波动很大,损失先降后升等问题,找个有问题的图

下面是一些总结:

1.跳跃很大,波动:增大batch_size,减小lr。

2.降低过拟合:

        a.降低模型的复杂程度,但是修改具体的神经元个数,因为这个网络本身就不大,所有没啥用,模型非常大没准会有用。

        b.batchsize增大,lr减小是有效的。

        c.输入数据进行归一化是有用的,归一化之后lr可以调大一点,收敛变快了。

        d.L2正则化是有用的,很有用。dropout应该也有用,但是模型本来就很小,我试了试没啥差别。而且有正则化之后可以加速收敛,lr可以稍微调大一点,较少的epoches也可以收敛了,而已acc也会更高一点,稳定一点。

相关文章:

搭建全连接网络进行分类(糖尿病为例)

拿来练手,大神请绕道。 1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。 2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。 3.全连接的回归也写了&#…...

【小沐学前端】Node.js实现基于Protobuf协议的UDP通信(UDP/TCP)

文章目录 1、简介1.1 node1.2 Protobuf 2、下载和安装2.1 node2.2 Protobuf2.2.1 安装2.2.2 工具 3、node 代码示例3.1 HTTP3.2 UDP单播3.4 UDP广播 4、Protobuf 代码示例4.1 例子: awesome.proto4.1.1 加载.proto文件方式4.1.2 加载.json文件方式4.1.3 加载.js文件方式 4.2 例…...

Verasity Tokenomics — 社区讨论总结与下一步计划

Verasity 代币经济学的社区讨论已结束。 本次讨论从 8 月 4 日持续到 9 月 29 日,是区块链领域规模最大的讨论之一,超过 500,000 名 VRA 持有者和社区成员参与讨论,并收到了数千份回复。 首先,我们要感谢所有参与讨论并提出详细建…...

JUC第十三讲:JUC锁: ReentrantLock详解

JUC第十三讲:JUC锁: ReentrantLock详解 本文是JUC第十三讲,JUC锁:ReentrantLock详解。可重入锁 ReentrantLock 的底层是通过 AbstractQueuedSynchronizer 实现,所以先要学习上一章节 AbstractQueuedSynchronizer 详解。 文章目录 …...

WSL2安装历程

WLS2安装 1、系统检查 安装WSL2必须运行 Windows 10 版本 2004 及更高版本(内部版本 19041 及更高版本)或 Windows 11。 查看 Windows 版本及内部版本号,选择 Win R,然后键入winver。 2、家庭版升级企业版 下载HEU_KMS_Activ…...

Ubuntu20配置Mysql常用操作

文章目录 版权声明ubuntu更换软件源Ubuntu设置静态ipUbuntu防火墙ubuntu安装ssh服务Ubuntu安装vmtoolsUbuntu安装mysql5.7Ubuntu安装mysql8.0Ubuntu卸载mysql 版权声明 本博客的内容基于我个人学习黑马程序员课程的学习笔记整理而成。我特此声明,所有版权属于黑马程…...

【解决方案】‘create’ is not a member of ‘cv::aruco::DetectorParameters’

‘create’ is not a member of ‘cv::aruco::DetectorParameters’ 在构建AruCo标定板标定位姿代码的过程中,发现代码中认为create并不是aruco::DetectorParameters的成员函数,这是因为在4.7.0及以上的OpenCV版本中,对ArUco的代码做调整&…...

门牌制作(蓝桥杯)

门牌制作 题目描述 本题为填空题,只需要算出结果后,在代码中使用输出语句将所填结果输出即可。 小蓝要为一条街的住户制作门牌号。 这条街一共有 2020 位住户,门牌号从 1 到 2020 编号。 小蓝制作门牌的方法是先制作 0 到 9 这几个数字字…...

支付宝支付模块开发

生成二维码 使用Hutool工具类生成二维码 引入对应的依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.5</version> </dependency><dependency><groupId>com.go…...

12、Kubernetes中KubeProxy实现之iptables和ipvs

目录 一、概述 二、iptables 代理模式 三、iptables案例分析 四、ipvs案例分析 一、概述 iptables和ipvs其实都是依赖的一个共同的Linux内核模块&#xff1a;Netfilter。Netfilter是Linux 2.4.x引入的一个子系统&#xff0c;它作为一个通用的、抽象的框架&#xff0c;提供…...

从0开始python学习-29.selenium 通过cookie信息进行登录

1. 手动输入cookie信息保持登录状态 url https://test.com/login driver.get(url) # 手动将cookie信息写入&#xff08;有多个的情况需要分开写入&#xff09;--弊端为需要每次都手动输入&#xff0c;很麻烦不适用 driver.add_cookie({"name": "SIAM_IMAGE_…...

CentOS安装OpenNebula(二)

被控端部署&#xff1a; 先要配置好yum源&#xff1a; [rootmaster yum.repos.d]# vim opennebula.repo[rootmaster yum.repos.d]# cat opennebula.repo [opennebula] nameopennebula baseurlhttps://downloads.opennebula.org/repo/5.6/CentOS/7/x86_64 enabled1 gpgkeyhttps…...

力扣第239题 c++滑动窗口经典题 单调队列

题目 239. 滑动窗口最大值 困难 提示 队列 数组 滑动窗口 单调队列 堆(优先队列) 给你一个整数数组 nums&#xff0c;有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回 滑动窗口中的…...

华为云云耀云服务器L实例评测|华为云云耀云服务器docker部署srs,可使用HLS协议

华为云云耀云服务器L实例评测&#xff5c;华为云云耀云服务器docker部署srs&#xff0c;可使用HLS协议 什么是华为云云耀云L实例 云耀云服务器L实例&#xff0c;面向初创企业和开发者打造的全新轻量应用云服务器。提供丰富严选的应用镜像&#xff0c;实现应用一键部署&#x…...

jira流转issue条目状态transitions的rest实用脚本,issue状态改变调整

官方文档链接地址&#xff1a; POST Transition issue Performs an issue transition and, if the transition has a screen, updates the fields from the transition screen. sortByCategory To update the fields on the transition screen, specify the fields in the fiel…...

JAVA 注解

1 概念 Annotation&#xff08;注解&#xff09;是 Java 提供的一种对元程序中元素关联信息和元数据&#xff08;metadata&#xff09;的途径和方法。Annatation(注解)是一个接口&#xff0c;程序可以通过反射来获取指定程序中元素的 Annotation 对象&#xff0c;然后通过该 An…...

C++面试题准备

文章目录 一、线程1.什么是进程&#xff0c;线程&#xff0c;彼此有什么区别?2.多进程、多线程的优缺点3.什么时候用进程&#xff0c;什么时候用线程4.多进程、多线程同步&#xff08;通讯&#xff09;的方法5.父进程、子进程的关系以及区别6.什么是进程上下文、中断上下文7.一…...

使用Java操作Redis

要在Java程序中操作Redis可以使用Jedis开源工具。 一、jedis的下载 如果使用Maven项目&#xff0c;可以把以下内容添加到pom中 <!-- https://mvnrepository.com/artifact/redis.clients/jedis --> <dependency> <groupId>redis.clients</groupId>…...

VRRP配置案例(路由走向分析,端口切换)

以下配置图为例 PC1的配置 acsw下行为access口&#xff0c;上行为trunk口&#xff0c; 将g0/0/3划分到vlan100中 <Huawei>sys Enter system view, return user view with CtrlZ. [Huawei]sysname acsw [acsw] Sep 11 2023 18:15:48-08:00 acsw DS/4/DATASYNC_CFGCHANGE:O…...

【图像处理】【应用程序设计】加载,编辑和保存图像数据、图像分割、色度键控研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…...

05. 机器学习入门 - 动态规划

文章目录 从一个案例开始动态规划 Hi, 你好。我是茶桁。 咱们之前的课程就给大家讲了什么是人工智能&#xff0c;也说了每个人的定义都不太一样。关于人工智能的不同观点和方法&#xff0c;其实是一个很复杂的领域&#xff0c;我们无法用一个或者两个概念确定什么是人工智能&a…...

【JVM】第五篇 垃圾收集器G1和ZGC详解

导航 一. G1垃圾收集算法详解1. 大对象Humongous说明2. G1收集器执行一次GC运行的过程步骤3. G1垃圾收集分类4. G1垃圾收集器参数设置5. G1垃圾收集器的优化建议6. 适合使用G1垃圾收集器的场景?二. ZGC垃圾收集器详解1. NUMA与UMA2. 颜色指针3. ZGC的运作过程4. ZGC垃圾收集器…...

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石⑤

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石⑤ 第十九章 驱动程序基石⑤19.9 mmap19.9.1 内存映射现象与数据结构19.9.2 ARM架构内存映射简介19.9.2.1 一级页表映射过程19.9.2.2 二级页表映射过程 19.9.3 怎么给APP新建一块内存映射19.9.3.1 mmap调用过程19.9.3.2 cach…...

数据分析技能点-独立性检验拟合优度检验

在这个数据驱动的时代,数据分析已经成为了一个不可或缺的工具,无论是在商业决策、医疗研究还是日常生活中。然而数据分析并不仅仅是一堆数字和图表;它是一个需要严谨的科学方法和逻辑推理的过程。 本文将重点介绍两种广泛应用于数据分析的统计检验方法:独立性检验和拟合优…...

了解汽车ecu组成

常用ecu框架组成&#xff1a; BCM(body control module)-车身控制模块: 如英飞凌tc265芯片&#xff1a; 车身控制单元&#xff08;BCM&#xff09;适合应用于12V和24V两种电压工作环境&#xff0c;可用于轿车、大客车和商用车的车身控制。输入模块通过采集电路采集各路开关量和…...

用AI原生向量数据库Milvus Cloud 搭建一个 AI 聊天机器人

搭建聊天机器人 一切准备就绪后,就可以搭建聊天机器人了。 文档存储 机器人需要存储文档块以及使用 Towhee 提取出的文档块向量。在这个步骤中,我们需要用到 Milvus。 安装轻量版 Milvus Lite,使用以下命令运行 Milvus 服务器: (chatbot_venv) [egoebelbecker@ares milvus_…...

【OpenCV-Torch-dlib-ubuntu】Vm虚拟机linux环境摄像头调用方法与dilb模型探究

前言 随着金秋时节的来临&#xff0c;国庆和中秋的双重喜庆汇聚成一片温暖的节日氛围。在这个美好的时刻&#xff0c;我们有幸共同迎来一次长达8天的假期&#xff0c;为心灵充电&#xff0c;为身体放松&#xff0c;为未来充实自己。今年的国庆不仅仅是家国团聚的时刻&#xff…...

(二)详解观察者模式

一.使用场景 当我们需要一个类&#xff0c;在他的内部元素发生变化的时候可以主动通知其他类的时候&#xff0c;同时要保持良好的可拓展性&#xff0c;可以采用观察者模式。 二.核心 观察者模式出版者订阅者 我们拥有一个主题对象&#xff0c;和一些其他对象&#xff0c;包…...

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石④

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石④ 第十九章 驱动程序基石④19.7 工作队列19.7.1 内核函数19.7.1.1 定义 work19.7.1.2 使用 work&#xff1a;schedule_work19.7.1.3 其他函数 19.7.2 编程、上机19.7.3 内部机制19.7.3.1 Linux 2.x的工作队列创建过程19.7.3…...

2023 彩虹全新 SUP 模板,卡卡云模板修复版

2023 彩虹全新 SUP 模板&#xff0c;卡卡云模板&#xff0c;首页美化&#xff0c;登陆页美化&#xff0c;修复了 PC 端购物车页面显示不正常的问题。 使用教程 将这俩个数据库文件导入数据库&#xff1b; 其他的直接导入网站根目录覆盖就好&#xff1b; 若首页显示不正常&a…...

网站建设易客/网店代运营诈骗

先看全站仪后方交会建站常见的操作步骤&#xff1a;首先仪器随便架在一个方便的地方。选择你所测量需要的那个坐标系&#xff0c;再进入新点功能。用后方交会法。就可以采点了。先照准你已知的第一个点&#xff0c;再照准已知第2个点。坐标系就建好了。然后就可以碎部测量或放样…...

wordpress微信分享带缩略图/活动推广方案策划

又一次通过点滴时间——吃晚饭&#xff0c;地铁上读完了《小就是大》&#xff0c;这本书的介绍请点击这里。 里面有一小节叫做“CD Baby的确认信”对于品牌的宣传起到很大的作用。 恭喜您&#xff0c;我们的工作人员刚刚戴着消过毒的手套将您预定的CD从货架上取下来&#xff0c…...

成品电影网站建设/怎么建立一个公司的网站

在写python代码的时候&#xff0c;我们常常会遇到BUG满天飞、代码跑不了或者项目结构没用等一系列问题。但是&#xff0c;我们又很难发现到底是其中的哪一个步骤&#xff0c;导致了这些问题的出现。导致这些问题的其中一个原因&#xff0c;就是我们没有养成良好的编程习惯。编程…...

深圳网站制作的/巩义网站推广优化

随着Log4j安全漏洞的出现&#xff0c;研究人员已经看到多个攻击者(主要是出于经济动机)立即将其添加到他们的武器库中。毫不奇怪&#xff0c;一些由国家支持的攻击者也将这个新漏洞视为在潜在目标&#xff0c;在受影响系统修复这个漏洞之前寻找发动攻击的机会。 APT35&#xf…...

天津百度优化公司/站长工具seo推广

点击上方“C语言入门到精通”&#xff0c;选择置顶第一时间关注程序猿身边的故事作者闫小林白天搬砖&#xff0c;晚上做梦。我有故事&#xff0c;你有酒么&#xff1f;C引用作函数参数C之所以增加引用类型&#xff0c;主要是把它作为函数参数&#xff0c;以扩充函数传递数据的功…...

安阳网站建设哪家正规/北京seo公司司

实验&#xff1a;在Redhat Enterprisr linux 5.4实现oracle 10g的集群 本实验所使用的虚拟机环境&#xff1a;VMware workstation8.0 一、准备工作 所谓工欲善必先利其器&#xff0c;要在vmware下做linux系统的oracle rac&#xff0c;我们也需要准备好相关的装备。 本实验使用到…...