搭建全连接网络进行分类(糖尿病为例)
拿来练手,大神请绕道。
1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。
2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。
3.全连接的回归也写了,有空再上传吧。
4.一般都是先写data或者model
import torch
import torch.nn as nn
import torch.nn.functional as F
#nn.func这个里面很多功能其实nn里就有,可以不导入,而且后面新的版本的torch也取消了cc.functional里面的部分函数#定义网络,需要定义两部分,一部分就是初始化,另一部分就是数据流
class FCNet(nn.Module):def __init__(self):super(FCNet,self).__init__()self.fc1 = nn.Linear(8,16)#初始的这个8,要和你的数据的特征数一样才行,后面的数可以随意设置,但是不要太多,容易过拟合# self.fc2 = nn.Linear(50,20)self.fc3 = nn.Linear(16,2)#二分类,输出2,其实1也可以的#最后的就是分类数,因为用的sigmod和交叉熵损失,就不用额外加softmax了,多分类要用softmaxself.sig = nn.Sigmoid()# self.drop = nn.Dropout(0.3)#可以把用到的放在这里,也可以用nn.Sequential()放在一起,这样后面的话就可以直接用这个,不用写那么多了def forward(self,x):x = self.sig(self.fc1(x))# x = self.sig(self.fc2(x))x = self.sig(self.fc3(x))return x#就是x要怎么在网络中走,要写一遍#可以自己输出测试一下看看网络是不是自己想的那样,在真的调用的时候再屏蔽掉
# net= FCNet()
# print(net)
首先看看数据是是啥样,outcome就是有没有糖尿病

