当前位置: 首页 > news >正文

GAN的原理分析与实例

为了便于理解,可以先玩一玩这个网站:GAN Lab: Play with Generative Adversarial Networks in Your Browser!

GAN的本质:枯叶蝶和鸟。生成器的目标:让枯叶蝶进化,变得像枯叶,不被鸟准确识别。判别器的目标:准确判别是枯叶还是鸟

伪代码: 

案例:

原始数据:

案例结果: 

 案例完整代码:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),#转置卷积层:输入特征映射的尺寸会放大,通道数可能会减小,普通卷积层:输入特征映射的尺寸会缩小,但通道数可能会增加nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True),      # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96real_image = Variable(image)#real_image = real_image.cuda()if (i + 1) % d_every == 0:  #d_every = 1,每一个batch训练一次discriminatoroptimizer_d.zero_grad()output = Discriminator(real_image)      # 尽可能把真图片判为Trueerror_d_real = criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises).detach()       # 根据噪声生成假图fake_output = Discriminator(fake_img)       # 尽可能把假图片判为Falseerror_d_fake = criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i + 1) % g_every == 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises)        # 这里没有detachfake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为Trueerror_g = criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags = Generator(fix_noises)fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5# x = torch.rand(64, 3, 96, 96)fig = plt.figure(1)i = 1for image in fix_fake_imags:ax = fig.add_subplot(8, 8, eval('%d' % i)) #将Figure划分为8行8列的子图网格,并将当前的子图添加到第i个位置。# plt.xticks([]), plt.yticks([])  # 去除坐标轴plt.axis('off')plt.imshow(image.permute(1, 2, 0)) #permute()函数可以对维度进行重排,Matplotlib期望的图像格式是(H, W, C),即高度、宽度、通道i += 1plt.subplots_adjust(left=None,  # the left side of the subplots of the figureright=None,  # the right side of the subplots of the figurebottom=None,  # the bottom of the subplots of the figuretop=None,  # the top of the subplots of the figurewspace=0.05,  # the amount of width reserved for blank space between subplotshspace=0.05)  # the amount of height reserved for white space between subplots)plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)plt.savefig("images/%dcgan.png" % num)if __name__ == '__main__':transform = tv.transforms.Compose([tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数])dataset = tv.datasets.ImageFolder(dir, transform=transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'print('数据加载完毕!')Generator = NetGenerator()Discriminator = NetDiscriminator()optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))criterion = torch.nn.BCELoss()true_labels = Variable(torch.ones(batch_size))     # batch_sizefake_labels = Variable(torch.zeros(batch_size))fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布# if torch.cuda.is_available() == True:#     print('Cuda is available!')#     Generator.cuda()#     Discriminator.cuda()#     criterion.cuda()#     true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()#     fix_noises, noises = fix_noises.cuda(), noises.cuda()#plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]plot_epoch = [1,5,10,50,100,200,500,800,1000,1200,1500]for i in range(1500):        # 最大迭代次数train()print('迭代次数:{}'.format(i))if i in plot_epoch:show(i)

http://t.csdnimg.cn/FTSriicon-default.png?t=N7T8http://t.csdnimg.cn/FTSri

相关文章:

GAN的原理分析与实例

为了便于理解&#xff0c;可以先玩一玩这个网站&#xff1a;GAN Lab: Play with Generative Adversarial Networks in Your Browser! GAN的本质&#xff1a;枯叶蝶和鸟。生成器的目标&#xff1a;让枯叶蝶进化&#xff0c;变得像枯叶&#xff0c;不被鸟准确识别。判别器的目标&…...

什么是POM设计模式?

为什么要用POM设计模式 前期&#xff0c;我们学会了使用PythonSelenium编写Web UI自动化测试线性脚本 线性脚本&#xff08;以快递100网站登录举栗&#xff09;&#xff1a; import timefrom selenium import webdriver from selenium.webdriver.common.by import Bydriver …...

没有数据线,在手机上查看电脑备忘录怎么操作

在工作中&#xff0c;电脑和手机是我最常用的工具。我经常需要在电脑上记录一些重要的工作事项&#xff0c;然后又需要在手机上查看这些记录&#xff0c;以便随时了解工作进展。但是&#xff0c;每次都需要通过数据线来传输数据&#xff0c;实在是太麻烦了。 有一次&#xff0…...

Elasitcsearch--解决CPU使用率升高

原文网址&#xff1a;Elasitcsearch--解决CPU使用率升高_IT利刃出鞘的博客-CSDN博客 简介 本文介绍如何解决ES导致的CPU使用率升高的问题。 问题描述 线上环境 Elasticsearch CPU 使用率飙升常见问题如下&#xff1a; Elasticsearch 使用线程池来管理并发操作的 CPU 资源。…...

vue和jQuery有什么区别

