当前位置: 首页 > news >正文

PyTorch训练深度卷积生成对抗网络DCGAN

文章目录

    • DCGAN介绍
    • 代码
    • 结果
    • 参考

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:
在这里插入图片描述
图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self, channels_img, features_d):super(Discriminator, self).__init__()self.disc = nn.Sequential(# Input: N x channels_img x 64 x 64nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32nn.LeakyReLU(0.2),self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1nn.Sigmoid(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, channels_img, features_g):super(Generator, self).__init__()self.gen = nn.Sequential(# Input: N x z_dim x 1 x 1self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1,),nn.Tanh(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,),nn.BatchNorm2d(out_channels),nn.ReLU(),)def forward(self, x):return self.gen(x)def initialize_weights(model):for m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):nn.init.normal_(m.weight.data, 0.0, 0.02)def test():N, in_channels, H, W = 8, 3, 64, 64z_dim = 100x = torch.randn((N, in_channels, H, W))disc = Discriminator(in_channels, 8)initialize_weights(disc)assert disc(x).shape == (N, 1, 1, 1)gen = Generator(z_dim, in_channels, 8)initialize_weights(gen)z = torch.randn((N, z_dim, 1, 1))assert gen(z).shape == (N, in_channels, H, W)print("success")if __name__ == "__main__":test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:
在这里插入图片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64transforms = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),]
)# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )# comment mnist above and uncomment below if train on CelebA# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0gen.train()
disc.train()for epoch in range(NUM_EPOCHS):# Target labels not needed! <3 unsupervisedfor batch_idx, (real, _) in enumerate(dataloader):real = real.to(device)noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)fake = gen(noise)### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_disc_real + loss_disc_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()# Print losses occasionally and print to tensorboardif batch_idx % 100 == 0:print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")with torch.no_grad():fake = gen(fixed_noise)# take out (up to) 32 examplesimg_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)writer_real.add_image("Real", img_grid_real, global_step=step)writer_fake.add_image("Fake", img_grid_fake, global_step=step)step += 1

结果

训练5个epoch,部分结果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打开tensorboard

在这里插入图片描述

参考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

相关文章:

PyTorch训练深度卷积生成对抗网络DCGAN

文章目录 DCGAN介绍代码结果参考 DCGAN介绍 将CNN和GAN结合起来&#xff0c;把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN) DCGAN的生成器结构&#xff1a; 图片来源&#xff1a;https://arxiv.org/abs/1511.06434 代码 model.py impor…...

Spring-4-掌握Spring事务传播机制

今日目标 能够掌握Spring事务配置 Spring事务管理 1 Spring事务简介【重点】 1.1 Spring事务作用 事务作用&#xff1a;在数据层保障一系列的数据库操作同成功同失败 Spring事务作用&#xff1a;在数据层或业务层保障一系列的数据库操作同成功同失败 1.2 案例分析Spring…...

[PyTorch][chapter 49][创建自己的数据集 1]

前言&#xff1a; 后面几章主要利用DataSet 创建自己的数据集&#xff0c;实现建模&#xff0c; 训练&#xff0c;迁移等功能。 目录: pokemon 数据集深度学习工程步骤 一 pokemon 数据集介绍 1.1 pokemon: 数据集地址&#xff1a; 百度网盘路径: https://pan.baidu.com/s/1…...

中间件(二)dubbo负载均衡介绍

一、负载均衡概述 支持轮询、随机、一致性hash和最小活跃数等。 1、轮询 ① sequences&#xff1a;内部的序列计数器 ② 服务器接口方法权重一样&#xff1a;&#xff08;sequences1&#xff09;%服务器的数量&#xff08;决定调用&#xff09;哪个服务器的服务。 ③ 服务器…...

springboot异步文件上传获取输入流提示找不到文件java.io.FileNotFoundException

springboot上传文件&#xff0c;使用异步操作处理上传的文件数据&#xff0c;出现异常如下&#xff1a; 这个是在异步之后使用传过来的MultipartFile对象尝试调用getInputStream方法发生的异常。 java.io.FileNotFoundException: C:\Users\Administrator\AppData\Local\Temp\to…...

安装jenkins-cli

