昇思训练营打卡第二十一天(DCGAN生成漫画头像)
DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引入了深度卷积神经网络(CNN)的结构,用于生成高质量、高分辨率的图像。
DCGAN的原理可以概括为两个主要部分:生成器(Generator)和判别器(Discriminator)。生成器和判别器都是深度卷积神经网络,它们通过对抗过程互相博弈,最终达到纳什均衡。
-
生成器(Generator):生成器的输入是一个随机的噪声向量,通过一系列的卷积、反卷积、批标准化(Batch Normalization)和激活函数(如ReLU、Tanh等)操作,生成一个与真实图像具有相同尺寸的图像。生成器的目标是生成尽可能逼真的图像,以欺骗判别器。
-
判别器(Discriminator):判别器的输入是一个图像,它通过一系列的卷积、批标准化和激活函数操作,判断输入图像是真实图像还是生成器生成的假图像。判别器的目标是能够准确地区分真实图像和假图像。
在训练过程中,生成器和判别器交替进行优化。生成器尝试生成逼真的图像,而判别器尝试更好地识别真实图像和假图像。这个过程可以看作是一种博弈,生成器和判别器在不断的迭代过程中提高自己的性能。最终,当生成器和判别器达到纳什均衡时,生成器能够生成高质量的逼真图像,判别器无法准确地区分真实图像和假图像。
DCGAN在计算机视觉领域有广泛的应用,如图像生成、图像修复、图像转换等。通过调整网络结构和训练策略,DCGAN还可以应用于其他领域,如自然语言处理、音频生成等。
数据准备与处理
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')
import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)
构造网络
生成器
import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()
判别器
class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()
模型训练
损失函数
# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')
优化器
# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')
训练模型
训练分为两个主要部分:训练判别器和训练生成器。
def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgs
import mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d] Loss_D:%7.4f Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)
相关文章:
昇思训练营打卡第二十一天(DCGAN生成漫画头像)
DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引…...
东方通Tongweb发布vue前端
一、前端包中添加文件 1、解压vue打包文件 以dist.zip为例,解压之后得到dist文件夹,进入dist文件夹,新建WEB-INF文件夹,进入WEB-INF文件夹,新建web.xml文件, 打开web.xml文件,输入以下内容 …...
spring xml实现bean对象(仅供自己参考)
对于spring xml来实现bean 具体代码: <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaL…...
MiniGPT-Med 通用医学视觉大模型:生成医学报告 + 视觉问答 + 医学疾病识别
MiniGPT-Med 通用医学视觉大模型:生成医学报告 视觉问答 医学疾病识别 提出背景解法拆解 论文:https://arxiv.org/pdf/2407.04106 代码:https://github.com/Vision-CAIR/MiniGPT-Med 提出背景 近年来,人工智能(AI…...
如何判断ip地址在同一个网段:技术解析与实际应用
在网络世界中,IP地址就像每个人的身份证一样,是识别和定位网络设备的关键。然而,仅仅知道IP地址还不足以完全理解其背后的网络结构和通信方式。特别是当我们需要判断两个或多个IP地址是否位于同一网段时,就需要借助子网掩码这一概…...
linux高级编程(TCP)(传输控制协议)
TCP与UDP: TCP: TCP优点: 可靠,稳定 TCP的可靠体现在TCP在传递数据之前,会有三次握手来建立连接,而且在数据传递时,有确认、窗口、重传、拥塞控制机制,在数据传完后,还会断开连接用来节约系统…...
【常见开源库的二次开发】一文学懂CJSON
简介: JSON(JavaScript Object Notation)是一种轻量级的数据交换格式。它基于JavaScript的一个子集,但是JSON是独立于语言的,这意味着尽管JSON是由JavaScript语法衍生出来的,它可以被任何编程语言读取和生成…...
点云下采样有损压缩
转自本人博客:点云下采样有损压缩 点云下采样是通过一定规则对原点云数据进行再采样,减少点云个数,降低点云稀疏程度,减小点云数据大小。 1. 体素下采样(Voxel Down Sample) std::shared_ptr<PointClo…...
AutoHotKey自动热键(六)转义符号
转义符号 符号说明,, (原义的逗号). 注意: 在命令最后一个参数中的逗号不需要转义, 因为程序知道把它们作为原义处理. 对于 MsgBox 所有参数同样如此, 因为它会智能的处理逗号.%% (原义的百分号) (原义的重音符; 即两个连续的转义符产生单个原义字符);; (原义的分号). 注意: 仅…...
第16章 主成分分析:四个案例及课后习题
1.假设 x x x为 m m m 维随机变量,其均值为 μ \mu μ,协方差矩阵为 Σ \Sigma Σ。 考虑由 m m m维随机变量 x x x到 m m m维随机变量 y y y的线性变换 y i α i T x ∑ k 1 m α k i x k , i 1 , 2 , ⋯ , m y _ { i } \alpha _ { i } ^ { T } …...
股票分析系统设计方案大纲与细节
股票分析系统设计方案大纲与细节 一、引言 随着互联网和金融行业的迅猛发展,股票市场已成为重要的投资渠道。投资者在追求财富增值的过程中,对股票市场的分析和预测需求日益增加。因此,设计并实现一套高效、精准的股票分析系统显得尤为重要。本设计方案旨在提出一个基于大…...
.gitmodules文件
.gitmodules文件在Git仓库中的作用 .gitmodules 文件是 Git 版本控制系统中用来跟踪和管理子模块的配置文件。子模块允许你将一个 Git 仓库嵌套在另一个仓库中,这样可以方便地管理多个项目之间的依赖关系。 在 .gitmodules 文件中,通常会记录每个子模块…...
STM32 SPI世界:W25Q64 Flash存储器的硬件与软件集成策略
摘要 在嵌入式系统设计中,选择合适的存储解决方案对于确保数据的安全性和系统的可靠性至关重要。W25Q64 Flash存储器因其高性能和大容量成为STM32微控制器项目中的热门选择。本文将深入探讨STM32与W25Q64 Flash存储器的硬件连接、软件集成以及SPI通信的最佳实践。 …...
【计算机网络仿真】b站湖科大教书匠思科Packet Tracer——实验17 开放最短路径优先OSPF
一、实验目的 1.验证OSPF协议的作用; 二、实验要求 1.使用Cisco Packet Tracer仿真平台; 2.观看B站湖科大教书匠仿真实验视频,完成对应实验。 三、实验内容 1.构建网络拓扑; 2.验证OSPF协议的作用。 四、实验步骤 1.构建网…...
ChatGPT对话:python程序模拟操作网页弹出对话框
【编者按】单击一网页中的按钮,弹出对话框网页,再单击其中的“Yes”按钮,对话框关闭,请求并获取新网页。 可能ChatGPT第一次没有正确理解描述问题的含义,再次说明后,程序编写就正确了。 1问:pyt…...
利用亚马逊云科技云原生Serverless代码托管服务开发OpenAI ChatGPT-4o应用
今天小李哥继续介绍国际上主流云计算平台亚马逊云科技AWS上的热门生成式AI应用开发架构。上次小李哥分享了利用谷歌云serverless代码托管服务Cloud Functions构建Gemini Pro API,这次我将介绍如何利用亚马逊的云原生服务Lambda调用OpenAI的最新模型ChatGPT 4o。…...
Selenium 切换 frame/iframe
环境: Python 3.8 selenium3.141.0 urllib31.26.19说明: driver.switch_to.frame() # 将当前定位的主体切换为frame/iframe表单的内嵌页面中 driver.switch_to.default_content() # 跳回最外层的页面# 判断元素是否在 frame/ifame 中 # 126 邮箱为例 # …...
VOI(Virtual Operating System Infrastructure,虚拟操作系统基础架构)
VOI(Virtual Operating System Infrastructure,虚拟操作系统基础架构)架构在桌面虚拟化领域具有其独特的优势,使得它在某些场景下表现尤为出色。以下是几个具体场景: 1. 重载性能需求场景 表现: 高效利用…...
迭代器模式(大话设计模式)C/C++版本
迭代器模式 C #include <iostream> #include <string> #include <vector>using namespace std;// 迭代抽象类,用于定义得到开始对象、得到下一个对象、判断是否到结尾、当前对象等抽象方法,统一接口 class Iterator { public:Iterator(){};virtu…...
vue学习day04-计算属性、computed计算属性与methods方法、计算属性完整写法
10、计算属性 (1)概念: 基于现有的数据,计算出来的新属性。依赖于数据变化,自动重新计算。 (计算属性->可以将一段求值的代码进行封装) (2)语法: 1&a…...
Vue记事本应用实现教程
文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...
ES6从入门到精通:前言
ES6简介 ES6(ECMAScript 2015)是JavaScript语言的重大更新,引入了许多新特性,包括语法糖、新数据类型、模块化支持等,显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var…...
盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来
一、破局:PCB行业的时代之问 在数字经济蓬勃发展的浪潮中,PCB(印制电路板)作为 “电子产品之母”,其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透,PCB行业面临着前所未有的挑战与机遇。产品迭代…...
PHP和Node.js哪个更爽?
先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...
(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
实现弹窗随键盘上移居中
实现弹窗随键盘上移的核心思路 在Android中,可以通过监听键盘的显示和隐藏事件,动态调整弹窗的位置。关键点在于获取键盘高度,并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...
Xen Server服务器释放磁盘空间
disk.sh #!/bin/bashcd /run/sr-mount/e54f0646-ae11-0457-b64f-eba4673b824c # 全部虚拟机物理磁盘文件存储 a$(ls -l | awk {print $NF} | cut -d. -f1) # 使用中的虚拟机物理磁盘文件 b$(xe vm-disk-list --multiple | grep uuid | awk {print $NF})printf "%s\n"…...
