当前位置: 首页 > news >正文

生成对抗网络——GAN深度卷积实现(代码+理解)

        本篇博客为 上篇博客的 另一个实现版本,训练流程相同,所以只实现代码,感兴趣可以跳转看一下。

  生成对抗网络—GAN(代码+理解)

http://t.csdnimg.cn/HDfLOicon-default.png?t=N7T8http://t.csdnimg.cn/HDfLO


目录

一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

2. 模型训练时

3. 优化器定义

4. 训练数据

5. 模型结构

(1)生成器        

(2)判别器


一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

import torch
import torch.nn as nn
import argparse
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as npparser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=20, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)# 加载数据
dataloader = torch.utils.data.DataLoader(datasets.MNIST("./others/",train=False,download=False,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02) # 给定均值和标准差的正态分布N(mean,std)中生成值torch.nn.init.constant_(m.bias.data, 0.0)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.init_size = opt.img_size // 4  # 原为28*28,现为32*32,两边各多了2self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),    # 调整数据的分布,使其 更适合于 下一层的 激活函数或学习nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, z):out = self.l1(z)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return blockself.model = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样(图片进行 4次卷积操作,变为ds_size * ds_size尺寸大小)ds_size = opt.img_size // 2 ** 4self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1),nn.Sigmoid())def forward(self, img):out = self.model(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)return validity# 实例化
generator = Generator()
discriminator = Discriminator()# 初始化参数
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# 交叉熵损失函数
adversarial_loss = torch.nn.BCELoss()def gen_img_plot(model, epoch, text_input):prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16])plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow((prediction[i] + 1) / 2)plt.axis('off')plt.show()# ----------
#  Training
# ----------
D_loss_ = []  # 记录训练过程中判别器的损失
G_loss_ = []  # 记录训练过程中生成器的损失
for epoch in range(opt.n_epochs):# 初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(dataloader)  # 返回批次数for i, (imgs, _) in enumerate(dataloader):valid = torch.ones(imgs.shape[0], 1)fake = torch.zeros(imgs.shape[0], 1)# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()z = torch.randn(imgs.shape[0], opt.latent_dim)gen_imgs = generator(z)g_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()real_loss = adversarial_loss(discriminator(imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))# batches_done = epoch * len(dataloader) + i# if batches_done % opt.sample_interval == 0:#     save_image(gen_imgs.data[:25], "others/images/%d.png" % batches_done, nrow=5, normalize=True)# 累计每一个批次的losswith torch.no_grad():D_epoch_loss += d_lossG_epoch_loss += g_loss# 求平均损失with torch.no_grad():D_epoch_loss /= countG_epoch_loss /= countD_loss_.append(D_epoch_loss.item())G_loss_.append(G_epoch_loss.item())text_input = torch.randn(opt.batch_size, opt.latent_dim)gen_img_plot(generator, epoch, text_input)x = [epoch + 1 for epoch in range(opt.n_epochs)]
plt.figure()
plt.plot(x, G_loss_, 'r')
plt.plot(x, D_loss_, 'b')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['G_loss','D_loss'])
plt.show()

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

        函数 weights_init_normal 用于初始化 模型参数,为什么要 以 均值和标准差 的正态分布中采样的数 为标准?

2. 模型训练时

        这里“d_loss = (real_loss + fake_loss) / 2” 中的 “/ 2” 操作,在 实际训练中 有什么作用?

        由(real_loss + fake_loss) / 2的 得到 的 d_loss 与(real_loss+fake_loss)得到的 d_loss 进行 回溯,两者结果会 有什么不同吗?

3. 优化器定义

        设置 betas=(opt.b1, opt.b2) 有什么 实际的作用?通俗易懂的讲一下

        betas=(opt.b1, opt.b2) 是怎样 更新学习率的?

4. 训练数据

        这里我们用的data为 MNIST,为什么img_size设置为 32,不是 28?

5. 模型结构

(1)生成器        

        解释一下为什么是“Upsample, Conv2d, BatchNorm2d, LeakyReLU ”这种顺序?

(2)判别器

        模型的 基本 运算步骤是什么?其中为什么需要 “Dropout2d( p=0.25, inplace=False)”这一步?

        关于“ds_size” 和 “128 * ds_size ** 2”的实际意义?


                                后续更新 GAN的其他模型结构。

相关文章:

生成对抗网络——GAN深度卷积实现(代码+理解)

本篇博客为 上篇博客的 另一个实现版本,训练流程相同,所以只实现代码,感兴趣可以跳转看一下。 生成对抗网络—GAN(代码理解) http://t.csdnimg.cn/HDfLOhttp://t.csdnimg.cn/HDfLO 目录 一、GAN深度卷积实现 1. 模型…...

