搭建全连接网络进行分类(糖尿病为例)
拿来练手,大神请绕道。
1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。
2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。
3.全连接的回归也写了,有空再上传吧。
4.一般都是先写data或者model
import torch
import torch.nn as nn
import torch.nn.functional as F
#nn.func这个里面很多功能其实nn里就有,可以不导入,而且后面新的版本的torch也取消了cc.functional里面的部分函数#定义网络,需要定义两部分,一部分就是初始化,另一部分就是数据流
class FCNet(nn.Module):def __init__(self):super(FCNet,self).__init__()self.fc1 = nn.Linear(8,16)#初始的这个8,要和你的数据的特征数一样才行,后面的数可以随意设置,但是不要太多,容易过拟合# self.fc2 = nn.Linear(50,20)self.fc3 = nn.Linear(16,2)#二分类,输出2,其实1也可以的#最后的就是分类数,因为用的sigmod和交叉熵损失,就不用额外加softmax了,多分类要用softmaxself.sig = nn.Sigmoid()# self.drop = nn.Dropout(0.3)#可以把用到的放在这里,也可以用nn.Sequential()放在一起,这样后面的话就可以直接用这个,不用写那么多了def forward(self,x):x = self.sig(self.fc1(x))# x = self.sig(self.fc2(x))x = self.sig(self.fc3(x))return x#就是x要怎么在网络中走,要写一遍#可以自己输出测试一下看看网络是不是自己想的那样,在真的调用的时候再屏蔽掉
# net= FCNet()
# print(net)
首先看看数据是是啥样,outcome就是有没有糖尿病
其实可以手动把csv分成train和test
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
#导入pands是为了读数据,当然使用numpy也可以读得,sklearn是为了把训练数据分为训练和验证集data = pd.read_csv('./train.csv')
#就是把对应的数据哪出来,x代表的是feature上的data,y代表的是label,因为pd可以读到最上面的标签,所以从第2行(i=1)开始读就行
x = data.iloc[1:,:-1]
y = data.iloc[1:,[-1]]
#可以输出看看数据对不对,x中不应该包含labels
# print(x)
# print(y)
#test_size就是划分的比例,后面的是种子,意思是每次运行这个函数时候,0.8就是那些,0.2也还是每次一样,如果想要不一样,只要每次运行这个函数时候换个值就行
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
#print(x_train,y_test)
# print(x_test,y_test)
#给数据进行归一化,可以用很多方法,我用最简单的归一到-1到1
x_train = x_train.apply(lambda x: (x - x.mean()) / (x.std()))
x_test = x_test.apply(lambda x: (x - x.mean()) / (x.std()))#写dataset可以用两种方法,第一种就是 每一个数据自己单独处理,第二个就是要自己重写dataset类
#1.
# 可以使用分别的处理,把数据(首先转换为tensor,或者把dataframe.valus拿出来才能转换为tensor)转换为tensor并且数据类型转换为float32,如果测试没有真值,需要单独转换
# x_train = torch.tensor(np.array(x_train),dtype=torch.float32)
# y_train = torch.tensor(np.array(y_train),dtype=torch.float32)
# x_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# y_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# train_dataset = torch.utils.data.TensorDataset(x_train,y_train)
# test_dataset = torch.utils.data.TensorDataset(x_test,y_test)#2.也可以直接重写datasetclass dataset(Dataset):def __init__(self, x, y):#把值拿出来或者变为np类型才能转换为tensor# self.data = torch.tensor(x.values,dtype=torch.float32)# self.labels = torch.tensor(y.values,dtype=torch.float32)self.data = torch.tensor(np.array(x),dtype=torch.float32)self.labels = torch.tensor(np.array(y),dtype=torch.float32)def __len__(self):return len(self.data)def __getitem__(self,idx):return self.data[idx],self.labels[idx]#应该返回的是list类型,不是字典也不是setBATCH_SIZE = 64#验证集一般不用shuffle
train_dataset = dataset(x_train,y_train)
test_dataset = dataset(x_test,y_test)
# print(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_lodaer = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)
# print(train_loader)
然后就可以写train或者test了,其实test和train一样
from Model import FCNet
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import data
#导入要调用的net和data,也可以from data import xxx 这样可以直接用xxx,现在的这个需要用data.xxx#看自己的设备,最好用gpu来跑
if (torch.cuda.is_available()):my_device = torch.device('cuda')
else:my_device = torch.device('cpu')print(my_device)
#实例化一个net,并且放到gpu上,需要放到gpu上的有inputs,labels,net,loss
net = FCNet().to(my_device)
# print(net)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#一开始是不需要weight_decay(也就是l2正则化),可以等出现过拟合在用,也可以先用上
optimizer = optim.Adam(net.parameters(),lr=0.001,weight_decay=0.01)epochs = 600
#定义train,因为一边训练一边验证,所有就把两个loader都放进去了,不过写法很多,也可以不放dataloader,放epoches也可以
def train(dataloader,valloader):losses = []acces = []losses_val = []for epoch in range(epochs):loss_batch = 0for i,data in enumerate(dataloader):#需要注意的,这里的inputs和labels和之前定义的dataset相关,需要是list类型才可以inputs,labels = data#print(data)可以打印出来查看一下inputs,labels = inputs.to(my_device),labels.to(my_device)optimizer.zero_grad()#每次要梯度清零outputs = net(inputs)#print(outputs)#model的最后一层是sigmod#labels的格式需要注意,因为现在是[[1],[0],[1],[1]..]这样得格式,无法放到交叉熵了,需要时[0,1,1,1...]这样得格式才行loss = criterion(outputs,labels.squeeze(1).long()).to(my_device)#print(labels.squeeze(1).long())loss.backward()optimizer.step()loss_batch += loss.item()length = i#验证的时候不用反向传播和梯度下降这些net.eval()count = 0right = 0loss_batch_val =0with torch.no_grad():for j,data2 in enumerate(valloader):val_inputs,val_labels = data2val_inputs,val_labels = val_inputs.to(my_device),val_labels.squeeze(1).long().to(my_device)val_outputs = net(val_inputs)loss_val = criterion(val_outputs,val_labels)#因为net的最后一层是2,所以输出的是2维的【0.6,0.4】这种,但是这个可以直接放到交叉熵中#——中放的是概率,pred中放的是预测的类别,算损失还是要用outputs,但是算准确率就是用pred和真实labels相比了_,pred = torch.max(val_outputs,1)#print(pred)right = (pred == val_labels).sum().item()count = len(val_labels)acc = right/countloss_batch_val += loss_val.item()length2 = jif epoch % 10 == 9:print('train_epoch:',epoch+1,'train_loss:',loss_batch/length,'val_loss:',loss_batch_val/length2,'acc:',acc)losses.append(loss_batch/length)acces.append(acc)losses_val.append(loss_batch_val/length2)#可以画一些曲线,输出一些值plt.plot(range(60),losses,color ='blue',label ='train_loss')plt.plot(range(60),acces, color ='red',label ='val_acc')plt.plot(range(60),losses_val,color ='yellow',label ='val_loss')plt.legend()plt.show()torch.save(net.state_dict(),'./weights_epoch1000.pth')#保存参数train(data.train_loader,data.test_lodaer)
最后看一下结果,最后的准确率在85%左右,还可以,毕竟数据不多,也是简单的全连接。
在这个结果之前出现了很多问题,比如波动很大,损失先降后升等问题,找个有问题的图
下面是一些总结:
1.跳跃很大,波动:增大batch_size,减小lr。
2.降低过拟合:
a.降低模型的复杂程度,但是修改具体的神经元个数,因为这个网络本身就不大,所有没啥用,模型非常大没准会有用。
b.batchsize增大,lr减小是有效的。
c.输入数据进行归一化是有用的,归一化之后lr可以调大一点,收敛变快了。
d.L2正则化是有用的,很有用。dropout应该也有用,但是模型本来就很小,我试了试没啥差别。而且有正则化之后可以加速收敛,lr可以稍微调大一点,较少的epoches也可以收敛了,而已acc也会更高一点,稳定一点。
相关文章:

搭建全连接网络进行分类(糖尿病为例)
拿来练手,大神请绕道。 1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。 2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。 3.全连接的回归也写了&#…...

【小沐学前端】Node.js实现基于Protobuf协议的UDP通信(UDP/TCP)
文章目录 1、简介1.1 node1.2 Protobuf 2、下载和安装2.1 node2.2 Protobuf2.2.1 安装2.2.2 工具 3、node 代码示例3.1 HTTP3.2 UDP单播3.4 UDP广播 4、Protobuf 代码示例4.1 例子: awesome.proto4.1.1 加载.proto文件方式4.1.2 加载.json文件方式4.1.3 加载.js文件方式 4.2 例…...
Verasity Tokenomics — 社区讨论总结与下一步计划
Verasity 代币经济学的社区讨论已结束。 本次讨论从 8 月 4 日持续到 9 月 29 日,是区块链领域规模最大的讨论之一,超过 500,000 名 VRA 持有者和社区成员参与讨论,并收到了数千份回复。 首先,我们要感谢所有参与讨论并提出详细建…...

JUC第十三讲:JUC锁: ReentrantLock详解
JUC第十三讲:JUC锁: ReentrantLock详解 本文是JUC第十三讲,JUC锁:ReentrantLock详解。可重入锁 ReentrantLock 的底层是通过 AbstractQueuedSynchronizer 实现,所以先要学习上一章节 AbstractQueuedSynchronizer 详解。 文章目录 …...