其实可以手动把csv分成train和test
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
#导入pands是为了读数据,当然使用numpy也可以读得,sklearn是为了把训练数据分为训练和验证集data = pd.read_csv('./train.csv')
#就是把对应的数据哪出来,x代表的是feature上的data,y代表的是label,因为pd可以读到最上面的标签,所以从第2行(i=1)开始读就行
x = data.iloc[1:,:-1]
y = data.iloc[1:,[-1]]
#可以输出看看数据对不对,x中不应该包含labels
# print(x)
# print(y)
#test_size就是划分的比例,后面的是种子,意思是每次运行这个函数时候,0.8就是那些,0.2也还是每次一样,如果想要不一样,只要每次运行这个函数时候换个值就行
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
#print(x_train,y_test)
# print(x_test,y_test)
#给数据进行归一化,可以用很多方法,我用最简单的归一到-1到1
x_train = x_train.apply(lambda x: (x - x.mean()) / (x.std()))
x_test = x_test.apply(lambda x: (x - x.mean()) / (x.std()))#写dataset可以用两种方法,第一种就是 每一个数据自己单独处理,第二个就是要自己重写dataset类
#1.
# 可以使用分别的处理,把数据(首先转换为tensor,或者把dataframe.valus拿出来才能转换为tensor)转换为tensor并且数据类型转换为float32,如果测试没有真值,需要单独转换
# x_train = torch.tensor(np.array(x_train),dtype=torch.float32)
# y_train = torch.tensor(np.array(y_train),dtype=torch.float32)
# x_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# y_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# train_dataset = torch.utils.data.TensorDataset(x_train,y_train)
# test_dataset = torch.utils.data.TensorDataset(x_test,y_test)#2.也可以直接重写datasetclass dataset(Dataset):def __init__(self, x, y):#把值拿出来或者变为np类型才能转换为tensor# self.data = torch.tensor(x.values,dtype=torch.float32)# self.labels = torch.tensor(y.values,dtype=torch.float32)self.data = torch.tensor(np.array(x),dtype=torch.float32)self.labels = torch.tensor(np.array(y),dtype=torch.float32)def __len__(self):return len(self.data)def __getitem__(self,idx):return self.data[idx],self.labels[idx]#应该返回的是list类型,不是字典也不是setBATCH_SIZE = 64#验证集一般不用shuffle
train_dataset = dataset(x_train,y_train)
test_dataset = dataset(x_test,y_test)
# print(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_lodaer = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)
# print(train_loader)
然后就可以写train或者test了,其实test和train一样
from Model import FCNet
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import data
#导入要调用的net和data,也可以from data import xxx 这样可以直接用xxx,现在的这个需要用data.xxx#看自己的设备,最好用gpu来跑
if (torch.cuda.is_available()):my_device = torch.device('cuda')
else:my_device = torch.device('cpu')print(my_device)
#实例化一个net,并且放到gpu上,需要放到gpu上的有inputs,labels,net,loss
net = FCNet().to(my_device)
# print(net)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#一开始是不需要weight_decay(也就是l2正则化),可以等出现过拟合在用,也可以先用上
optimizer = optim.Adam(net.parameters(),lr=0.001,weight_decay=0.01)epochs = 600
#定义train,因为一边训练一边验证,所有就把两个loader都放进去了,不过写法很多,也可以不放dataloader,放epoches也可以
def train(dataloader,valloader):losses = []acces = []losses_val = []for epoch in range(epochs):loss_batch = 0for i,data in enumerate(dataloader):#需要注意的,这里的inputs和labels和之前定义的dataset相关,需要是list类型才可以inputs,labels = data#print(data)可以打印出来查看一下inputs,labels = inputs.to(my_device),labels.to(my_device)optimizer.zero_grad()#每次要梯度清零outputs = net(inputs)#print(outputs)#model的最后一层是sigmod#labels的格式需要注意,因为现在是[[1],[0],[1],[1]..]这样得格式,无法放到交叉熵了,需要时[0,1,1,1...]这样得格式才行loss = criterion(outputs,labels.squeeze(1).long()).to(my_device)#print(labels.squeeze(1).long())loss.backward()optimizer.step()loss_batch += loss.item()length = i#验证的时候不用反向传播和梯度下降这些net.eval()count = 0right = 0loss_batch_val =0with torch.no_grad():for j,data2 in enumerate(valloader):val_inputs,val_labels = data2val_inputs,val_labels = val_inputs.to(my_device),val_labels.squeeze(1).long().to(my_device)val_outputs = net(val_inputs)loss_val = criterion(val_outputs,val_labels)#因为net的最后一层是2,所以输出的是2维的【0.6,0.4】这种,但是这个可以直接放到交叉熵中#——中放的是概率,pred中放的是预测的类别,算损失还是要用outputs,但是算准确率就是用pred和真实labels相比了_,pred = torch.max(val_outputs,1)#print(pred)right = (pred == val_labels).sum().item()count = len(val_labels)acc = right/countloss_batch_val += loss_val.item()length2 = jif epoch % 10 == 9:print('train_epoch:',epoch+1,'train_loss:',loss_batch/length,'val_loss:',loss_batch_val/length2,'acc:',acc)losses.append(loss_batch/length)acces.append(acc)losses_val.append(loss_batch_val/length2)#可以画一些曲线,输出一些值plt.plot(range(60),losses,color ='blue',label ='train_loss')plt.plot(range(60),acces, color ='red',label ='val_acc')plt.plot(range(60),losses_val,color ='yellow',label ='val_loss')plt.legend()plt.show()torch.save(net.state_dict(),'./weights_epoch1000.pth')#保存参数train(data.train_loader,data.test_lodaer)
最后看一下结果,最后的准确率在85%左右,还可以,毕竟数据不多,也是简单的全连接。

在这个结果之前出现了很多问题,比如波动很大,损失先降后升等问题,找个有问题的图

