当前位置: 首页 > news >正文

前馈神经网络dropout实例

直接看代码。

(一)手动实现


import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt#下载MNIST手写数据集  
mnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  #读取数据  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  #初始化参数  
num_inputs,num_hiddens,num_outputs =784, 256,10num_epochs=30lr = 0.001def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  b1 = torch.zeros(1, dtype=torch.float32)  W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  b2 = torch.zeros(1, dtype=torch.float32)  params =[W1,b1,W2,b2]for param in params:  param.requires_grad_(requires_grad=True)  return W1,b1,W2,b2def dropout(X, drop_prob):X = X.float()assert 0 <= drop_prob <= 1keep_prob = 1 - drop_probif keep_prob == 0:return torch.zeros_like(X)mask = (torch.rand(X.shape) < keep_prob).float()print(mask)return mask * X / keep_probdef net(X, is_training=True):X = X.view(-1, num_inputs)H1 = (torch.matmul(X, W1.t()) + b1).relu()if is_training:H1 = dropout(H1, drop_prob)return (torch.matmul(H1,W2.t()) + b2).relu()def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None):train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter:l=loss(net(X),y)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:l=loss(net(X,is_training=False),y)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch+1)%10==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lsdrop_probs = np.arange(0,1.1,0.1)Train_ls, Test_ls = [], []for drop_prob in drop_probs:W1,b1,W2,b2 = init_param()loss = nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls =  train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer)   Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(train_ls),len(train_ls))plt.figure(figsize=(10,8))for i in range(0,len(drop_probs)):plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)plt.xlabel('epoch')plt.ylabel('loss')# plt.legend()
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()

运行结果:

在这里插入图片描述

(二)torch.nn实现

import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as pltmnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  class LinearNet(nn.Module):def __init__(self,num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob1,drop_prob2):super(LinearNet,self).__init__()self.linear1 = nn.Linear(num_inputs,num_hiddens1)self.relu = nn.ReLU()self.drop1 = nn.Dropout(drop_prob1)self.linear2 = nn.Linear(num_hiddens1,num_hiddens2)self.drop2 = nn.Dropout(drop_prob2)self.linear3 = nn.Linear(num_hiddens2,num_outputs)self.flatten  = nn.Flatten()def forward(self,x):x = self.flatten(x)x = self.linear1(x)x = self.relu(x)x = self.drop1(x)x = self.linear2(x)x = self.relu(x)x = self.drop2(x)x = self.linear3(x)y = self.relu(x)return ydef train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter:l=loss(net(X),y)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:l=loss(net(X),y)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch+1)%5==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_ls    num_inputs,num_hiddens1,num_hiddens2,num_outputs =784, 256,256,10
num_epochs=20
lr = 0.001
drop_probs = np.arange(0,1.1,0.1)
Train_ls, Test_ls = [], []for drop_prob in drop_probs:net = LinearNet(num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob,drop_prob)for param in net.parameters():nn.init.normal_(param,mean=0, std= 0.01)loss = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(),lr)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,net.parameters,lr,optimizer)Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(train_ls),len(train_ls))
plt.figure(figsize=(10,8))
for i in range(0,len(drop_probs)):plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)plt.xlabel('epoch')plt.ylabel('loss')
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()input = torch.randn(2, 5, 5)
m = nn.Sequential(
nn.Flatten()
)
output = m(input)
output.size()

运行结果:

在这里插入图片描述

关于dropout的原理,网上资料很多,一般都是用一个正态分布的矩阵,比较矩阵元素和(1-dropout),大于(1-dropout)的矩阵元素值的修正为1,小于(1-dropout)的改为1,将输入的值乘以修改后的矩阵,再除以(1-dropout)。

疑问:

  1. 数值经过正态分布矩阵的筛选后,还要除以 (1-dropout),这样做的原因是什么?
  2. Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。

相关文章:

前馈神经网络dropout实例

直接看代码。 &#xff08;一&#xff09;手动实现 import torch import torch.nn as nn import numpy as np import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt#下载MNIST手写数据集 mnist_train torchvision.datasets.MN…...

Android DataStore:安全存储和轻松管理数据

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、使用3.1 Preferences DataStore添加依赖数据读…...

opencv进阶12-EigenFaces 人脸识别