Vue 和 jQuery 是两种不同类型的前端工具&#xff0c;它们有一些显著的区别&#xff1a; Vue 响应式数据绑定&#xff1a;Vue 提供了双向数据绑定和响应式更新的能力&#xff0c;使得数据与视图之间的关系更加直观和易于维护。组件化开发&#xff1a;Vue 鼓励使用组件化的方式…...

[Android] Binder all-in-all

前言&#xff1a; Binder 是一种 IPC 机制&#xff0c;使用共享内存实现进程间通讯&#xff0c;既可以传递消息&#xff0c;也可以传递创建在共享内存中的对象&#xff0c;而Binder本身就是用共享内存实现的&#xff0c;因此遵循Binder写法的类是可以实例化后在进程间传递的。…...

无人零售柜:快捷舒适购物体验

无人零售柜&#xff1a;快捷舒适购物体验 通过无人零售柜和人工智能技术&#xff0c;消费者在购物过程中可以自由选择商品&#xff0c;根据个人需求和喜好查询商品清单。这种自主选择的购物环境能够为消费者提供更加舒适和满意的体验。此外&#xff0c;无人零售柜还具有节约时间…...

Bash script进阶笔记

数组类型 arr(1 2 3) # 最基础的方式声明数组&#xff0c;用小括号()&#xff0c;元素之间逗号分隔 arr([1]10 [2]20 [3]30) # 初始化时指定index declare -a arr(1 2 3) # 用declare -a声明数组&#xff0c;小括号外面可选使用单引号、双引号 declare -a arr‘(1 2 3)’…...

OpenCV图像处理——Python开发中OpenCV视频流的多线程处理方式

前言 在做视觉类项目中&#xff0c;常常需要在Python环境下使用OpenCV读取本地的还是网络摄像头的视频流&#xff0c;之后再调入各种模型&#xff0c;如目标分类、目标检测&#xff0c;人脸识别等等。如果使用单线程处理&#xff0c;很多时候会出现比较严重的时延&#xff0c;…...

webGL开发智慧城市流程

开发智慧城市的WebGL应用程序涉及多个方面&#xff0c;包括城市模型、实时数据集成、用户界面设计等。以下是一个一般性的流程&#xff0c;您可以根据项目的具体需求进行调整&#xff0c;希望对大家有所帮助。 1.需求分析&#xff1a; 确定智慧城市应用程序的具体需求和功能。考…...

Django讲课笔记02:Django环境搭建

文章目录 一、学习目标二、相关概念&#xff08;一&#xff09;Python&#xff08;二&#xff09;Django 三、环境搭建&#xff08;一&#xff09;安装Python1. 从官方网站下载最新版本的Python2. 运行安装程序并按照安装向导进行操作3. 勾选添加到路径复选框4. 完成安装过程5.…...

黑豹程序员-原生JS拖动div到任何地方-自定义布局

效果图 代码html <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0 Transitional//EN"> <html xmlns"http://www.w3.org/1999/xhtml"> <head> <meta http-equiv"Content-Type" content"text/html; charsetutf-8" /…...

<软考高项备考>《论文专题 - 7 论文的项目背景之技术架构》

1 技术架构概况 ➢ 架构前端:HTML ➢ 后端:Java ➢ 数据库: Oracle ➢ 大数据:MapReduce ➢ 人工智能:Python ➢ 物联网:RFID识别&#xff0c;http传输&#xff0c;Java ➢ 开发APP: IOS、Android 2 常用开发语言 序号语言说明1JavaJava是一种跨平台的编程语言&#xff0c;广…...

6.3 C++11 原子操作与原子类型

一、原子类型 1.多线程下的问题 在C中&#xff0c;一个全局数据在多个线程中被同时使用时&#xff0c;如果不加任何处理&#xff0c;则会出现数据同步的问题。 #include <iostream> #include <thread> #include <chrono> long val 0;void test() {for (i…...

智能优化算法应用:基于狮群算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于狮群算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于狮群算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.狮群算法4.实验参数设定5.算法结果6.参考文献7.MA…...

BERT、GPT学习问题个人记录

目录 1. 为什么过去几年大家都在做BERT, 做GPT的人少。 2. 但最近做GPT的多了以及为什么GPT架构的scaling&#xff08;扩展性&#xff09;比BERT好。 3.BERT是否可以用来做生成&#xff0c;如果可以的话为什么大家都用GPT不用BERT. 4. BERT里的NSP后面被认为是没用的&#x…...

HeartBeat监控Mysql状态

目录 一、概述 二、 安装部署 三、配置 四、启动服务 五、查看数据 一、概述 使用heartbeat可以实现在kibana界面对 Mysql 服务存活状态进行观察&#xff0c;如有必要&#xff0c;也可在服务宕机后立即向相关人员发送邮件通知 二、 安装部署 参照章节&#xff1a;监控组件…...

软件开发经常出现的bug原因有哪些

软件开发中出现bug的原因是多方面的&#xff0c;这些原因可能涉及到开发流程、人为因素、设计问题以及其他一系列因素。以下是一些常见的导致bug的原因&#xff1a; 1. 错误的需求分析&#xff1a; 不正确、不完整或者模糊的需求分析可能导致开发人员误解客户的需求&#xff0…...