1、要在 Linux 操作系统上安装 jcli curl -L https://github.com/jenkins-zh/jenkins-cli/releases/latest/download/jcli-linux-amd64.tar.gz|tar xzv sudo mv jcli /usr/local/bin/ 在用户根目录下&#xff0c;增加 jcli 的配置文件&#xff1a; jcli config gen -ifalse …...

linux通过NC工具启动临时端口监听

1.安装nc工具 yum install nc -y2. 启动监听指定端口 #例如监听8080端口 nc -lk 8080#后台监听 nc -lk 8080 &3. 验证 #通过另外一台网络能通的机器&#xff0c;telnet 该机器ip 监听端口能通&#xff0c;并且能接手数据 telnet 192.xxx.xxx.xx 8080...

开源语音聊天软件Mumble

网友 大气 告诉我&#xff0c;Openblocks在国内还有个版本叫 码匠&#xff0c;更贴合国内软件开发的需求&#xff0c;如接入了国内常用的身份认证&#xff0c;接入了国内的数据库和云服务&#xff0c;也对小程序、企微 sdk 等场景做了适配。 在 https://majiang.co/docs/docke…...

JDK 1.6与JDK 1.8的区别

ArrayList使用默认的构造方式实例 jdk1.6默认初始值为10jdk1.8为0,第一次放入值才初始化&#xff0c;属于懒加载 Hashmap底层 jdk1.6与jdk1.8都是数组链表 jdk1.8是链表超过8时&#xff0c;自动转为红黑树 静态方式不同 jdk1.6是先初始化static后执行main方法。 jdk1.8是懒加…...

单片机实训报告

这周我们进行了单片机实训&#xff0c;一周中我们通过七个项目1&#xff1a;P1 口输入/输出 2&#xff1a;继电器控制 3 音频控制 4&#xff1a;子程序设计 5&#xff1a;字符碰头程序设计 6&#xff1a;外部中断 7&#xff1a; 急救车与交通信号灯&#xff0c;练习编写了子程…...

【编织时空四:探究顺序表与链表的数据之旅】

本章重点 链表的分类 带头双向循环链表接口实现 顺序表和链表的区别 缓存利用率参考存储体系结构 以及 局部原理性。 一、链表的分类 实际中链表的结构非常多样&#xff0c;以下情况组合起来就有8种链表结构&#xff1a; 1. 单向或者双向 2. 带头或者不带头 3. 循环或者非…...

PHP8的字符串操作1-PHP8知识详解

字符串是php中最重要的数据之一&#xff0c;字符串的操作在PHP编程占有重要的地位。在使用PHP语言开发web项目的过程中&#xff0c;为了实现某些功能&#xff0c;经常需要对某些字符串进行特殊的处理&#xff0c;比如字符串的格式化、字符串的连接与分割、字符串的比较、查找等…...

电脑提示msvcp140.dll丢失的解决方法,dll组件怎么处理

Windows系统有时在打开游戏或者软件时&#xff0c; 系统会弹窗提示缺少“msvcp140.dll.dll”文件 或者类似错误提示怎么办&#xff1f; 错误背景&#xff1a; msvcp140.dll是Microsoft Visual C Redistributable Package中的一个动态链接库文件&#xff0c;它在运行软件时提…...

stable diffusion基础

整合包下载&#xff1a;秋叶大佬 【AI绘画8月最新】Stable Diffusion整合包v4.2发布&#xff01; 参照&#xff1a;基础04】目前全网最贴心的Lora基础知识教程&#xff01; VAE 作用&#xff1a;滤镜微调 VAE下载地址&#xff1a;C站&#xff08;https://civitai.com/models…...

Greiner–Hormann裁剪算法深度探索:C++实现与应用案例

介绍 在计算几何中&#xff0c;裁剪是一个核心的主题。特别是&#xff0c;多边形裁剪已经被广泛地应用于计算机图形学&#xff0c;地理信息系统和许多其他领域。Greiner-Hormann裁剪算法是其中之一&#xff0c;提供了一个高效的方式来计算两个多边形的交集、并集等。在本文中&…...

Automatically Correcting Large Language Models

本文是大模型相关领域的系列文章&#xff0c;针对《Automatically Correcting Large Language Models: Surveying the landscape of diverse self-correction strategies》的翻译。 自动更正大型语言模型&#xff1a;综述各种自我更正策略的前景 摘要1 引言2 自动反馈校正LLM的…...