EigenFaces 通常也被称为 特征脸&#xff0c;它使用主成分分析&#xff08;Principal Component Analysis&#xff0c;PCA&#xff09; 方法将高维的人脸数据处理为低维数据后&#xff08;降维&#xff09;&#xff0c;再进行数据分析和处理&#xff0c;获取识别结果。 基本原理…...

The internal rate of return (IRR)

内部收益率 NPV(Net Present Value)_spencer_tseng的博客-CSDN博客...

半导体自动化专用静电消除器主要由哪些部分组成

半导体自动化专用静电消除器是一种用于消除半导体生产过程中的静电问题的设备。由于半导体制造过程中对静电的敏感性&#xff0c;静电可能会对半导体器件的质量和可靠性产生很大的影响&#xff0c;甚至造成元件损坏。因此&#xff0c;半导体生产中采用专用的静电消除器是非常重…...

【C++入门到精通】C++入门 —— deque(STL)

阅读导航 前言一、deque简介1. 概念2. 特点 二、deque使用1. 基本操作&#xff08;增、删、查、改&#xff09;2. 底层结构 三、deque的缺陷四、 为什么选择deque作为stack和queue的底层默认容器总结温馨提示 前言 文章绑定了VS平台下std::deque的源码&#xff0c;大家可以下载…...

Codeforces Round 893 (Div. 2) D.Trees and Segments

原题链接&#xff1a;Problem - D - Codeforces 题面&#xff1a; 大概意思就是让你在翻转01串不超过k次的情况下&#xff0c;使得a*&#xff08;0的最大连续长度&#xff09;&#xff08;1的最大连续长度&#xff09;最大&#xff08;1<a<n&#xff09;。输出n个数&…...

SpringBoot + Vue 前后端分离项目 微人事(九)

职位管理后端接口设计 在controller包里面新建system包&#xff0c;再在system包里面新建basic包&#xff0c;再在basic包里面创建PositionController类&#xff0c;在定义PositionController类的接口的时候&#xff0c;一定要与数据库的menu中的url地址到一致&#xff0c;不然…...

【业务功能篇71】Cglib的BeanCopier进行Bean对象拷贝

选择Cglib的BeanCopier进行Bean拷贝的理由是&#xff0c; 其性能要比Spring的BeanUtils&#xff0c;Apache的BeanUtils和PropertyUtils要好很多&#xff0c; 尤其是数据量比较大的情况下。 BeanCopier的主要作用是将数据库层面的Entity转化成service层的POJO。BeanCopier其实已…...

让eslint的错误信息显示在项目界面上

1.需求描述 效果如下 让eslint中的错误&#xff0c;显示在项目界面上 2.问题解决 1.安装 vite-plugin-eslint 插件 npm install vite-plugin-eslint --save-dev2.配置插件 // vite.config.js import { defineConfig } from vite import vue from vitejs/plugin-vue import e…...

手摸手带你实现一个开箱即用的Node邮件推送服务

目录 ​编辑 前言 准备工作 邮箱配置 代码实现 服务部署 使用效果 题外话 写在最后 相关代码&#xff1a; 前言 由于邮箱账号和手机号的唯一性&#xff0c;通常实现验证码的校验时比较常用的两种方式是手机短信推送和邮箱推送&#xff0c;此外&#xff0c;邮件推送服…...

【Linux网络】网络编程套接字 -- 基于socket实现一个简单UDP网络程序

认识端口号网络字节序处理字节序函数 htonl、htons、ntohl、ntohs socketsocket编程接口sockaddr结构结尾实现UDP程序的socket接口使用解析socket处理 IP 地址的函数初始化sockaddr_inbindrecvfromsendto 实现一个简单的UDP网络程序封装服务器相关代码封装客户端相关代码实验结…...

Python学习笔记第六十四天(Matplotlib 网格线)

Python学习笔记第六十四天 Matplotlib 网格线普通网格线样式网格线 后记 Matplotlib 网格线 我们可以使用 pyplot 中的 grid() 方法来设置图表中的网格线。 grid() 方法语法格式如下&#xff1a; matplotlib.pyplot.grid(bNone, whichmajor, axisboth, )参数说明&#xff1a…...

机器学习与模式识别3(线性回归与逻辑回归)

一、线性回归与逻辑回归简介 线性回归主要功能是拟合数据&#xff0c;常用平方误差函数。 逻辑回归主要功能是区分数据&#xff0c;找到决策边界&#xff0c;常用交叉熵。 二、线性回归与逻辑回归的实现 1.线性回归 利用回归方程对一个或多个特征值和目标值之间的关系进行建模…...