WSL2安装历程
WLS2安装 1、系统检查 安装WSL2必须运行 Windows 10 版本 2004 及更高版本(内部版本 19041 及更高版本)或 Windows 11。 查看 Windows 版本及内部版本号,选择 Win R,然后键入winver。 2、家庭版升级企业版 下载HEU_KMS_Activ…...

Ubuntu20配置Mysql常用操作
文章目录 版权声明ubuntu更换软件源Ubuntu设置静态ipUbuntu防火墙ubuntu安装ssh服务Ubuntu安装vmtoolsUbuntu安装mysql5.7Ubuntu安装mysql8.0Ubuntu卸载mysql 版权声明 本博客的内容基于我个人学习黑马程序员课程的学习笔记整理而成。我特此声明,所有版权属于黑马程…...

【解决方案】‘create’ is not a member of ‘cv::aruco::DetectorParameters’
‘create’ is not a member of ‘cv::aruco::DetectorParameters’ 在构建AruCo标定板标定位姿代码的过程中,发现代码中认为create并不是aruco::DetectorParameters的成员函数,这是因为在4.7.0及以上的OpenCV版本中,对ArUco的代码做调整&…...
门牌制作(蓝桥杯)
门牌制作 题目描述 本题为填空题,只需要算出结果后,在代码中使用输出语句将所填结果输出即可。 小蓝要为一条街的住户制作门牌号。 这条街一共有 2020 位住户,门牌号从 1 到 2020 编号。 小蓝制作门牌的方法是先制作 0 到 9 这几个数字字…...

支付宝支付模块开发
生成二维码 使用Hutool工具类生成二维码 引入对应的依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.5</version> </dependency><dependency><groupId>com.go…...

12、Kubernetes中KubeProxy实现之iptables和ipvs
目录 一、概述 二、iptables 代理模式 三、iptables案例分析 四、ipvs案例分析 一、概述 iptables和ipvs其实都是依赖的一个共同的Linux内核模块:Netfilter。Netfilter是Linux 2.4.x引入的一个子系统,它作为一个通用的、抽象的框架,提供…...
从0开始python学习-29.selenium 通过cookie信息进行登录
1. 手动输入cookie信息保持登录状态 url https://test.com/login driver.get(url) # 手动将cookie信息写入(有多个的情况需要分开写入)--弊端为需要每次都手动输入,很麻烦不适用 driver.add_cookie({"name": "SIAM_IMAGE_…...

CentOS安装OpenNebula(二)
被控端部署: 先要配置好yum源: [rootmaster yum.repos.d]# vim opennebula.repo[rootmaster yum.repos.d]# cat opennebula.repo [opennebula] nameopennebula baseurlhttps://downloads.opennebula.org/repo/5.6/CentOS/7/x86_64 enabled1 gpgkeyhttps…...
力扣第239题 c++滑动窗口经典题 单调队列
题目 239. 滑动窗口最大值 困难 提示 队列 数组 滑动窗口 单调队列 堆(优先队列) 给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回 滑动窗口中的…...

