当前位置: 首页 > news >正文

生成式AI系列 —— DCGAN生成手写数字

1、模型构建

1.1 构建生成器

# 导入软件包
import torch
import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim=20, image_size=256):super(Generator, self).__init__()self.layer1 = nn.Sequential(nn.ConvTranspose2d(z_dim, image_size * 32,kernel_size=4, stride=1),nn.BatchNorm2d(image_size * 32),nn.ReLU(inplace=True))self.layer2 = nn.Sequential(nn.ConvTranspose2d(image_size * 32, image_size * 16,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 16),nn.ReLU(inplace=True))self.layer3 = nn.Sequential(nn.ConvTranspose2d(image_size * 16, image_size * 8,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 8),nn.ReLU(inplace=True))self.layer4 = nn.Sequential(nn.ConvTranspose2d(image_size * 8, image_size *4,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 4),nn.ReLU(inplace=True))self.layer5 = nn.Sequential(nn.ConvTranspose2d(image_size * 4, image_size * 2,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 2),nn.ReLU(inplace=True))self.layer6 = nn.Sequential(nn.ConvTranspose2d(image_size * 2, image_size,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size),nn.ReLU(inplace=True))self.last = nn.Sequential(nn.ConvTranspose2d(image_size, 3, kernel_size=4,stride=2, padding=1),nn.Tanh())# 注意:因为是黑白图像,所以只有一个输出通道def forward(self, z):out = self.layer1(z)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = self.layer6(out)out = self.last(out)return outif __name__ == "__main__":import matplotlib.pyplot as pltG = Generator(z_dim=20, image_size=256)# 输入的随机数input_z = torch.randn(1, 20)# 将张量尺寸变形为(1,20,1,1)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)#输出假图像fake_images = G(input_z)print(fake_images.shape)img_transformed = fake_images[0].detach().numpy().transpose(1, 2, 0)plt.imshow(img_transformed)plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M0oWDbXr-1692468683782)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\Figure_1.png)]

1.1 构建判别器

class Discriminator(nn.Module):def __init__(self, z_dim=20, image_size=256):super(Discriminator, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, image_size, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))#注意:因为是黑白图像,所以输入通道只有一个self.layer2 = nn.Sequential(nn.Conv2d(image_size, image_size*2, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer3 = nn.Sequential(nn.Conv2d(image_size*2, image_size*4, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer4 = nn.Sequential(nn.Conv2d(image_size*4, image_size*8, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer5 = nn.Sequential(nn.Conv2d(image_size*8, image_size*16, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer6 = nn.Sequential(nn.Conv2d(image_size*16, image_size*32, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.last = nn.Conv2d(image_size*32, 1, kernel_size=4, stride=1)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = self.layer6(out)out = self.last(out)return outif __name__ == "__main__":#确认程序执行D = Discriminator(z_dim=20, image_size=64)#生成伪造图像input_z = torch.randn(1, 20)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)#将伪造的图像输入判别器D中d_out = D(fake_images)#将输出值d_out乘以Sigmoid函数,将其转换成0~1的值print(torch.sigmoid(d_out))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SEGvF8xh-1692468683784)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230817224333376.png)]

2、数据集构建

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
import time
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nnfrom torchvision import transforms
from model.DCGAN import Generator, Discriminator
from matplotlib import pyplot as pltdef make_datapath_list(root):"""创建用于学习和验证的图像数据及标注数据的文件路径列表。 """train_img_list = list() #保存图像文件的路径for img_idx in range(200):img_path = f"{root}/img_7_{str(img_idx)}.jpg"train_img_list.append(img_path)img_path = f"{root}/img_8_{str(img_idx)}.jpg"train_img_list.append(img_path)return train_img_listclass ImageTransform:"""图像的预处理类"""def __init__(self, mean, std):self.data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])def __call__(self, img):return self.data_transform(img)class GAN_Img_Dataset(data.Dataset):"""图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""def __init__(self, file_list, transform):self.file_list = file_listself.transform = transformdef __len__(self):'''返回图像的张数'''return len(self.file_list)def __getitem__(self, index):'''获取经过预处理后的图像的张量格式的数据'''img_path = self.file_list[index]img = Image.open(img_path)  # [ 高度 ][ 宽度 ] 黑白# 图像的预处理img_transformed = self.transform(img)return img_transformed# 创建DataLoader并确认执行结果# 创建文件列表
root = "./img_78"
train_img_list = make_datapath_list(root)# 创建Dataset
mean = (0.5)
std = (0.5)
train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std)
)# 创建DataLoader
batch_size = 2
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
)# 确认执行结果
batch_iterator = iter(train_dataloader)  # 转换为迭代器
imges = next(batch_iterator)  # 取出位于第一位的元素
print(imges.size())  # torch.Size([64, 1, 64, 64])

