当前位置: 首页 > news >正文

第G9周:ACGAN理论与实战

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

上一周已经给出代码,需要可以跳转上一周的任务
第G8周:ACGAN任务

import argparse
import os
import numpy as npimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch# 创建用于存储生成图像的目录
os.makedirs("images", exist_ok=True)# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="训练的总轮数")
parser.add_argument("--batch_size", type=int, default=64, help="每个批次的大小")
parser.add_argument("--lr", type=float, default=0.0002, help="Adam优化器的学习率")
parser.add_argument("--b1", type=float, default=0.5, help="Adam优化器的一阶动量衰减")
parser.add_argument("--b2", type=float, default=0.999, help="Adam优化器的二阶动量衰减")
parser.add_argument("--n_cpu", type=int, default=4, help="用于批次生成的CPU线程数")
parser.add_argument("--latent_dim", type=int, default=100, help="潜在空间的维度")
parser.add_argument("--n_classes", type=int, default=10, help="数据集的类别数")
parser.add_argument("--img_size", type=int, default=32, help="每个图像的尺寸")
parser.add_argument("--channels", type=int, default=1, help="图像通道数")
parser.add_argument("--sample_interval", type=int, default=400, help="图像采样间隔")
opt = parser.parse_args()
print(opt)# 检查是否支持GPU加速
cuda = True if torch.cuda.is_available() else False# 初始化神经网络权重的函数
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)# 生成器网络类
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 为类别标签创建嵌入层self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)# 计算上采样前的初始大小self.init_size = opt.img_size // 4  # Initial size before upsampling# 第一层线性层self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))# 卷积层块self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels):# 将标签嵌入到噪声中gen_input = torch.mul(self.label_emb(labels), noise)# 通过第一层线性层out = self.l1(gen_input)# 重新整形为合适的形状out = out.view(out.shape[0], 128, self.init_size, self.init_size)# 通过卷积层块生成图像img = self.conv_blocks(out)return img# 判别器网络类
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器块的函数def discriminator_block(in_filters, out_filters, bn=True):"""返回每个判别器块的层"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return block# 判别器的卷积层块self.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样后图像的高度和宽度ds_size = opt.img_size // 2 ** 4# 输出层self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label# 损失函数
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 配置数据加载器
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor# 保存生成图像的函数
def sample_image(n_row, batches_done):"""保存从0到n_classes的生成数字的图像网格"""# 采样噪声z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# 为n行生成标签从0到n_classeslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)# ----------
# 训练
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 真实数据的标签valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)# 生成数据的标签fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)# 配置输入real_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------# 训练生成器# -----------------optimizer_G.zero_grad()# 采样噪声和标签作为生成器的输入z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))# 生成一批图像gen_imgs = generator(z, gen_labels)# 损失度量生成器的欺骗判别器的能力validity, pred_label = discriminator(gen_imgs)g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 真实图像的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# 生成图像的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 判别器的总损失d_loss = (d_real_loss + d_fake_loss) / 2# 计算判别器的准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

相关文章:

第G9周:ACGAN理论与实战

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 上一周已经给出代码,需要可以跳转上一周的任务 第G8周:ACGAN任…...

Linux网络部分——DNS域名解析服务

目录 1. 域名结构 2. 系统根据域名查找IP地址的过程 3.DNS域名解析方式 4.DNS域名解析的工作原理【☆】 5.域名解析查询方式 6.搭建主从DNS域名服务器 ①初始化操作主服务器和从服务器,安装BIND软件 ②修改主服务器的主配置文件、区域配置文件、区域数…...

预处理详解

乐观学习,乐观生活,才能不断前进啊!!! 我的主页:optimistic_chen 我的专栏:c语言 点击主页:optimistic_chen和专栏:c语言, 创作不易,大佬们点赞鼓…...

Python的创建和使用自定义模块

Python 的模块是组织代码的基本单元,它可以包含变量、函数、类等,并且可以被其他 Python 程序引用和重用。除了使用 Python 提供的标准库和第三方库外,开发者还可以创建自定义模块,用于组织和管理自己的代码。本文将详细介绍如何创…...

Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具)

Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具) 场景来源 去年单位内部的一次素拓活动,分工负责策划设置其中的“你画我猜”环节,网络上搜集到题目文字后,想着如何快速做成对应一页一页的PPT。第一时间想…...

小程序地理位置接口权限直接抄作业

小程序地理位置接口有什么功能? 随着小程序生态的发展,越来越多的小程序开发者会通过官方提供的自带接口来给用户提供便捷的服务。但是当涉及到地理位置接口时,却经常遇到申请驳回的问题,反复修改也无法通过,给的理由也…...

【Osek网络管理测试】[TG3_TC6]等待总线睡眠状态_2

🙋‍♂️ 【Osek网络管理测试】系列💁‍♂️点击跳转 文章目录 1.环境搭建2.测试目的3.测试步骤4.预期结果5.测试结果 1.环境搭建 硬件:VN1630 软件:CANoe 2.测试目的 验证DUT在满足进入等待睡眠状态的条件时是否进入该状态 …...

BEV下统一的多传感器融合框架 - FUTR3D

BEV下统一的多传感器融合框架 - FUTR3D 引言 在自动驾驶汽车或者移动机器人上,通常会配备许多种传感器,比如:光学相机、激光雷达、毫米波雷达等。由于不同传感器的数据形式不同,如RGB图像,点云等,不同模态…...

c#和python的flask接口的交互

一、灰度图像的传输 c#端的传输 //读入文件夹中的图像 Mat img2 new Mat(file, ImreadModes.AnyColor); //将图像的数据转换成和相机相同的buffer数据 byte[] image_buffer new byte[img2.Width * img2.Height]; int cn img2.Channels(); //通道数 if (cn 1){//将图像的数…...

Python测试框架Pytest的参数化详解

上篇博文介绍过,Pytest是目前比较成熟功能齐全的测试框架,使用率肯定也不断攀升。 在实际工作中,许多测试用例都是类似的重复,一个个写最后代码会显得很冗余。这里,我们来了解一下pytest.mark.parametrize装饰器&…...

KernelSU 如何不通过模块,直接修改系统分区

刚刚看了术哥发的视频,发现kernelSU通过挂载OverlayFS实现无需模块,即可直接修改系统分区,很是方便,并且安全性也很高,于是便有了这篇文章。 下面的教程与原视频存在差异,建议观看原视频后再结合本文章进行操作。 在未进行修改前,我们打开/system/文件夹,并在里面创建…...

红日靶场ATTCK 1通关攻略

环境 拓扑图 VM1 web服务器 win7(192.168.22.129,10.10.10.140) VM2 win2003(10.10.10.135) VM3 DC win2008(10.10.10.138) 环境搭建 win7: 设置内网两张网卡,开启…...

CellMarker | 人骨骼肌组织细胞Marker大全!~(强烈建议火速收藏!)

1写在前面 分享一下最近看到的2篇paper关于骨骼肌组织的细胞Marker&#xff0c;绝对的Atlas级好东西。&#x1f44d; 希望做单细胞的小伙伴觉得有用哦。&#x1f60f; 2常用marker&#xff08;一&#xff09; general_mrkrs <- c( MYH7, TNNT1, TNNT3, MYH1, MYH2, "C…...

游戏名台词大赏

文章目录 原神&#xff08;圈内&#xff09; 崩坏&#xff1a;星穹铁道&#xff08;圈内&#xff09; 崩坏3&#xff08;圈内&#xff09; 原神 只要不失去你的崇高&#xff0c;整个世界都会为你敞开。 总会有地上的生灵&#xff0c;敢于直面雷霆的威光。 谁也没有见过风&…...

OpenCV如何在图像中寻找轮廓(60)

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV如何模板匹配(59) 下一篇 :OpenCV检测凸包(61) 目标 在本教程中&#xff0c;您将学习如何&#xff1a; 使用 OpenCV 函数 cv::findContours使用 OpenCV 函数 cv::d rawContours …...

java 泛型题目讲解

泛型的知识点 泛型仅存在于编译时期&#xff0c;编译期间JAVA将会使用Object类型代替泛型类型&#xff0c;在运行时期不存在泛型&#xff1b;且所有泛型实例共享一个泛型类 public class Main{public static void main(String[] args){ArrayList<String> list1new Arra…...

pptx 文件版面分析-- python-pptx(python 文档解析提取)

安装 pip install python-pptx -i https://pypi.tuna.tsinghua.edu.cn/simple --ignore-installedpptx 解析代码实现 from pptx import Presentation file_name "rag_pptx/test1.pptx" # 打开.pptx文件 ppt Presentation(file_name) for slide in ppt.slides:#pr…...

http的basic 认证方式

写在前面 本文看下http的basic auth认证方式。 1&#xff1a;什么是basic auth认证 basic auth是一种http协议规范中的一种认证方式&#xff0c;即一种证明你就是你的方式。更进一步的它是一种规范&#xff0c;这种规范是这样子&#xff0c;如果是服务端使用了basic auth认证…...

【信息系统项目管理师练习题】信息系统治理

IT治理的核心是关注以下哪项内容? a) 人员培训和发展计划 b) IT定位和信息化建设与数字化转型的责权利划分 c) 业务流程的绩效管理 d) IT基础设施的优化利用 答案: b) IT定位和信息化建设与数字化转型的责权利划分 IT治理体系框架的组成部分包括以下哪些? a) IT战略目标、IT治…...

RabbitMQ之顺序消费

什么是顺序消费 例如&#xff1a;业务上产生者发送三条消息&#xff0c; 分别是对同一条数据的增加、修改、删除操作&#xff0c; 如果没有保证顺序消费&#xff0c;执行顺序可能变成删除、修改、增加&#xff0c;这就乱了。 如何保证顺序性 一般我们讨论如何保证消息的顺序性&…...

轻松上手的LangChain学习说明书

一、Langchain是什么&#xff1f; 如今各类AI模型层出不穷&#xff0c;百花齐放&#xff0c;大佬们开发的速度永远遥遥领先于学习者的学习速度。。为了解放生产力&#xff0c;不让应用层开发人员受限于各语言模型的生产部署中…LangChain横空出世界。 Langchain可以说是现阶段…...

【论文笔记】Training language models to follow instructions with human feedback A部分

Training language models to follow instructions with human feedback A 部分 回顾一下第一代 GPT-1 &#xff1a; 设计思路是 “海量无标记文本进行无监督预训练少量有标签文本有监督微调” 范式&#xff1b;模型架构是基于 Transformer 的叠加解码器&#xff08;掩码自注意…...

嵌入式交叉编译:x265

下载 multicoreware / x265_git / Downloads — Bitbucket 解压编译 BUILD_DIR${HOME}/build_libs CROSS_NAMEaarch64-mix210-linuxcd build/aarch64-linuxmake cleancmake \-G "Unix Makefiles" \-DCMAKE_C_COMPILER${CROSS_NAME}-gcc \-DCMAKE_CXX_COMPILER${CR…...

一、Redis五种常用数据类型

Redis优势&#xff1a; 1、性能高—基于内存实现数据的存储 2、丰富的数据类型 5种常用&#xff0c;3种高级 3、原子—redis的所有单个操作都是原子性&#xff0c;即要么成功&#xff0c;要么失败。其多个操作也支持采用事务的方式实现原子性。 Redis特点&#xff1a; 1、支持…...

C语言动态内存管理malloc、calloc、realloc、free函数、内存泄漏、动态内存开辟的位置等的介绍

文章目录 前言一、为什么存在动态内存管理二、动态内存函数的介绍1. malloc函数2. 内存泄漏3. 动态内存开辟位置4. free函数5. calloc 函数6. realloc 函数7. realloc 传空指针 总结 前言 C语言动态内存管理malloc、calloc、realloc、free函数、内存泄漏、动态内存开辟的位置等…...

最近惊爆谷歌裁员

Python团队还没解散完&#xff0c;谷歌又对Flutter、Dart动手了。 什么原因呢&#xff0c;猜测啊。 谷歌裁员Python的具体原因可能是因为公司在进行技术栈的调整和优化。Python作为一种脚本语言&#xff0c;在某些情况下可能无法提供足够的性能或者扩展性&#xff0c;尤其是在…...

音频可视化:原生音频API为前端带来的全新可能!

音频API是一组提供给网页开发者的接口&#xff0c;允许他们直接在浏览器中处理音频内容。这些API使得在不依赖任何外部插件的情况下操作和控制音频成为可能。 Web Audio API 可以进行音频的播放、处理、合成以及分析等操作。借助于这些工具&#xff0c;开发者可以实现自定义的音…...

【中等】保研/考研408机试-动态规划1(01背包、完全背包、多重背包)

背包问题基本上都是模板题&#xff0c;重点&#xff1a;弄熟多重背包模板 dp[j]max(dp[j-v[i]]w[i],dp[j]) //核心思路代码&#xff08;一维数组版&#xff09; dp[i][j]max(dp[i-1][j], dp[i-1][j-v[i]]w[i])//二维数字版 一、 0-1背包 一般输入两个变量&#xff1a;体积&…...

[DEMO]给两个字符串取交集的词语

要求&#xff1a;2个英文字符串中&#xff0c;取相同的大于等于4个字母的词组 比如&#xff1a; 字符串1&#xff1a;" xingMeiLingabcdef WorldHello", 字符串2&#xff1a;"mnjqlup WorldLingLing xingMeiLingHello" 获取交接&#xff1a; [xingMeiLing…...

leetcode53-Maximum Subarray

题目 给你一个整数数组 nums &#xff0c;请你找出一个具有最大和的连续子数组&#xff08;子数组最少包含一个元素&#xff09;&#xff0c;返回其最大和。 子数组 是数组中的一个连续部分。 示例 1&#xff1a; 输入&#xff1a;nums [-2,1,-3,4,-1,2,1,-5,4] 输出&#xf…...

wordpress 首页图没了/石家庄最新消息

Java Web 知识点总结(一)(2012-03-10 18:57:59)转载▼标签&#xff1a; 服务器端sql语言客户端浏览器mysql数据库it 分类&#xff1a; java1、JavaEE的四层体系结构是什么&#xff1f;客户端层 web层(表现层) 业务逻辑层 数据层web层 请求、响应 ---- Servlet/JSP Struts业务层…...

asp.net做报名网站/公司网络推广营销

学习python一直是断断续续的&#xff0c;今天我们来介绍的是python的一个非常强大的模块---OS,我们来事例的时候不是用的标准的python&#xff0c;而是用的python的同胞兄弟Ipython&#xff0c;ipython 是一个 python 的交互式 shell&#xff0c;比默认的 python shell 好用得多…...

怎么自己做画册网站/长沙seo霜天

问题描述   旋转是图像处理的基本操作&#xff0c;在这个问题中&#xff0c;你需要将一个图像逆时针旋转90度。   计算机中的图像表示可以用一个矩阵来表示&#xff0c;为了旋转一个图像&#xff0c;只需要将对应的矩阵旋转即可。 输入格式   输入的第一行包含两个整数n,…...

wordpress 小蘑菇/效果最好的推广软件

机器学习的定义&#xff1a; 计算机程序从经验E中学习,解决某一任务T,进行某一性能度量P,通过P&#xff0c;测定在T上的表现因经验E而提高。 例如机器下棋和邮件分类。 &#xff08;1&#xff09;对于下棋而言&#xff0c; E&#xff1a;通过学习棋谱和模拟下棋&#xff0c…...

建立网站数据库实验报告/长沙网络推广服务

题目&#xff1a;原题链接&#xff08;中等&#xff09; 标签&#xff1a;回溯算法、位运算、递归 解法时间复杂度空间复杂度执行用时Ans 1 (Python)O(2N)O(2^N)O(2N)O(2N)O(2^N)O(2N)76ms (93.94%)Ans 2 (Python)Ans 3 (Python) 解法一&#xff1a; class Solution:def max…...

网络游戏有哪些/关键词排名优化工具有用吗

今天遇到一个问题就是ssh很久才能连上 于是打开命令行输入ssh rootip -v加上-v参数查看详细信息 发现在SSH2_MSG_SERVICE_ACCEPT received这一行卡了很久 经过百度是dns解析的问题 然后vim /etc/ssh/sshd 把这一行UseDNS yes的注释去掉&#xff0c;改成UseDNS no 然后保存就…...