vue启动配置npm run serve,动态环境变量,根据不同环境访问不同域名

首先创建不同环境的配置文件&#xff0c;比如域名和一些常量&#xff0c;创建一个env文件,先看看文件目录 env.dev就是dev环境的域名&#xff0c;.test就是test环境域名&#xff0c;其他同理&#xff0c;然后配置package.json文件 {"name": "require-admin&qu…...

HTML <strike> 标签

HTML5 中不支持 <strike> 标签在 HTML 4 中用于定义删除线文本。 定义和用法 <strike> 标签可定义加删除线文本定义。 浏览器支持 元素ChromeIEFirefoxSafariOpera<strike>YesYesYesYesYes 所有浏览器都支持 <strike> 标签。 HTML 与 XHTML 之间…...

数学建模-模型详解(1)

规划模型 线性规划模型&#xff1a; 当涉及到线性规划模型实例时&#xff0c;以下是一个简单的示例&#xff1a; 假设我们有两个变量 x 和 y&#xff0c;并且我们希望最大化目标函数 Z 5x 3y&#xff0c;同时满足以下约束条件&#xff1a; x > 0y > 02x y < 10…...

MySQL 数据库表的基本操作

一、数据库表概述 在数据库中&#xff0c;数据表是数据库中最重要、最基本的操作对象&#xff0c;是数据存储的基本单位。数据表被定义为列的集合&#xff0c;数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录&#xff0c;每一列代表记录中的一个域。 二、数…...

企业微信电脑端开启chrome调试

首先&#xff1a; Mac端调试开启的快捷键&#xff1a;control shift command d Window端调试开启的快捷键: control shift alt d 这边以Mac为例&#xff0c;我们可以在电脑顶部看到调试的入口&#xff1a; 然后我们点击 『浏览器、webView相关』菜单&#xff0c;勾选上…...

Maven官网下载配置新仓库

1.Maven的下载 Maven的官网地址&#xff1a;Maven – Download Apache Maven 点击Download&#xff0c;查找 Files下的版本并下载如下图&#xff1a; 2.Maven的配置 自己在D盘或者E盘创建一个文件夹&#xff0c;作为本地仓库&#xff0c;存放项目依赖。 将下载好的zip文件进行解…...

银河麒麟V10 达梦安装教程

安装前先准备要安装包&#xff0c;包需要需要区分X86和arm架构。 版本为&#xff1a;dm8_20230419_FTarm_kylin10_sp1_64.iso 达梦数据库下载地址&#xff1a; https://www.aliyundrive.com/s/Qm7Es5BQM5U 第一步创建用户 su - root 1. 创建安装用户组 dminstall。 groupad…...

Python操作MongoDB数据库

安装MongoDB库 pip install pymongopython 代码 Author: tkhywang 2810248865qq.com Date: 2023-08-21 10:22:30 LastEditors: tkhywang 2810248865qq.com LastEditTime: 2023-08-21 11:17:45 FilePath: \PythonProject02\MongoDB 数据库.py Description: 这是默认设置,请设置…...

《HeadFirst设计模式(第二版)》第十一章代码——代理模式

代码文件目录&#xff1a; RMI&#xff1a; MyRemote package Chapter11_ProxyPattern.RMI;import java.rmi.Remote; import java.rmi.RemoteException;public interface MyRemote extends Remote {public String sayHello() throws RemoteException; }MyRemoteClient packa…...

QT的工程文件认识

目录 1、QT介绍 2、QT的特点 3、QT模块 3.1基本模块 3.2扩展模块 4、QT工程创建 1.选择应用的窗体格式 2.设置工程的名称与路径 3.设置类名 4.选择编译器 5、QT 工程解析 xxx.pro 工程配置 xxx.h 头文件 main.cpp 主函数 xxx.cpp 文件 6、纯手工创建一个QT 工程…...

typeScript安装及TypeScript tsc 不是内部或外部命令,也不是可运行的程序或批处理文件解决办法

一、typeScript安装&#xff1a; 1、首先确定系统中已安装node, winr 输入cmd 打开命令行&#xff0c;得到版本号证明系统中已经安装node node -v //v18.17.0 2、使用npm 全局安装typeScript # 全局安装 TypeScript npm i -g typescript 二、检查是否安装成功ts #检查t…...

