生成式AI系列 —— DCGAN生成手写数字
1、模型构建
1.1 构建生成器
# 导入软件包
import torch
import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim=20, image_size=256):super(Generator, self).__init__()self.layer1 = nn.Sequential(nn.ConvTranspose2d(z_dim, image_size * 32,kernel_size=4, stride=1),nn.BatchNorm2d(image_size * 32),nn.ReLU(inplace=True))self.layer2 = nn.Sequential(nn.ConvTranspose2d(image_size * 32, image_size * 16,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 16),nn.ReLU(inplace=True))self.layer3 = nn.Sequential(nn.ConvTranspose2d(image_size * 16, image_size * 8,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 8),nn.ReLU(inplace=True))self.layer4 = nn.Sequential(nn.ConvTranspose2d(image_size * 8, image_size *4,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 4),nn.ReLU(inplace=True))self.layer5 = nn.Sequential(nn.ConvTranspose2d(image_size * 4, image_size * 2,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 2),nn.ReLU(inplace=True))self.layer6 = nn.Sequential(nn.ConvTranspose2d(image_size * 2, image_size,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size),nn.ReLU(inplace=True))self.last = nn.Sequential(nn.ConvTranspose2d(image_size, 3, kernel_size=4,stride=2, padding=1),nn.Tanh())# 注意:因为是黑白图像,所以只有一个输出通道def forward(self, z):out = self.layer1(z)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = self.layer6(out)out = self.last(out)return outif __name__ == "__main__":import matplotlib.pyplot as pltG = Generator(z_dim=20, image_size=256)# 输入的随机数input_z = torch.randn(1, 20)# 将张量尺寸变形为(1,20,1,1)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)#输出假图像fake_images = G(input_z)print(fake_images.shape)img_transformed = fake_images[0].detach().numpy().transpose(1, 2, 0)plt.imshow(img_transformed)plt.show()
1.1 构建判别器
class Discriminator(nn.Module):def __init__(self, z_dim=20, image_size=256):super(Discriminator, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, image_size, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))#注意:因为是黑白图像,所以输入通道只有一个self.layer2 = nn.Sequential(nn.Conv2d(image_size, image_size*2, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer3 = nn.Sequential(nn.Conv2d(image_size*2, image_size*4, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer4 = nn.Sequential(nn.Conv2d(image_size*4, image_size*8, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer5 = nn.Sequential(nn.Conv2d(image_size*8, image_size*16, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer6 = nn.Sequential(nn.Conv2d(image_size*16, image_size*32, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.last = nn.Conv2d(image_size*32, 1, kernel_size=4, stride=1)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = self.layer6(out)out = self.last(out)return outif __name__ == "__main__":#确认程序执行D = Discriminator(z_dim=20, image_size=64)#生成伪造图像input_z = torch.randn(1, 20)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)#将伪造的图像输入判别器D中d_out = D(fake_images)#将输出值d_out乘以Sigmoid函数,将其转换成0~1的值print(torch.sigmoid(d_out))
2、数据集构建
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
import time
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nnfrom torchvision import transforms
from model.DCGAN import Generator, Discriminator
from matplotlib import pyplot as pltdef make_datapath_list(root):"""创建用于学习和验证的图像数据及标注数据的文件路径列表。 """train_img_list = list() #保存图像文件的路径for img_idx in range(200):img_path = f"{root}/img_7_{str(img_idx)}.jpg"train_img_list.append(img_path)img_path = f"{root}/img_8_{str(img_idx)}.jpg"train_img_list.append(img_path)return train_img_listclass ImageTransform:"""图像的预处理类"""def __init__(self, mean, std):self.data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])def __call__(self, img):return self.data_transform(img)class GAN_Img_Dataset(data.Dataset):"""图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""def __init__(self, file_list, transform):self.file_list = file_listself.transform = transformdef __len__(self):'''返回图像的张数'''return len(self.file_list)def __getitem__(self, index):'''获取经过预处理后的图像的张量格式的数据'''img_path = self.file_list[index]img = Image.open(img_path) # [ 高度 ][ 宽度 ] 黑白# 图像的预处理img_transformed = self.transform(img)return img_transformed# 创建DataLoader并确认执行结果# 创建文件列表
root = "./img_78"
train_img_list = make_datapath_list(root)# 创建Dataset
mean = (0.5)
std = (0.5)
train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std)
)# 创建DataLoader
batch_size = 2
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
)# 确认执行结果
batch_iterator = iter(train_dataloader) # 转换为迭代器
imges = next(batch_iterator) # 取出位于第一位的元素
print(imges.size()) # torch.Size([64, 1, 64, 64])
数据请在访问链接获取:
3、train接口实现
def train_model(G, D, dataloader, num_epochs):# 确认是否能够使用GPU加速device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("使用设备:", device)# 设置最优化算法g_lr, d_lr = 0.0001, 0.0004beta1, beta2 = 0.0, 0.9g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])# 定义误差函数criterion = nn.BCEWithLogitsLoss(reduction='mean')# 使用硬编码的参数z_dim = 20mini_batch_size = 8# 将网络载入GPU中G.to(device)D.to(device)G.train() # 将模式设置为训练模式D.train() # 将模式设置为训练模式# 如果网络相对固定,则开启加速torch.backends.cudnn.benchmark = True# 图像张数num_train_imgs = len(dataloader.dataset)batch_size = dataloader.batch_size# 设置迭代计数器iteration = 1logs = []# epoch循环for epoch in range(num_epochs):# 保存开始时间t_epoch_start = time.time()epoch_g_loss = 0.0 # epoch的损失总和epoch_d_loss = 0.0 # epoch的损失总和print('-------------')print('Epoch {}/{}'.format(epoch, num_epochs))print('-------------')print('(train)')# 以minibatch为单位从数据加载器中读取数据的循环for imges in dataloader:# --------------------# 1.判别器D的学习# --------------------# 如果小批次的尺寸设置为1,会导致批次归一化处理产生错误,因此需要避免if imges.size()[0] == 1:continue# 如果能使用GPU,则将数据送入GPU中imges = imges.to(device)# 创建正确答案标签和伪造数据标签# 在epoch最后的迭代中,小批次的数量会减少mini_batch_size = imges.size()[0]label_real = torch.full((mini_batch_size,), 1).to(device)label_fake = torch.full((mini_batch_size,), 0).to(device)# 对真正的图像进行判定d_out_real = D(imges)# 生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)# 计算误差d_loss_real = criterion(d_out_real.view(-1), label_real.to(torch.float))d_loss_fake = criterion(d_out_fake.view(-1), label_fake.to(torch.float))d_loss = d_loss_real + d_loss_fake# 反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# --------------------# 2.生成器G的学习# --------------------# 生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)# 计算误差g_loss = criterion(d_out_fake.view(-1), label_real.to(torch.float))# 反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# --------------------# 3.记录结果# --------------------epoch_d_loss += d_loss.item()epoch_g_loss += g_loss.item()iteration += 1# epoch的每个phase的loss和准确率t_epoch_finish = time.time()print('-------------')print('epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(epoch, epoch_d_loss / batch_size, epoch_g_loss / batch_size))print('timer: {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))t_epoch_start = time.time()return G, D
4、训练
G = Generator(z_dim=20, image_size=64)
D = Discriminator(z_dim=20, image_size=64)
# 定义误差函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')
num_epochs = 200
G_update, D_update = train_model(G, D, dataloader=train_dataloader, num_epochs=num_epochs
)
5、测试
# 将生成的图像和训练数据可视化
# 反复执行本单元中的代码,直到生成感觉良好的图像为止device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 生成用于输入的随机数
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)# 生成图像
fake_images = G_update(fixed_z.to(device))# 训练数据
imges = next(iter(train_dataloader)) # 取出位于第一位的元素# 输出结果
fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):# 将训练数据放入上层plt.subplot(2, 5, i + 1)plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')# 将生成数据放入下层plt.subplot(2, 5, 5 + i + 1)plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')
相关文章:
生成式AI系列 —— DCGAN生成手写数字
1、模型构建 1.1 构建生成器 # 导入软件包 import torch import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim20, image_size256):super(Generator, self).__init__()self.layer1 nn.Sequential(nn.ConvTranspose2d(z_dim, image_size * 32,kernel_s…...
vscode-vue项目格式化+语法检验-草稿
Vue学习笔记7 - 在Vscode中配置Vetur,ESlint,Prettier_vetur规则_Myron.Maoyz的博客-CSDN博客...
【Java从0到1学习】10 Java常用类汇总
1. System类 System类对读者来说并不陌生,因为在之前所学知识中,需要打印结果时,使用的都是“System.out.println();”语句,这句代码中就使用了System类。System类定义了一些与系统相关的属性和方法,它所提供的属性和…...
第三届人工智能与智能制造国际研讨会(AIIM 2023)
第三届人工智能与智能制造国际研讨会(AIIM 2023) The 3rd International Symposium on Artificial Intelligence and Intelligent Manufacturing 第三届人工智能与智能制造国际研讨会(AIIM 2023)将于2023年10月27-29日在成都召开…...
层次分析法
目录 一:问题的引入 二:模型的建立 1.分析系统中各因素之间的关系,建立系统的递阶层次结构。 2.对于同一层次的各元素关于上一层次中某一准则的重要性进行两两比较,构造两两比较矩阵(判断矩阵)。 3.由判…...
Error Handling
有几个特定的异常类允许用户代码对与CAN总线相关的特定场景做出反应: Exception (Python standard library)+-- ...+-- CanError (python-can)+-- CanInterfaceNotImplementedError+-- CanInitializationError...
leetcode:字符串相乘(两种方法)
题目: 给定两个以字符串形式表示的非负整数 num1 和 num2,返回 num1 和 num2 的乘积,它们的乘积也表示为字符串形式。 注意:不能使用任何内置的 BigInteger 库或直接将输入转换为整数。 示例 1: 输入: num1 "2", nu…...
【爬虫练习之glidedsky】爬虫-基础2
题目 链接 爬虫往往不能在一个页面里面获取全部想要的数据,需要访问大量的网页才能够完成任务。 这里有一个网站,还是求所有数字的和,只是这次分了1000页。 思路 找到调用接口 可以看到后面有个参数page来控制页码 代码实现 import reques…...
03.有监督算法——决策树
1.决策树算法 决策树算法可以做分类,也可以做回归 决策树的训练与测试: 训练阶段:从给定的训练集构造出一棵树(从根节点开始选择特征,如何进行特征切分) 测试阶段:根据构造出来的树模型从上…...
网络协议详解之STP
目录 一、STP协议(生成树) 1.1 生成树协议核心知识点: 1.2 生成树协议与导致问题: 生成树含义: 1.3 802.1D 规则: 802.1D 缺点: 1.4 PVST cisco私有 1.5 PVST 1.6 快速生成树 快速的原…...
Eltima USB Network Gate 10.0 Crack
USB Network Gate -通过网络共享USB 设备 USB Network Gate (前身为以太网USB控制器USB) 轻松的通过网络(Internet/LAN/WAN)分享您的一个或者多个连接到您计算机的USB设备。 无论您身处异国还是近在隔壁办公室,您都可以轻松使用远程扫描仪、打印机、摄像头、调制解…...
SpringCloudGateway网关实战(一)
SpringCloudGateway网关实战(一) 目前对cloud的gateway功能还是不太熟悉,因此特意新建了对应的应用来尝试网关功能。 网关模块搭建 首先我们新建一个父模块用于添加对应的springboot依赖和cloud依赖。本模块我们的配置读取使用的是nacos&a…...
django中使用ajax发送请求
1、ajax简单介绍 浏览器向网站发送请求时 是以URL和表单的形式提交的post 或get 请求,特点是:页面刷新 除此之外,也可以基于ajax向后台发送请求(异步) 依赖jQuery 编写ajax代码 $.ajax({url: "发送的地址"…...
C++之std::list<string>::iterator迭代器应用实例(一百七十九)
简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…...
VSCode好用的插件
文章目录 前言1.Snippet Creator & easy snippet(自定义代码)2.Indent Rainbow(代码缩进)3.Chinese (Simplified) Language Pack(中文包)4.Path Intellisense(路径提示)5.Beauti…...
js实现滚轮滑动到底部自动加载(完整版)
这里我们用vue实现(原生js相似), 这里我们用一个div当作一个容器; <div class="JL" @scroll="onScroll" ref="inin"> <div v-for="(item,index) in this.list" :key="index" > ....…...
如何限制PDF打印?限制清晰度?
想要限制PDF文件的打印功能,想要限制PDF文件打印清晰度,都可以通过设置限制编辑来达到目的。 打开PDF编辑器,找到设置限制编辑的界面,切换到加密状态,然后我们就看到 有印刷许可。勾选【权限密码】输入一个PDF密码&am…...
python计算模板图像与原图像各区域的相似度
目录 1、解释说明: 2、使用示例: 3、注意事项: 1、解释说明: 在Python中,我们可以使用OpenCV库进行图像处理和计算机视觉任务。其中,模板匹配是一种常见的方法,用于在一幅图像中识别出与给定…...
阿里云云解析DNS核心概念与应用
文章目录 1.DNS解析基本概念1.1.DNS基本介绍1.2.域名的分层结构1.3.DNS解析原理1.4.DNS递归查询和迭代查询的区别1.5.DNS常用的解析记录 2.使用DNS云解析将域名与SLB公网IP进行绑定2.1.进入云解析DNS控制台2.2.添加域名解析记录2.3.验证解析是否生效 1.DNS解析基本概念 DNS官方…...
计算机竞赛 垃圾邮件(短信)分类算法实现 机器学习 深度学习
文章目录 0 前言2 垃圾短信/邮件 分类算法 原理2.1 常用的分类器 - 贝叶斯分类器 3 数据集介绍4 数据预处理5 特征提取6 训练分类器7 综合测试结果8 其他模型方法9 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 垃圾邮件(短信)分类算…...
compositionAPI
面试题:composition api相比于option api有哪些优势? 不同于reactivity api,composition api提供的函数很多是与组件深度绑定的,不能脱离组件而存在。 1. setup // component export default {setup(props, context){// 该函数在…...
vscode配置调试环境-windows系统
1. 下载Vscode 下载网址code.visualstudio.com 2. 安装vscode 直打开下载好的.exe文件进行安装即可 3.安装插件 4下载mingw编译器 4.1下载 下载网址sourceforge.net/projects/mingw-w64/files/ 下拉找到该位置,下载圈中的版本。下载速度有点慢 临时下载地址 htt…...
智慧城市能实现嘛?数字孪生又在其中扮演什么角色?
数字孪生智慧城市是将数字孪生技术与城市智能化相结合的新兴概念,旨在通过实时数字模拟城市运行,优化城市管理与服务,创造更智能、高效、可持续的城市环境。 在智慧城市中,数字孪生技术可以实时收集、分析城市各个方面的数据&…...
【置顶帖】关于博主/关于博客/博客大事记
关于博主 ● 信息安全从业者 ● 注册信息安全认证专家资质 ● CSDN认证业界专家、安全博客专家 、全栈安全领域优质创作者 ● 中国信通院【2021-GOLF IT新治理领导力论坛】演讲嘉宾 ● 安世加【2021-EISS企业信息安全峰会-上海】演讲嘉宾 ● CSDN【2022-隐私计算论坛】演讲嘉宾…...
华为数通方向HCIP-DataCom H12-821题库(单选题:01-20)
第01题 下面关于OSPF邻居关系和邻接关系描述正确的是 A、邻接关系由 OSPF的 DD 报文维护 B、OSPF 路由器在交换 Hello 报文之前必须建立邻接关系 C、邻居关系是从邻接关系中选出的为了交换路由信息而形成的关系 D、并非所有的邻居关系都可以成为邻接关系 答案:D 解析…...
Java【手撕双指针】LeetCode 11. “盛水最多的容器“, 图文详解思路分析 + 代码
文章目录 前言一、盛水最多的容器1, 题目2, 思路分析3, 代码展示 前言 各位读者好, 我是小陈, 这是我的个人主页, 希望我的专栏能够帮助到你: 📕 JavaSE基础: 基础语法, 类和对象, 封装继承多态, 接口, 综合小练习图书管理系统等 📗 Java数据结构: 顺序表…...
vue3——递归组件的使用
该文章是在学习 小满vue3 课程的随堂记录示例均采用 <script setup>,且包含 typescript 的基础用法 一、使用场景 递归组件 的使用场景,如 无限级的菜单 ,接下来就用菜单的例子来学习 二、具体使用 先把菜单的基础内容写出来再说 父…...
【爬虫练习之glidedsky】爬虫-基础1
题目 链接 爬虫的目标很简单,就是拿到想要的数据。 这里有一个网站,里面有一些数字。把这些数字的总和,输入到答案框里面,即可通过本关。 思路 找到调用接口 分析response 代码实现 import re import requestsurl http://www.…...
计算机视觉入门 1)卷积分类器
目录 一、卷积分类器(The Convolutional Classifer)训练分类器 二、【代码示例】汽车卡车图片分类器步骤1. 导入数据步骤2 - 定义预训练模型步骤3 - 连接头部步骤4 - 训练模型 一、卷积分类器(The Convolutional Classifer) 卷积…...
SpringBoot 配置优先级
一般而言,SpringBoot支持配置文件进行配置,即在resources下的application.properties或application.yml。 关于配置优先级而言, application.properties>application.yml>application.yaml 另外JAVA程序程序还支持java系统配置和命令行…...
如果我的网站被百度收录了_以后如何做更新争取更多收录/网址导航该如何推广
父元素添加 font-size:0px; line-height:0px; overflow:hidden; 转载于:https://www.cnblogs.com/lweiruil/p/4505099.html...
写网站建设的软文/关键词推广方式
中新网1月16日电 16日,清华大学公益慈善研究院联合字节跳动发布《互联网生态公益方法论——2018年头条公益年度数据》(以下简称《2018年头条公益年度数据》),共同探讨移动互联网时代公益内容生产和传播的新方向,让信息更具公益价值。抖音#橙子…...
天津网站开发tjniu/百度快照怎么用
用Linux守护进程检测某个程序是否运行 本文博客链接:http://blog.csdn.net/jdh99,作者:jdh,转载请注明. 环境: 主机:Fedora12 目标板:SC6410 目标板LINUX内核版本:2.6.36 实现功能: 做的一个嵌入式板子开机会自启动一个程序&am…...
wordpress语音插件下载/seo技术培训广东
Springboot(以1.5.21版本为例)项目中,项目启动除了jvm的经典过程外,以下是Spring boot项目启动过程: org.springframework.boot.loader.JarLauncher中的main函数即为上一步jvm加载并执行的函数编写有SpringApplication…...
ftontpage如何做网站/搜索引擎优化好做吗
Google I/O大会上,Google向 Android 引入了新 App 动态化框架(Android App Bundle, AAB),被看作是对Android未来发展具有颠覆性的动态化解决方案。在给Android App带来便利的同时,也给移动安全领域带来了新的挑战&…...
中国建设业管理协会网站/整站seo定制
Linux LVM逻辑卷配置过程详解 许多Linux使用者安装操作系统时都会遇到这样的困境:如何精确评估和分配各个硬盘分区的容量,如果当初评估不准确,一旦系统分区不够用时可能不得不备份、删除相关数据,甚至被迫重新规划分区并重装操作系…...