华为云云耀云服务器L实例评测|华为云云耀云服务器docker部署srs,可使用HLS协议
华为云云耀云服务器L实例评测|华为云云耀云服务器docker部署srs,可使用HLS协议 什么是华为云云耀云L实例 云耀云服务器L实例,面向初创企业和开发者打造的全新轻量应用云服务器。提供丰富严选的应用镜像,实现应用一键部署&#x…...

jira流转issue条目状态transitions的rest实用脚本,issue状态改变调整
官方文档链接地址: POST Transition issue Performs an issue transition and, if the transition has a screen, updates the fields from the transition screen. sortByCategory To update the fields on the transition screen, specify the fields in the fiel…...

JAVA 注解
1 概念 Annotation(注解)是 Java 提供的一种对元程序中元素关联信息和元数据(metadata)的途径和方法。Annatation(注解)是一个接口,程序可以通过反射来获取指定程序中元素的 Annotation 对象,然后通过该 An…...

C++面试题准备
文章目录 一、线程1.什么是进程,线程,彼此有什么区别?2.多进程、多线程的优缺点3.什么时候用进程,什么时候用线程4.多进程、多线程同步(通讯)的方法5.父进程、子进程的关系以及区别6.什么是进程上下文、中断上下文7.一…...

使用Java操作Redis
要在Java程序中操作Redis可以使用Jedis开源工具。 一、jedis的下载 如果使用Maven项目,可以把以下内容添加到pom中 <!-- https://mvnrepository.com/artifact/redis.clients/jedis --> <dependency> <groupId>redis.clients</groupId>…...

VRRP配置案例(路由走向分析,端口切换)
以下配置图为例 PC1的配置 acsw下行为access口,上行为trunk口, 将g0/0/3划分到vlan100中 <Huawei>sys Enter system view, return user view with CtrlZ. [Huawei]sysname acsw [acsw] Sep 11 2023 18:15:48-08:00 acsw DS/4/DATASYNC_CFGCHANGE:O…...

【图像处理】【应用程序设计】加载,编辑和保存图像数据、图像分割、色度键控研究(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...
零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?
一、核心优势:专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发,是一款收费低廉但功能全面的Windows NAS工具,主打“无学习成本部署” 。与其他NAS软件相比,其优势在于: 无需硬件改造:将任意W…...

YSYX学习记录(八)
C语言,练习0: 先创建一个文件夹,我用的是物理机: 安装build-essential 练习1: 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件,随机修改或删除一部分,之后…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...

苍穹外卖--缓存菜品
1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...
HTML前端开发:JavaScript 常用事件详解
作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...
MySQL中【正则表达式】用法
MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...
DeepSeek 技术赋能无人农场协同作业:用 AI 重构农田管理 “神经网”
目录 一、引言二、DeepSeek 技术大揭秘2.1 核心架构解析2.2 关键技术剖析 三、智能农业无人农场协同作业现状3.1 发展现状概述3.2 协同作业模式介绍 四、DeepSeek 的 “农场奇妙游”4.1 数据处理与分析4.2 作物生长监测与预测4.3 病虫害防治4.4 农机协同作业调度 五、实际案例大…...

LLMs 系列实操科普(1)
写在前面: 本期内容我们继续 Andrej Karpathy 的《How I use LLMs》讲座内容,原视频时长 ~130 分钟,以实操演示主流的一些 LLMs 的使用,由于涉及到实操,实际上并不适合以文字整理,但还是决定尽量整理一份笔…...
Java多线程实现之Runnable接口深度解析
Java多线程实现之Runnable接口深度解析 一、Runnable接口概述1.1 接口定义1.2 与Thread类的关系1.3 使用Runnable接口的优势 二、Runnable接口的基本实现方式2.1 传统方式实现Runnable接口2.2 使用匿名内部类实现Runnable接口2.3 使用Lambda表达式实现Runnable接口 三、Runnabl…...

构建Docker镜像的Dockerfile文件详解
文章目录 前言Dockerfile 案例docker build1. 基本构建2. 指定 Dockerfile 路径3. 设置构建时变量4. 不使用缓存5. 删除中间容器6. 拉取最新基础镜像7. 静默输出完整示例 docker runDockerFile 入门syntax指定构造器FROM基础镜像RUN命令注释COPY复制ENV设置环境变量EXPOSE暴露端…...