生成对抗网络入门案例
前言
生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试生成与训练数据相似的新样本,而判别器则试图区分生成器生成的样本和真实训练数据。
下面是一个简单的对抗生成网络的入门例子,用于生成手写数字图像:
实现过程
1、导入必要的库和模块
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
2、加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)
3、定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))
4、定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
5、编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
6、冻结判别器模型的权重
discriminator.trainable = False
7、定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)
8、编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
9、定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)
10、保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()
11、训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000
完整代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)# 定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))# 定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))# 编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])# 冻结判别器模型的权重
discriminator.trainable = False# 定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)# 编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))# 定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)# 保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()# 训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000train_gan(epochs, batch_size, sample_interval)
训练结果:

这个例子使用了MNIST数据集,生成手写数字图像。生成器和判别器模型使用了卷积神经网络的结构。在训练过程中,生成器试图生成逼真的手写数字图像,而判别器则试图区分真实图像和生成图像。通过反复迭代训练生成器和判别器,GAN模型能够逐渐生成更逼真的手写数字图像。生成的图像会保存在gan_images文件夹中。
相关文章:
生成对抗网络入门案例
前言 生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试…...
多头注意力机制
1、什么是多头注意力机制 从多头注意力的结构图中,貌似这个所谓的多个头就是指多组线性变换,但是并不是,只使用了一组线性变换层,即三个变换张量对 Q、K、V 分别进行线性变换,这些变化不会改变原有张量的尺寸…...
Qt + FFmpeg 搭建 Windows 开发环境
Qt FFmpeg 搭建 Windows 开发环境 Qt FFmpeg 搭建 Windows 开发环境安装 Qt Creator下载 FFmpeg 编译包测试 Qt FFmpeg踩坑解决方法1:换一个 FFmpeg 库解决方法2:把项目改成 64 位 后记 官方博客:https://www.yafeilinux.com/ Qt开源社区…...
[网鼎杯 2020 白虎组]PicDown python反弹shell proc/self目录的信息
[网鼎杯 2020 白虎组]PicDown - 知乎 这里确实完全不会 第一次遇到一个只有文件读取思路的题目 这里也确实说明还是要学学一些其他的东西了 首先打开环境 只存在一个框框 我们通过 目录扫描 抓包 注入 发现没有用 我们测试能不能任意文件读取 ?url../../../../etc/passwd …...
SDL2绘制ffmpeg解析的mp4文件
文章目录 1.FFMPEG利用命令行将mp4转yuv4202.ffmpeg将mp4解析为yuv数据2.1 核心api: 3.SDL2进行yuv绘制到屏幕3.1 核心api 4.完整代码5.效果展示6.SDL2事件响应补充6.1 处理方式-016.2 处理方式-02 本项目采用生产者消费者模型,生产者线程:使用ffmpeg将m…...
决策树C4.5算法的技术深度剖析、实战解读
目录 一、简介决策树(Decision Tree)例子: 信息熵(Information Entropy)与信息增益(Information Gain)例子: 信息增益比(Gain Ratio)例子: 二、算…...
LLMs Python解释器程序辅助语言模型(PAL)Program-aided language models (PAL)
正如您在本课程早期看到的,LLM执行算术和其他数学运算的能力是有限的。虽然您可以尝试使用链式思维提示来克服这一问题,但它只能帮助您走得更远。即使模型正确地通过了问题的推理,对于较大的数字或复杂的运算,它仍可能在个别数学操…...
【12】c++设计模式——>单例模式练习(任务队列)
属性: (1)存储任务的容器,这个容器可以选择使用STL中的队列(queue) (2)互斥锁,多线程访问的时候用于保护任务队列中的数据 方法:主要是对任务队列中的任务进行操作 &…...
Python之函数、模块、包库
函数、模块、包库基础概念和作用 A、函数 减少代码重复 将复杂问题代码分解成简单模块 提高代码可读性 复用老代码 """ 函数 """# 定义一个函数 def my_fuvtion():# 函数执行部分print(这是一个函数)# 定义带有参数的函数 def say_hello(n…...
SQL创建与删除索引
索引创建、删除与使用: 1.1 create方式创建索引:CREATE [UNIQUE – 唯一索引 | FULLTEXT – 全文索引 ] INDEX index_name ON table_name – 不指定唯一或全文时默认普通索引 (column1[(length) [DESC|ASC]] [,column2,…]) – 可以对多列建立组合索引 …...
网络协议--链路层
2.1 引言 从图1-4中可以看出,在TCP/IP协议族中,链路层主要有三个目的: (1)为IP模块发送和接收IP数据报; (2)为ARP模块发送ARP请求和接收ARP应答; (3…...
HDLbits: Count clock
目前写过最长的verilog代码,用了将近三个小时,编写12h显示的时钟,改来改去,估计只有我自己看得懂(吐血) module top_module(input clk,input reset,input ena,output pm,output [7:0] hh,output [7:0] mm,…...
【1day】用友移动管理系统任意文件上传漏洞学习
注:该文章来自作者日常学习笔记,请勿利用文章内的相关技术从事非法测试,如因此产生的一切不良后果与作者无关。 目录 一、漏洞描述 二、影响版本 三、资产测绘 四、漏洞复现...
【c++】向webrtc学习容器操作
std::map的key为std::pair 时的查找 std::map<RemoteAndLocalNetworkId, size_t> in_flight_bytes_RTC_GUARDED_BY(&lock_);private:using RemoteAndLocalNetworkId = std::pair<uint16_t, uint16_t...
SpringBoot+Vue3外卖项目构思
SpringBoot的学习: SpringBoot的学习_明里灰的博客-CSDN博客 实现功能 前台 用户注册,邮箱登录,地址管理,历史订单,菜品规格,购物车,下单,菜品浏览,评价,…...
【AI视野·今日NLP 自然语言处理论文速览 第四十七期】Wed, 4 Oct 2023
AI视野今日CS.NLP 自然语言处理论文速览 Wed, 4 Oct 2023 Totally 73 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Contrastive Post-training Large Language Models on Data Curriculum Authors Canwen Xu, Corby Rosset, Luc…...
c++的lambda表达式
文章目录 1 lambda表达式2 捕捉列表 vs 参数列表3 lambda表达式的传递3.1 函数作为形参3.2 场景1:条件表达式3.3 场景2:线程的运行表达式 1 lambda表达式 lambda表达式可以理解为匿名函数,也就是没有名字的函数,既然是函数&#…...
电梯安全监测丨S271W无线水浸传感器用于电梯机房/电梯基坑水浸监测
城市化进程中,电梯与我们的生活息息相关。高层住宅、医院、商场、学校、车站等各种商业体建筑、公共建筑中电梯为我们生活工作提供了诸多便利。 保障电梯系统的安全至关重要!特别是电梯机房和电梯基坑可通过智能化改造提高其安全性和稳定性。例如在暴风…...
Java异常:基本概念、分类和处理
Java异常:基本概念、分类和处理 在Java编程中,异常处理是一个非常重要的部分。了解如何识别、处理和避免异常对于编写健壮、可维护的代码至关重要。本文将介绍Java异常的基本概念、分类和处理方法,并通过简单的代码示例进行说明。 一、什么…...
小谈设计模式(19)—备忘录模式
小谈设计模式(19)—备忘录模式 专栏介绍专栏地址专栏介绍 备忘录模式主要角色发起人(Originator)备忘录(Memento)管理者(Caretaker) 应用场景结构实现步骤Java程序实现首先ÿ…...
为什么Python社区推荐用pipx替代pip?以virtualenv安装为例演示工作流
为什么Python开发者应该用pipx替代pip?以virtualenv为例的完整隔离方案 当你在Ubuntu终端输入pip install virtualenv时,那个刺眼的externally-managed-environment错误提示就像一堵墙——这不是技术故障,而是Python生态进化的重要路标。传统…...
Bootstrap4 导航栏详解
Bootstrap4 导航栏详解 引言 Bootstrap 是一个流行的前端框架,它为开发者提供了丰富的组件和工具,以快速构建响应式、移动优先的网站和应用程序。导航栏是网站的重要组成部分,它能够帮助用户轻松地在网站的不同页面之间导航。Bootstrap4 提供…...
手把手教你用VSCode快速定位并修改RuoYi框架的页面标题和图标(避坑指南)
高效定制RuoYi前端界面:VSCode全局搜索实战指南 刚接触RuoYi框架的开发者常会遇到这样的困扰:想修改浏览器标签页标题或系统Logo,却不知从何下手。前后端分离的项目结构让配置文件散落在各处,而手动翻找无异于大海捞针。本文将带你…...
Z-Image-GGUF模型Java后端集成指南:SpringBoot微服务实战
Z-Image-GGUF模型Java后端集成指南:SpringBoot微服务实战 最近在做一个内容创作平台的后台重构,产品经理提了个需求,想给用户加个“AI一键生成文章配图”的功能。团队评估了几个方案,最终决定用Z-Image-GGUF这个模型,…...
开源工具优化Cursor API调用:突破限制提升开发效率的完整方案
开源工具优化Cursor API调用:突破限制提升开发效率的完整方案 【免费下载链接】cursor-free-vip [Support 0.45](Multi Language 多语言)自动注册 Cursor Ai ,自动重置机器ID , 免费升级使用Pro 功能: Youve reached y…...
SpringBoot微服务架构:集成AnythingtoRealCharacters2511实现分布式转换服务
SpringBoot微服务架构:集成AnythingtoRealCharacters2511实现分布式转换服务 1. 引言 想象一下,一个电商平台每天需要处理成千上万的动漫风格商品图片,想要将它们转换为真实人像风格来提升商品吸引力。传统方案要么依赖人工设计效率低下&am…...
G-Helper完全手册:华硕笔记本终极性能调优指南
G-Helper完全手册:华硕笔记本终极性能调优指南 【免费下载链接】g-helper Lightweight Armoury Crate alternative for Asus laptops. Control tool for ROG Zephyrus G14, G15, G16, M16, Flow X13, Flow X16, TUF, Strix, Scar and other models 项目地址: http…...
Polars 2.0大规模清洗性能翻倍的7个底层优化技巧:基于真实金融风控流水线压测数据
第一章:Polars 2.0大规模数据清洗性能跃迁的工程意义Polars 2.0 的发布标志着 Rust 原生 DataFrame 库在工程落地层面实现关键突破——其基于 Arrow 2.0 和全新查询优化器(QOv2)重构的执行引擎,将典型 ETL 清洗任务的吞吐量提升达…...
硬件工程师成长指南:从理论到实战的完整路径
1. 硬件工程师的成长路线:从理论到实践的完整规划作为一名从业十年的硬件工程师,我见过太多新人一上来就埋头焊板子、调电路,结果浪费大量时间在低水平重复。硬件设计就像下围棋,没有全局思维的人永远只能当个业余爱好者。今天我想…...
嵌入式系统内存泄漏防护与实战解决方案
1. 内存泄漏的危害与现状分析在嵌入式系统开发中,内存泄漏堪称"隐形杀手"。我经历过一个真实案例:某通信设备在现网运行三个月后频繁重启,最终定位是某个看似无害的日志处理函数每次调用泄漏512字节内存。这个案例让我深刻认识到&a…...