代码随想录27期|Python|Day15|二叉树|层序遍历|对称二叉树|翻转二叉树

本文图片来源&#xff1a;代码随想录 层序遍历&#xff08;图论中的广度优先遍历&#xff09; 这一部分有10道题&#xff0c;全部可以套用相同的层序遍历方法&#xff0c;但是需要在每一层进行处理或者修改。 102. 二叉树的层序遍历 - 力扣&#xff08;LeetCode&#xff09; 层…...

鸿蒙开发组件之Web

一、加载一个url myWebController: WebviewController new webview.WebviewControllerbuild() {Column() {Web({src: https://www.baidu.com,controller: this.myWebController})}.width(100%).height(100%)} 二、注意点 2.1 不能用Previewer预览 Web这个组件不能使用预览…...

成绩分析。

成绩分析 题目描述 小蓝给学生们组织了一场考试&#xff0c;卷面总分为 100分&#xff0c;每个学生的得分都是一个0到100的整数。 请计算这次考试的最高分、最低分和平均分 输入描述 输入的第一行包含一个整数n(1n104)&#xff0c;表示考试人数。 接下来n行&#xff0c;每行包含…...

Excel实现字母+数字拖拉自动递增,步长可更改

目录 1、带有字母的数字序列自增加&#xff08;步长可变&#xff09; 2、仅字母自增加 3、字母数字同时自增 1、带有字母的数字序列自增加&#xff08;步长可变&#xff09; 使用Excel通常可以直接通过拖拉的方式&#xff0c;实现自增数字&#xf…...

Java之Stream流

一、什么是Stream流 Stream是一种处理集合&#xff08;Collection&#xff09;数据的方式。Stream可以让我们以一种更简洁的方式对集合进行过滤、映射、排序等操作。 二、Stream流的使用步骤 先得到一条Stream流&#xff0c;并把数据放上去利用Stream流中的API进行各种操作 中间…...

vue中element-ui日期选择组件el-date-picker 清空所选时间,会将model绑定的值设置为null 问题 及 限制起止日期范围

一、问题 在Vue中使用Element UI的日期选择组件 <el-date-picker>&#xff0c;当你清空所选时间时&#xff0c;组件会将绑定的 v-model 值设置为 null。这是日期选择器的预设行为&#xff0c;它将清空所选日期后将其视为 null。但有时后端不允许日期传空。 因此&#xff…...

使用模方时,三维模型在su中显示不了怎么办?

答&#xff1a;可以借助截图功能截取模型影像在su中绘制白模。 模方是一款针对实景三维模型的冗余碎片、水面残缺、道路不平、标牌破损、纹理拉伸模糊等共性问题研发的实景三维模型修复编辑软件。模方4.1新增自动单体化建模功能&#xff0c;支持一键自动提取房屋结构&#xff…...

AR-LDM原理及代码分析

AR-LDM原理AR-LDM代码分析pytorch_lightning(pl)的hook流程main.py 具体分析TrainSampleLightningDatasetARLDM blip mm encoder AR-LDM原理 左边是模仿了自回归地从1, 2, ..., j-1来构造 j 时刻的 frame 的过程。 在普通Stable Diffusion的基础上&#xff0c;使用了1, 2, .…...

MySQL常见死锁的发生场景以及如何解决

死锁的产生是因为满足了四个条件&#xff1a; 互斥占有且等待不可强占用循环等待 这个网站收集了很多死锁场景 接下来介绍几种常见的死锁发生场景。其中&#xff0c;id 为主键&#xff0c;no&#xff08;学号&#xff09;为二级唯一索引&#xff0c;name&#xff08;姓名&am…...

Leetcode 47 全排列 II

题意理解&#xff1a; 首先理解全排列是什么&#xff1f;全排列&#xff1a;使用集合中所有元素按照不同元素进行排列&#xff0c;将所有的排列结果的集合称为全排列。 这里的全排列难度升级了&#xff0c;问题在于集合中的元素是可以重复的。 问题&#xff1a;相同的元素会导致…...

C# 图解教程 第5版 —— 第18章 泛型

文章目录 18.1 什么是泛型18.2 C# 中的泛型18.3 泛型类18.3.1 声明泛型类18.3.2 创建构造类型18.3.3 创建变量和实例18.3.4 使用泛型的示例18.3.5 比较泛型和非泛型栈 18.4 类型参数的约束18.4.1 Where 子句18.4.2 约束类型和次序 18.5 泛型方法18.5.1 声明泛型方法18.5.2 调用…...

保障事务隔离级别的关键措施

目录 引言 1. 锁机制的应用 2. 多版本并发控制&#xff08;MVCC&#xff09;的实现 3. 事务日志的记录与恢复 4. 数据库引擎的实现策略 结论 引言 事务隔离级别是数据库管理系统&#xff08;DBMS&#xff09;中的一个关键概念&#xff0c;用于控制并发事务之间的可见性。…...