下面是一些总结:
1.跳跃很大,波动:增大batch_size,减小lr。
2.降低过拟合:
a.降低模型的复杂程度,但是修改具体的神经元个数,因为这个网络本身就不大,所有没啥用,模型非常大没准会有用。
b.batchsize增大,lr减小是有效的。
c.输入数据进行归一化是有用的,归一化之后lr可以调大一点,收敛变快了。
d.L2正则化是有用的,很有用。dropout应该也有用,但是模型本来就很小,我试了试没啥差别。而且有正则化之后可以加速收敛,lr可以稍微调大一点,较少的epoches也可以收敛了,而已acc也会更高一点,稳定一点。
相关文章:
搭建全连接网络进行分类(糖尿病为例)
拿来练手,大神请绕道。 1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。 2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。 3.全连接的回归也写了&#…...
【小沐学前端】Node.js实现基于Protobuf协议的UDP通信(UDP/TCP)
文章目录 1、简介1.1 node1.2 Protobuf 2、下载和安装2.1 node2.2 Protobuf2.2.1 安装2.2.2 工具 3、node 代码示例3.1 HTTP3.2 UDP单播3.4 UDP广播 4、Protobuf 代码示例4.1 例子: awesome.proto4.1.1 加载.proto文件方式4.1.2 加载.json文件方式4.1.3 加载.js文件方式 4.2 例…...
Verasity Tokenomics — 社区讨论总结与下一步计划
Verasity 代币经济学的社区讨论已结束。 本次讨论从 8 月 4 日持续到 9 月 29 日,是区块链领域规模最大的讨论之一,超过 500,000 名 VRA 持有者和社区成员参与讨论,并收到了数千份回复。 首先,我们要感谢所有参与讨论并提出详细建…...
JUC第十三讲:JUC锁: ReentrantLock详解
JUC第十三讲:JUC锁: ReentrantLock详解 本文是JUC第十三讲,JUC锁:ReentrantLock详解。可重入锁 ReentrantLock 的底层是通过 AbstractQueuedSynchronizer 实现,所以先要学习上一章节 AbstractQueuedSynchronizer 详解。 文章目录 …...
WSL2安装历程
WLS2安装 1、系统检查 安装WSL2必须运行 Windows 10 版本 2004 及更高版本(内部版本 19041 及更高版本)或 Windows 11。 查看 Windows 版本及内部版本号,选择 Win R,然后键入winver。 2、家庭版升级企业版 下载HEU_KMS_Activ…...
Ubuntu20配置Mysql常用操作
文章目录 版权声明ubuntu更换软件源Ubuntu设置静态ipUbuntu防火墙ubuntu安装ssh服务Ubuntu安装vmtoolsUbuntu安装mysql5.7Ubuntu安装mysql8.0Ubuntu卸载mysql 版权声明 本博客的内容基于我个人学习黑马程序员课程的学习笔记整理而成。我特此声明,所有版权属于黑马程…...
【解决方案】‘create’ is not a member of ‘cv::aruco::DetectorParameters’
‘create’ is not a member of ‘cv::aruco::DetectorParameters’ 在构建AruCo标定板标定位姿代码的过程中,发现代码中认为create并不是aruco::DetectorParameters的成员函数,这是因为在4.7.0及以上的OpenCV版本中,对ArUco的代码做调整&…...
门牌制作(蓝桥杯)
门牌制作 题目描述 本题为填空题,只需要算出结果后,在代码中使用输出语句将所填结果输出即可。 小蓝要为一条街的住户制作门牌号。 这条街一共有 2020 位住户,门牌号从 1 到 2020 编号。 小蓝制作门牌的方法是先制作 0 到 9 这几个数字字…...
支付宝支付模块开发
生成二维码 使用Hutool工具类生成二维码 引入对应的依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.5</version> </dependency><dependency><groupId>com.go…...
12、Kubernetes中KubeProxy实现之iptables和ipvs
目录 一、概述 二、iptables 代理模式 三、iptables案例分析 四、ipvs案例分析 一、概述 iptables和ipvs其实都是依赖的一个共同的Linux内核模块:Netfilter。Netfilter是Linux 2.4.x引入的一个子系统,它作为一个通用的、抽象的框架,提供…...
从0开始python学习-29.selenium 通过cookie信息进行登录
1. 手动输入cookie信息保持登录状态 url https://test.com/login driver.get(url) # 手动将cookie信息写入(有多个的情况需要分开写入)--弊端为需要每次都手动输入,很麻烦不适用 driver.add_cookie({"name": "SIAM_IMAGE_…...
CentOS安装OpenNebula(二)
被控端部署: 先要配置好yum源: [rootmaster yum.repos.d]# vim opennebula.repo[rootmaster yum.repos.d]# cat opennebula.repo [opennebula] nameopennebula baseurlhttps://downloads.opennebula.org/repo/5.6/CentOS/7/x86_64 enabled1 gpgkeyhttps…...
力扣第239题 c++滑动窗口经典题 单调队列
题目 239. 滑动窗口最大值 困难 提示 队列 数组 滑动窗口 单调队列 堆(优先队列) 给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回 滑动窗口中的…...
华为云云耀云服务器L实例评测|华为云云耀云服务器docker部署srs,可使用HLS协议
华为云云耀云服务器L实例评测|华为云云耀云服务器docker部署srs,可使用HLS协议 什么是华为云云耀云L实例 云耀云服务器L实例,面向初创企业和开发者打造的全新轻量应用云服务器。提供丰富严选的应用镜像,实现应用一键部署&#x…...
jira流转issue条目状态transitions的rest实用脚本,issue状态改变调整
官方文档链接地址: POST Transition issue Performs an issue transition and, if the transition has a screen, updates the fields from the transition screen. sortByCategory To update the fields on the transition screen, specify the fields in the fiel…...
JAVA 注解
1 概念 Annotation(注解)是 Java 提供的一种对元程序中元素关联信息和元数据(metadata)的途径和方法。Annatation(注解)是一个接口,程序可以通过反射来获取指定程序中元素的 Annotation 对象,然后通过该 An…...
C++面试题准备
文章目录 一、线程1.什么是进程,线程,彼此有什么区别?2.多进程、多线程的优缺点3.什么时候用进程,什么时候用线程4.多进程、多线程同步(通讯)的方法5.父进程、子进程的关系以及区别6.什么是进程上下文、中断上下文7.一…...
使用Java操作Redis
要在Java程序中操作Redis可以使用Jedis开源工具。 一、jedis的下载 如果使用Maven项目,可以把以下内容添加到pom中 <!-- https://mvnrepository.com/artifact/redis.clients/jedis --> <dependency> <groupId>redis.clients</groupId>…...
VRRP配置案例(路由走向分析,端口切换)
以下配置图为例 PC1的配置 acsw下行为access口,上行为trunk口, 将g0/0/3划分到vlan100中 <Huawei>sys Enter system view, return user view with CtrlZ. [Huawei]sysname acsw [acsw] Sep 11 2023 18:15:48-08:00 acsw DS/4/DATASYNC_CFGCHANGE:O…...
【图像处理】【应用程序设计】加载,编辑和保存图像数据、图像分割、色度键控研究(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...
手游刚开服就被攻击怎么办?如何防御DDoS?
开服初期是手游最脆弱的阶段,极易成为DDoS攻击的目标。一旦遭遇攻击,可能导致服务器瘫痪、玩家流失,甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案,帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...
Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具
文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...
虚拟电厂发展三大趋势:市场化、技术主导、车网互联
市场化:从政策驱动到多元盈利 政策全面赋能 2025年4月,国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》,首次明确虚拟电厂为“独立市场主体”,提出硬性目标:2027年全国调节能力≥2000万千瓦࿰…...
MySQL JOIN 表过多的优化思路
当 MySQL 查询涉及大量表 JOIN 时,性能会显著下降。以下是优化思路和简易实现方法: 一、核心优化思路 减少 JOIN 数量 数据冗余:添加必要的冗余字段(如订单表直接存储用户名)合并表:将频繁关联的小表合并成…...
pikachu靶场通关笔记19 SQL注入02-字符型注入(GET)
目录 一、SQL注入 二、字符型SQL注入 三、字符型注入与数字型注入 四、源码分析 五、渗透实战 1、渗透准备 2、SQL注入探测 (1)输入单引号 (2)万能注入语句 3、获取回显列orderby 4、获取数据库名database 5、获取表名…...
云原生周刊:k0s 成为 CNCF 沙箱项目
开源项目推荐 HAMi HAMi(原名 k8s‑vGPU‑scheduler)是一款 CNCF Sandbox 级别的开源 K8s 中间件,通过虚拟化 GPU/NPU 等异构设备并支持内存、计算核心时间片隔离及共享调度,为容器提供统一接口,实现细粒度资源配额…...
用递归算法解锁「子集」问题 —— LeetCode 78题解析
文章目录 一、题目介绍二、递归思路详解:从决策树开始理解三、解法一:二叉决策树 DFS四、解法二:组合式回溯写法(推荐)五、解法对比 递归算法是编程中一种非常强大且常见的思想,它能够优雅地解决很多复杂的…...
Mac flutter环境搭建
一、下载flutter sdk 制作 Android 应用 | Flutter 中文文档 - Flutter 中文开发者网站 - Flutter 1、查看mac电脑处理器选择sdk 2、解压 unzip ~/Downloads/flutter_macos_arm64_3.32.2-stable.zip \ -d ~/development/ 3、添加环境变量 命令行打开配置环境变量文件 ope…...
python打卡第47天
昨天代码中注意力热图的部分顺移至今天 知识点回顾: 热力图 作业:对比不同卷积层热图可视化的结果 def visualize_attention_map(model, test_loader, device, class_names, num_samples3):"""可视化模型的注意力热力图,展示模…...
【Java多线程从青铜到王者】单例设计模式(八)
wait和sleep的区别 我们的wait也是提供了一个还有超时时间的版本,sleep也是可以指定时间的,也就是说时间一到就会解除阻塞,继续执行 wait和sleep都能被提前唤醒(虽然时间还没有到也可以提前唤醒),wait能被notify提前唤醒…...