gbase8s数据库阻塞检查点和非阻塞检查点的执行机制

1. 检查点的描述 为了便于数据库系统的复原和逻辑恢复,数据库服务器生成的一致性标志点,称为检查点,其是建立在数据库系统的已知和一致状态时日志中的某个时间点检查点的目的在于定期将逻辑日志中的重新启动点向前移动 如果存在检查点&#…...

ARM32开发--串口库封装(初级)

知不足而奋进望远山而前行 目录 文章目录 前言 目标 内容 开发流程 文件目录创建 分组创建 接口定义 完整代码 总结 前言 在嵌入式软件开发中,封装抽取流程和抽取封装策略是非常重要的技术,能够提高代码的复用性和可维护性。本文将介绍如何在文…...

统一管理:Vue公共组件/公共样式/全局自定义指令

main.js 引入存放公共文件的文件路径 import "./plugins";src/plugins文件夹下的index.js 在处理公共文件中分别引入 /* 公共引入,勿随意修改,修改时需经过确认 */ import Vue from "vue";import "/icons"; // 图标 import ByuiQueryForm fr…...

Linux之旅: 基础知识点的终极指南

文章目录 1、Linux的目录结构2、ls命令3、管理文件和目录4、linux命令使用细节和技巧5、权限管理基本命令6、搜索命令7、管道符与重定向8、压缩和解压命令9、用户及vim编辑器10、用户和用户组管理一、Linux系统用户账号的基本管理二、Linux系统用户组的管理 1、Linux的目录结构…...

C#部分方法有什么用处?和传统方法有什么区别?什么时候用合适?

在C#中,部分类(partial class)和部分方法(partial method)是两个不同的概念,但它们经常一起使用,特别是在代码生成和框架设计中。下面我将分别解释这两个概念,并讨论它们的用处、与传…...

elasticsearch hanlp插件远程词典配置