数据请在访问链接获取:

3、train接口实现


def train_model(G, D, dataloader, num_epochs):# 确认是否能够使用GPU加速device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("使用设备:", device)# 设置最优化算法g_lr, d_lr = 0.0001, 0.0004beta1, beta2 = 0.0, 0.9g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])# 定义误差函数criterion = nn.BCEWithLogitsLoss(reduction='mean')# 使用硬编码的参数z_dim = 20mini_batch_size = 8# 将网络载入GPU中G.to(device)D.to(device)G.train()  # 将模式设置为训练模式D.train()  # 将模式设置为训练模式# 如果网络相对固定,则开启加速torch.backends.cudnn.benchmark = True# 图像张数num_train_imgs = len(dataloader.dataset)batch_size = dataloader.batch_size# 设置迭代计数器iteration = 1logs = []# epoch循环for epoch in range(num_epochs):# 保存开始时间t_epoch_start = time.time()epoch_g_loss = 0.0  # epoch的损失总和epoch_d_loss = 0.0  # epoch的损失总和print('-------------')print('Epoch {}/{}'.format(epoch, num_epochs))print('-------------')print('(train)')# 以minibatch为单位从数据加载器中读取数据的循环for imges in dataloader:# --------------------# 1.判别器D的学习# --------------------# 如果小批次的尺寸设置为1,会导致批次归一化处理产生错误,因此需要避免if imges.size()[0] == 1:continue# 如果能使用GPU,则将数据送入GPU中imges = imges.to(device)# 创建正确答案标签和伪造数据标签# 在epoch最后的迭代中,小批次的数量会减少mini_batch_size = imges.size()[0]label_real = torch.full((mini_batch_size,), 1).to(device)label_fake = torch.full((mini_batch_size,), 0).to(device)# 对真正的图像进行判定d_out_real = D(imges)# 生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)# 计算误差d_loss_real = criterion(d_out_real.view(-1), label_real.to(torch.float))d_loss_fake = criterion(d_out_fake.view(-1), label_fake.to(torch.float))d_loss = d_loss_real + d_loss_fake# 反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# --------------------# 2.生成器G的学习# --------------------# 生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)# 计算误差g_loss = criterion(d_out_fake.view(-1), label_real.to(torch.float))# 反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# --------------------# 3.记录结果# --------------------epoch_d_loss += d_loss.item()epoch_g_loss += g_loss.item()iteration += 1# epoch的每个phase的loss和准确率t_epoch_finish = time.time()print('-------------')print('epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(epoch, epoch_d_loss / batch_size, epoch_g_loss / batch_size))print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))t_epoch_start = time.time()return G, D

4、训练


G = Generator(z_dim=20, image_size=64)
D = Discriminator(z_dim=20, image_size=64)
# 定义误差函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')
num_epochs = 200
G_update, D_update = train_model(G, D, dataloader=train_dataloader, num_epochs=num_epochs
)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EvqY6h3G-1692468683786)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230820020234209.png)]

5、测试

# 将生成的图像和训练数据可视化
# 反复执行本单元中的代码,直到生成感觉良好的图像为止device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 生成用于输入的随机数
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)# 生成图像
fake_images = G_update(fixed_z.to(device))# 训练数据
imges = next(iter(train_dataloader))  # 取出位于第一位的元素# 输出结果
fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):# 将训练数据放入上层plt.subplot(2, 5, i + 1)plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')# 将生成数据放入下层plt.subplot(2, 5, 5 + i + 1)plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

在这里插入图片描述

相关文章:

生成式AI系列 —— DCGAN生成手写数字

1、模型构建 1.1 构建生成器 # 导入软件包 import torch import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim20, image_size256):super(Generator, self).__init__()self.layer1 nn.Sequential(nn.ConvTranspose2d(z_dim, image_size * 32,kernel_s…...

vscode-vue项目格式化+语法检验-草稿