SWUST 派森练习题:P111. 摩斯密码翻译器

描述 摩斯密码&#xff08;morse code)&#xff0c;又称摩斯电码、摩尔斯电码&#xff08;莫尔斯电码&#xff09;&#xff0c;是一种时通时断的信号代码&#xff0c;通过不同的信号排列顺序来表达不同的英文字母、数字和标点符号&#xff1b;通信时&#xff0c;将英文字母等内…...

如何在控制台查看excel内容

背景 最近发现打开电脑的excel很慢&#xff0c;而且使用到的场景很少&#xff0c;也因为mac自带了预览的功能。但是shigen就是闲不住&#xff0c;想自己搞一个excel预览软件&#xff0c;于是在一番技术选型之后&#xff0c;我决定使用python在控制台显示excel的内容。 具体的需…...

Echarts、js编写“中国主要城市空气质量对比”散点图 【亲测】

本次实验通过可视化工具Echarts来对全国主要城市的&#xff30;&#xff2d;2.5的值进行直观的展示&#xff0c;使人们可以快速的发现信息的关键点&#xff0c;从而对各个城市的空气质量情况有直观的了解。 先看效果 上代码&#xff1a; <!DOCTYPE html> <html>&…...

linux不分区直接在文件系统根上开swap

root下&#xff0c;直接创swapfile dd if/dev/zero of/swapfile bs1M count8192然后 mkswap swapfile swapon swapfile修改fstab # /etc/fstab: static file system information. # # Use blkid to print the universally unique identifier for a # device; this may be us…...

React请求机制优化思路 | 京东云技术团队

说起数据加载的机制&#xff0c;有一个绕不开的话题就是前端性能&#xff0c;很多电商门户的首页其实都会做一些垂直的定制优化&#xff0c;比如让请求在页面最早加载&#xff0c;或者在前一个页面就进行预加载等等。随着react18的发布&#xff0c;请求机制这一块也是被不断谈起…...

搭建网站挣钱/网易疫情实时最新数据

负责控制信号的输入和输出叫做使能&#xff0c;是一个动词&#xff0c;英文‘Enable’。英文Enable&#xff0c;前缀en-就是使的意思&#xff0c;able就是能够。合起来就是使能。使能通俗点说就是一个“允许”信号&#xff0c;进给使能也就是允许进给的信号&#xff0c;也就是说…...

炫酷业务网站/网络营销属于什么专业类型

原文链接&#xff1a;http://tecdat.cn/?p19839​tecdat.cn机器学习算法可用于找到最佳值来交易您的指标。相对强弱指标(RSI)是最常见的技术指标之一。它用于识别超卖和超买情况。传统上&#xff0c;交易者希望RSI值超过70代表超买市场状况&#xff0c;而低于30则代表超卖市场…...

wordpress改wp admin/软文营销名词解释

题目&#xff1a; 0,1&#xff0c;...n-1这n个数字排成一个圆圈&#xff0c;从数字0开始每次从这个圆圈里删除第m个数字&#xff0c;求出这个圆圈里剩下的最后一个数字。 思路&#xff1a; 1、环形链表模拟圆圈 创建一个n个节点的环形链表&#xff0c;然后每次在这个链表中删除…...

做厂房的网站/市场调研报告万能模板

React Native的版本升级插件&#xff08;仅是android&#xff09;, react-native版本需要0.17.0及以上 如何安装 1.首先安装npm包 npm install react-native-upgrade-android --save 2.link 自动link方法~ npm requires node version 4.1 or higher npm link link成功命令行会提…...

哪个网站能在百度做推广/今日国际新闻头条新闻

文章目录前言概述二、正文forEach2.mapfiltereverysomereducefindfindIndex总结前言 我个人而言数组的遍历方法是常用到的&#xff0c;每次用我都现搜&#xff0c;所以还是自己整理一下吧。若有术语不当之处&#xff0c;欢迎指出。 概述 我主要列出的是封装好的数组遍历方法&a…...

网站制作怎么做搜索栏/媒体公关

http://acm.hdu.edu.cn/showproblem.php?pid5917 即世界上任意6个人中&#xff0c;总有3个人相互认识&#xff0c;或互相皆不认识。 所以子集 > 6的一定是合法的。 然后总的子集数目是2^n&#xff0c;减去不合法的&#xff0c;暴力枚举即可。 选了1个肯定不合法&#xff0c…...