elasticsearch hanlp插件远程词典配置 背景远程词典配置新增远程词典文件修改hanlp-remote.xml自动加载词典 远程词典测试 背景 在使用elasticsearch的过程中,总会遇到与分词相关的需求,这里将针对常用的elasticsearch hanlp(后面统称为 es …...

力扣每日一题 6/18 字符串/模拟

博客主页:誓则盟约系列专栏:IT竞赛 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 2288.价格减免 【中等】 题目: 句子 是由若干个单词组成的字符…...

架构设计 - Nginx Proxy Cache 缓存配置

摘要: web 应用业务缓存通常3级: 一级缓存:JVM 本地缓存 二级缓存:Redis集中式缓存 三级缓存:Nginx Proxy Cache 缓存 或 Nginx Lua 缓存 四级缓存:静态资源CDN缓存 本文主要分享 Nginx Proxy Cache 缓…...

【前端】HTML5基础

目录 0 参考1 网页1.1 什么是网页1.2 什么是HTML1.3 网页的形成 2 浏览器2.1 常用的浏览器2.2 浏览器内核 3 Web标准3.1 为什么需要Web标准3.2 Web标准的构成 4 HTML 标签4.1 HTML语法规范4.1.1 基本语法概述4.1.2 标签关系4.1.2.1 包含关系4.1.2.2 并列关系 4.2 HTML基本结构标…...

9个最佳性能测试工具(2024)

1、前言 性能测试检查软件程序在预期工作负载下的速度、响应时间、可靠性、资源使用情况和可扩展性。性能测试的目的不是发现功能缺陷,而是消除软件或设备中的性能瓶颈。 性能测试为利益相关者提供有关其应用程序的速度、稳定性和可扩展性的信息。更重要的是&…...

RTthread+STM32F407ZGTx+烟雾报警检测+蜂鸣器报警+LED闪烁||使用RTthread Studio

目录 实验背景 1.安装环境 2.配置环境 3.先编译下载实例程序2,观察DS0是否闪烁 4.实验方法 5.实例代码 6.硬件连接 7.实验效果 8.关于这次开发遇到的问题 1.反应慢,都熄灭1分钟多了,才报的问题? 2.关于rt_pin_mode(KEY…...

k8s资源的基本操作

文章目录 一、Namespace1、概述2、预定义的k8s命名空间2.1、default2.2、kube-public2.3、kube-system2.4、kube-node-lease 3、命名空间基本操作3.1、查看3.1.1、查看所有的命名空间3.1.2、查看指定的命名空间3.1.3、指定输出格式3.1.4、查看ns详情 3.2、创建3.2.1、命令行创建…...

19.面包屑导航制作

面包屑导航制作 官网&#xff1a;组件 | Element 1. 在layout下新建BreadCrumb.vue BreadCrumb.vue <template><div class"bread-text"><el-breadcrumb class"bred"separator"/"><el-breadcrumb-item v-for"item in…...

做动画?Animatediff 和 ComfyUI 更配哦!

如果从工作流和内存利用率的角度来说&#xff0c;Animatediff 和 ComfyUI 可能更配一些&#xff0c;毕竟制作动画是一个很吃内存的操作。 首先&#xff0c;我们需要在管理器中下载 Animatediff 插件&#xff0c;当然也可以直接导入听雨的工作流&#xff0c;然后在管理器的安装…...

笔记-python里面的xlrd模块详解

那我就一下面积个问题对xlrd模块进行学习一下&#xff1a; 1.什么是xlrd模块&#xff1f; 2.为什么使用xlrd模块&#xff1f; 3.怎样使用xlrd模块&#xff1f; 1.什么是xlrd模块&#xff1f; ♦python操作excel主要用到xlrd和xlwt这两个库&#xff0c;即xlrd是读excel&…...

oracle将字符串中的字符和数字拆分开等功能

将字符串中的字符和数字拆分开 create or replace procedure F_GetNumber1( inString IN VARCHAR2,n_return1 out varchar2, n_return2 out varchar2) ISDCHAR VARCHAR2(1024); OUTCHAR VARCHAR2(1024); j number default 0; ulen number; BEGINOUTCHAR:;DCHAR:TRIM(inStr…...

汇编基础之使用vscode写hello world

汇编语言&#xff08;Assembly Language&#xff09; 概述 汇编语言&#xff08;Assembly Language&#xff09;是一种低级编程语言&#xff0c;它直接对应于计算机的机器代码&#xff08;machine code&#xff09;&#xff0c;但使用了更易读的文本符号。每台个人计算机都有…...

APS计划排程系统如何打破装备使用约束

APS计划排程系统是离散制造型企业在计划控制方向的重要支撑&#xff0c;它提供的是交期预测、订单排产计划、物料采购计划、人力分配计划等等。近些几年来&#xff0c;多品种、小批量、多订单的生产模式&#xff0c;让企业的计划员应接不暇、疲累不堪&#xff0c;传统的人工经验…...

gigachad - suid

gigachadeasyftp利用、google反图搜索、 suid提权、s-nail 提权 主机发现 ┌──(kali㉿kali)-[~/桌面/OSCP] └─$ sudo netdiscover -i eth0 -r 192.168.44.138/24服务探测 ┌──(kali㉿kali)-[~/桌面/OSCP] └─$ sudo nmap -sV -A -T 4 -p- 192.168.44.138 |_/kingchad…...

QtScript模块

在Qt中&#xff0c;可以使用Qt Script模块来将C类和方法绑定到Qt脚本引擎中&#xff0c;从而使得可以在Qt脚本中调用这些C类和方法。以下是一个简单的示例&#xff0c;演示了如何在Qt中将C类暴露给Qt Script引擎&#xff1a; 假设有一个名为 MyClass 的C类&#xff0c;其头文件…...

qt中for循环不要使用循环中会更改的变量

检查代码&#xff0c;发现始终会少了一位&#xff0c;最后发现我在使用for循环时&#xff0c;懒省事&#xff0c;判断条件中使用的变量是涉及到循环体中更改的变量&#xff0c;代码如下&#xff0c;更直观 for (int i 0; i < m_images.size(); i) {packageToDBList[0].imag…...

spark独立集群搭建

spark独立集群搭建(不依赖Hadoop) 1、上传spark-2.4.5-bin-hadoop2.7.tgz至 /usr/local/moudel &#xff0c;再解压到 /usr/local/soft tar -zxvf spark-2.4.5-bin-hadoop2.7.tgz -C /usr/local/soft/ 重命名 mv spark-2.4.5-bin-hadoop2.7/ spark-2.4.5 配…...

【BFS算法】广度搜索·由起点开始逐层向周围扩散求得最短路径(算法框架+题目)

0、前言 深度优先搜索是DFS&#xff08;Depth Frst Search)&#xff0c;其实就是前面所讲过的回溯算法&#xff0c;它的特点和它的名字一样&#xff0c;首先在一条路径上不断往下&#xff08;深度&#xff09;遍历&#xff0c;获得答案之后再返回&#xff0c;再继续往下遍历。…...

微信小程序---登录

手机号登录 手机号快速验证和手机号实时验证区别 手机号快速验证组件&#xff0c;平台会对号码进行验证&#xff0c;但不保证是实时验证&#xff1b;收费0.0.3元手机号实时验证组件&#xff0c;在每次请求时&#xff0c;平台均会对用户选择的手机号进行实时验证。收费0.0.4元…...

IPython大师课:提升数据科学工作效率的终极工具

IPython是一个增强的Python交互式shell&#xff0c;它提供了丰富的功能和易用性改进&#xff0c;特别适合进行数据分析、科学计算和一般的Python开发。本文将全面介绍IPython的基本概念、使用方法、主要作用以及注意事项。 一、IPython简介 1. IPython的起源 IPython最初由Fe…...

抖音素材网站平台有哪些?素材下载网站库分享

在这个视觉信息充斥的时代&#xff0c;抖音已经成为众多自媒体人展示才华的舞台。要在众多创作者中脱颖而出&#xff0c;不仅需要独特的创意&#xff0c;还需要优质的素材来支持你的内容制作。今天&#xff0c;我将介绍几个为抖音视频提供高品质素材的网站&#xff0c;包括国内…...

MODBUS TCP协议实例数据帧详细分析

MODBUS TCP协议实例数据帧详细分析 1.简介 2.ModbusTCP数据帧 2.1.报文头MBAP 2.2.帧结构PDU 3.ADU详细结构 3.1. 0x01&#xff1a;读线圈 3.2. 0x02&#xff1a;读离散量输入 3.3. 0x03&#xff1a;读保持寄存器 3.4. 0x04&#xff1a;读输入寄存器 3.5. 0x05&#xff1a;写单…...

Spring Boot启动与运行机制详解:初学者友好版

Spring Boot启动与运行机制详解&#xff1a;初学者友好版 随着微服务的兴起和容器化部署的流行&#xff0c;Spring Boot以其快速搭建、简单配置和自动化部署的特性&#xff0c;成为了众多开发者的首选。对于初学者而言&#xff0c;理解Spring Boot的启动与运行机制是掌握其精髓…...

Ubuntu 22.04 解决 firefox 中文界面乱码

问题复现 在为Ubuntu 22.04 Server安装完整的GNOME 42.01桌面后&#xff0c;将桌面语言设置为中文时&#xff0c;打开Firefox可能会出现中文乱码的问题。经过网上调查发现&#xff0c;这个问题是由Snap软件包引起的。 解决方案 为了避免在Ubuntu 22.04中文模式下的乱码问题…...

wordpress 自动 采集/电工培训内容

在父工程下创建了一个模块service_test&#xff0c;发现配置文件无效&#xff0c;没有变成小绿叶。 打开pom.xml文件后发现组织名错了&#xff0c;新建的模块放在了service下&#xff0c;但pom.xml文件中的组织名为parent 只需要将改成service即可。 如下所示&#xff0c;改后…...

怎样做网站变手机软件/谷歌google官网入口

进程及作业管理Uninterruptible sleep: 不可中断的睡眠Interruptible sleep:可中断睡眠COW: copy on write写时复制VSZ: 虚拟内存集RSS: 常驻内存集100-139&#xff1a;用户可控制 nice值&#xff1a;优雅的 -20 ~ -19 100 ~ 139 普通用户仅能调高进程的nice值 超级用户随…...

幼儿园主题网络图设计技巧/seo关键词布局

1.lseek 作用&#xff1a;移动文件指针&#xff0c;并且返回当前指针的值&#xff01; off_t lseek(int filedes, off_to ffset, int whence) ;对参数offset 的解释与参数whence的值有关。 • 若whence是SEEK_SET&#xff0c;则将该文件的位移量设置为距文件开始处offset 个字…...

网站开发如何下载服务器文档/提高工作效率8个方法

在C的类中&#xff0c;普通成员函数不能作为pthread_create的线程函数&#xff0c;如果要作为pthread_create中的线程函数&#xff0c;必须是static !在C语言中&#xff0c;我们使用pthread_create创建线程&#xff0c;线程函数是一个全局函数&#xff0c;所以在C中&#xff0c…...

简单网站建设软件有哪些/东莞关键词优化实力乐云seo

引用自&#xff1a;http://blog.csdn.net/jackyxu_2008/archive/2009/03/21/4009791.aspx 最近用powerDesinger遇到一些小问题&#xff0c;遇到好几次同样的问题了&#xff0c;写在这里&#xff0c;以备查用&#xff1a;-----------------------------------------------------…...

如何建设阿里巴巴网站/博客网站注册

有的时候我们需要集成ListActivity&#xff0c;注意点1&#xff0c;这个时候我们的xml中的<ListView>标签中的id属性不能够随便自己命名&#xff0c;而是要固定为android:id"id/android:list"&#xff0c;具体如下&#xff1a; main3.xml: 1 <?xml version…...