Vue学习笔记7 - 在Vscode中配置Vetur,ESlint,Prettier_vetur规则_Myron.Maoyz的博客-CSDN博客...

【Java从0到1学习】10 Java常用类汇总

1. System类 System类对读者来说并不陌生,因为在之前所学知识中,需要打印结果时,使用的都是“System.out.println();”语句,这句代码中就使用了System类。System类定义了一些与系统相关的属性和方法,它所提供的属性和…...

第三届人工智能与智能制造国际研讨会(AIIM 2023)

第三届人工智能与智能制造国际研讨会(AIIM 2023) The 3rd International Symposium on Artificial Intelligence and Intelligent Manufacturing 第三届人工智能与智能制造国际研讨会(AIIM 2023)将于2023年10月27-29日在成都召开…...

层次分析法

目录 一:问题的引入 二:模型的建立 1.分析系统中各因素之间的关系,建立系统的递阶层次结构。 2.对于同一层次的各元素关于上一层次中某一准则的重要性进行两两比较,构造两两比较矩阵(判断矩阵)。 3.由判…...

Error Handling

有几个特定的异常类允许用户代码对与CAN总线相关的特定场景做出反应: Exception (Python standard library)+-- ...+-- CanError (python-can)+-- CanInterfaceNotImplementedError+-- CanInitializationError...

leetcode:字符串相乘(两种方法)

题目: 给定两个以字符串形式表示的非负整数 num1 和 num2,返回 num1 和 num2 的乘积,它们的乘积也表示为字符串形式。 注意:不能使用任何内置的 BigInteger 库或直接将输入转换为整数。 示例 1: 输入: num1 "2", nu…...

【爬虫练习之glidedsky】爬虫-基础2

题目 链接 爬虫往往不能在一个页面里面获取全部想要的数据,需要访问大量的网页才能够完成任务。 这里有一个网站,还是求所有数字的和,只是这次分了1000页。 思路 找到调用接口 可以看到后面有个参数page来控制页码 代码实现 import reques…...

03.有监督算法——决策树

1.决策树算法 决策树算法可以做分类,也可以做回归 决策树的训练与测试: 训练阶段:从给定的训练集构造出一棵树(从根节点开始选择特征,如何进行特征切分) 测试阶段:根据构造出来的树模型从上…...

网络协议详解之STP

目录 一、STP协议(生成树) 1.1 生成树协议核心知识点: 1.2 生成树协议与导致问题: 生成树含义: 1.3 802.1D 规则: 802.1D 缺点: 1.4 PVST cisco私有 1.5 PVST 1.6 快速生成树 快速的原…...

Eltima USB Network Gate 10.0 Crack

USB Network Gate -通过网络共享USB 设备 USB Network Gate (前身为以太网USB控制器USB) 轻松的通过网络(Internet/LAN/WAN)分享您的一个或者多个连接到您计算机的USB设备。 无论您身处异国还是近在隔壁办公室,您都可以轻松使用远程扫描仪、打印机、摄像头、调制解…...

SpringCloudGateway网关实战(一)

SpringCloudGateway网关实战(一) 目前对cloud的gateway功能还是不太熟悉,因此特意新建了对应的应用来尝试网关功能。 网关模块搭建 首先我们新建一个父模块用于添加对应的springboot依赖和cloud依赖。本模块我们的配置读取使用的是nacos&a…...

django中使用ajax发送请求