【学习FreeRTOS】第8章——FreeRTOS列表和列表项

1.列表和列表项的简介 列表是 FreeRTOS 中的一个数据结构&#xff0c;概念上和链表有点类似&#xff0c;列表被用来跟踪 FreeRTOS中的任务。列表项就是存放在列表中的项目。 列表相当于链表&#xff0c;列表项相当于节点&#xff0c;FreeRTOS 中的列表是一个双向环形链表列表的…...

分布式图数据库 NebulaGraph v3.6.0 正式发布,强化全文索引能力

本次 v3.6.0 版本&#xff0c;主要强化全文索引能力&#xff0c;以及优化部分场景下的 MATCH 性能。 强化 强化增强全文索引功能&#xff0c;具体 pr 参见&#xff1a;#5567、#5575、#5577、#5580、#5584、#5587 优化 支持使用 MATCH 子句检索 VID 或属性索引时使用变量&am…...

在 ubuntu 18.04 上使用源码升级 OpenSSH_7.6p1到 OpenSSH_9.3p1

1、检查系统已安装的当前 SSH 版本 使用命令 ssh -V 查看当前 ssh 版本&#xff0c;输出如下&#xff1a; OpenSSH_7.6p1 Ubuntu-4ubuntu0.7, OpenSSL 1.0.2n 7 Dec 20172、安装依赖&#xff0c;依次执行以下命令 sudo apt update sudo apt install build-essential zlib1g…...

python中可以处理word文档的模块:docx模块

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 话不多说&#xff0c;直接开搞&#xff0c;如果有什么疑惑/资料需要的可以点击文章末尾名片领取源码 一.docx模块 Python可以利用python-docx模块处理word文档&#xff0c;处理方式是面向对象的。 也就是说python-docx模块…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”&#xff0c;无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息&#xff1a; 关注测试号&#xff1a;扫二维码关注测试号。 发送模版消息&#xff1a; import requests da…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

【解密LSTM、GRU如何解决传统RNN梯度消失问题】

解密LSTM与GRU&#xff1a;如何让RNN变得更聪明&#xff1f; 在深度学习的世界里&#xff0c;循环神经网络&#xff08;RNN&#xff09;以其卓越的序列数据处理能力广泛应用于自然语言处理、时间序列预测等领域。然而&#xff0c;传统RNN存在的一个严重问题——梯度消失&#…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装

以下是基于 vant-ui&#xff08;适配 Vue2 版本 &#xff09;实现截图中照片上传预览、删除功能&#xff0c;并封装成可复用组件的完整代码&#xff0c;包含样式和逻辑实现&#xff0c;可直接在 Vue2 项目中使用&#xff1a; 1. 封装的图片上传组件 ImageUploader.vue <te…...

return this;返回的是谁

一个审批系统的示例来演示责任链模式的实现。假设公司需要处理不同金额的采购申请&#xff0c;不同级别的经理有不同的审批权限&#xff1a; // 抽象处理者&#xff1a;审批者 abstract class Approver {protected Approver successor; // 下一个处理者// 设置下一个处理者pub…...

GitHub 趋势日报 (2025年06月06日)

&#x1f4ca; 由 TrendForge 系统生成 | &#x1f310; https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...

从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障

关键领域软件测试的"安全密码"&#xff1a;Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力&#xff0c;从金融交易到交通管控&#xff0c;这些关乎国计民生的关键领域…...

在 Visual Studio Code 中使用驭码 CodeRider 提升开发效率:以冒泡排序为例

目录 前言1 插件安装与配置1.1 安装驭码 CodeRider1.2 初始配置建议 2 示例代码&#xff1a;冒泡排序3 驭码 CodeRider 功能详解3.1 功能概览3.2 代码解释功能3.3 自动注释生成3.4 逻辑修改功能3.5 单元测试自动生成3.6 代码优化建议 4 驭码的实际应用建议5 常见问题与解决建议…...

FOPLP vs CoWoS

以下是 FOPLP&#xff08;Fan-out panel-level packaging 扇出型面板级封装&#xff09;与 CoWoS&#xff08;Chip on Wafer on Substrate&#xff09;两种先进封装技术的详细对比分析&#xff0c;涵盖技术原理、性能、成本、应用场景及市场趋势等维度&#xff1a; 一、技术原…...