1、ajax简单介绍 浏览器向网站发送请求时 是以URL和表单的形式提交的post 或get 请求,特点是:页面刷新 除此之外,也可以基于ajax向后台发送请求(异步) 依赖jQuery 编写ajax代码 $.ajax({url: "发送的地址"…...

C++之std::list<string>::iterator迭代器应用实例(一百七十九)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…...

VSCode好用的插件

文章目录 前言1.Snippet Creator & easy snippet(自定义代码)2.Indent Rainbow(代码缩进)3.Chinese (Simplified) Language Pack(中文包)4.Path Intellisense(路径提示)5.Beauti…...

js实现滚轮滑动到底部自动加载(完整版)

这里我们用vue实现(原生js相似), 这里我们用一个div当作一个容器; <div class="JL" @scroll="onScroll" ref="inin"> <div v-for="(item,index) in this.list" :key="index" > ....…...

如何限制PDF打印?限制清晰度?

想要限制PDF文件的打印功能&#xff0c;想要限制PDF文件打印清晰度&#xff0c;都可以通过设置限制编辑来达到目的。 打开PDF编辑器&#xff0c;找到设置限制编辑的界面&#xff0c;切换到加密状态&#xff0c;然后我们就看到 有印刷许可。勾选【权限密码】输入一个PDF密码&am…...

python计算模板图像与原图像各区域的相似度

目录 1、解释说明&#xff1a; 2、使用示例&#xff1a; 3、注意事项&#xff1a; 1、解释说明&#xff1a; 在Python中&#xff0c;我们可以使用OpenCV库进行图像处理和计算机视觉任务。其中&#xff0c;模板匹配是一种常见的方法&#xff0c;用于在一幅图像中识别出与给定…...

阿里云云解析DNS核心概念与应用

文章目录 1.DNS解析基本概念1.1.DNS基本介绍1.2.域名的分层结构1.3.DNS解析原理1.4.DNS递归查询和迭代查询的区别1.5.DNS常用的解析记录 2.使用DNS云解析将域名与SLB公网IP进行绑定2.1.进入云解析DNS控制台2.2.添加域名解析记录2.3.验证解析是否生效 1.DNS解析基本概念 DNS官方…...

计算机竞赛 垃圾邮件(短信)分类算法实现 机器学习 深度学习

文章目录 0 前言2 垃圾短信/邮件 分类算法 原理2.1 常用的分类器 - 贝叶斯分类器 3 数据集介绍4 数据预处理5 特征提取6 训练分类器7 综合测试结果8 其他模型方法9 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 垃圾邮件(短信)分类算…...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

【android bluetooth 框架分析 04】【bt-framework 层详解 1】【BluetoothProperties介绍】

1. BluetoothProperties介绍 libsysprop/srcs/android/sysprop/BluetoothProperties.sysprop BluetoothProperties.sysprop 是 Android AOSP 中的一种 系统属性定义文件&#xff08;System Property Definition File&#xff09;&#xff0c;用于声明和管理 Bluetooth 模块相…...

【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)

1.获取 authorizationCode&#xff1a; 2.利用 authorizationCode 获取 accessToken&#xff1a;文档中心 3.获取手机&#xff1a;文档中心 4.获取昵称头像&#xff1a;文档中心 首先创建 request 若要获取手机号&#xff0c;scope必填 phone&#xff0c;permissions 必填 …...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代&#xff0c;加密货币作为一种新兴的金融现象&#xff0c;正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而&#xff0c;加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下&#xff0c;稳定…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作&#xff1a;验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化&#xff1a;测试aof和aof持久化机制&#xff0c;确保数据在开启后正确恢复。 事务&#xff1a;检查事务的原子性和回滚机制。 发布订阅&#xff1a;确保消息正确传递。 2、性…...

Ubuntu系统复制(U盘-电脑硬盘)

所需环境 电脑自带硬盘&#xff1a;1块 (1T) U盘1&#xff1a;Ubuntu系统引导盘&#xff08;用于“U盘2”复制到“电脑自带硬盘”&#xff09; U盘2&#xff1a;Ubuntu系统盘&#xff08;1T&#xff0c;用于被复制&#xff09; &#xff01;&#xff01;&#xff01;建议“电脑…...

32单片机——基本定时器

STM32F103有众多的定时器&#xff0c;其中包括2个基本定时器&#xff08;TIM6和TIM7&#xff09;、4个通用定时器&#xff08;TIM2~TIM5&#xff09;、2个高级控制定时器&#xff08;TIM1和TIM8&#xff09;&#xff0c;这些定时器彼此完全独立&#xff0c;不共享任何资源 1、定…...

多元隐函数 偏导公式

我们来推导隐函数 z z ( x , y ) z z(x, y) zz(x,y) 的偏导公式&#xff0c;给定一个隐函数关系&#xff1a; F ( x , y , z ( x , y ) ) 0 F(x, y, z(x, y)) 0 F(x,y,z(x,y))0 &#x1f9e0; 目标&#xff1a; 求 ∂ z ∂ x \frac{\partial z}{\partial x} ∂x∂z​、 …...

云原生安全实战:API网关Envoy的鉴权与限流详解

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关 作为微服务架构的统一入口&#xff0c;负责路由转发、安全控制、流量管理等核心功能。 2. Envoy 由Lyft开源的高性能云原生…...