当前位置: 首页 > news >正文

Pytorch Advanced(三) Neural Style Transfer

神经风格迁移在之前的博客中已经用keras实现过了,比较复杂,keras版本。

这里用pytorch重新实现一次,原理图如下:


from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as npdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载图像

def load_image(image_path, transform=None, max_size=None, shape=None):"""Load an image and convert it to a torch tensor."""image = Image.open(image_path)if max_size:scale = max_size / max(image.size)size = np.array(image.size) * scaleimage = image.resize(size.astype(int), Image.ANTIALIAS)if shape:image = image.resize(shape, Image.LANCZOS)if transform:image = transform(image).unsqueeze(0)return image.to(device)

这里用的模型是 VGG-19,所要用的是网络中的5个卷积层

class VGGNet(nn.Module):def __init__(self):"""Select conv1_1 ~ conv5_1 activation maps."""super(VGGNet, self).__init__()self.select = ['0', '5', '10', '19', '28'] self.vgg = models.vgg19(pretrained=True).featuresdef forward(self, x):"""Extract multiple convolutional feature maps."""features = []for name, layer in self.vgg._modules.items():x = layer(x)if name in self.select:features.append(x)return features

 模型结构如下,可以看到使用序列模型来写的VGG-NET,所以标号即层号,我们要保存的是['0', '5', '10', '19', '28'] 的输出结果。

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace)(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(17): ReLU(inplace)(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace)(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(24): ReLU(inplace)(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(26): ReLU(inplace)(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace)(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(31): ReLU(inplace)(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(33): ReLU(inplace)(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(35): ReLU(inplace)(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace)(2): Dropout(p=0.5)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace)(5): Dropout(p=0.5)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

 训练:

接下来对训练过程进行解释:

1、加载风格图像和内容图像,我们在之前的博客中使用的一幅加噪图进行训练,这里是用的内容图像的拷贝。

2、我们需要优化的就是作为目标的内容图像拷贝,可以看到target需要求导。

3、VGGnet参数是不需要优化的,所以设置为验证状态。

4、将3幅图像输入网络,得到总共15个输出(每个图像有5层的输出)

5、内容损失:这里是遍历5个层的输出来计算损失,而在keras版本中只用了第4层的输出计算损失

6、风格损失:同样计算格拉姆风格矩阵,将每一层的风格损失叠加,得到总的风格损失,计算公式同样和keras版本有所不一样

7、反向传播

def main(config):# Image preprocessing# VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].# We use the same normalization statistics here.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])# Load content and style images# Make the style image same size as the content imagecontent = load_image(config.content, transform, max_size=config.max_size)style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])# Initialize a target image with the content imagetarget = content.clone().requires_grad_(True)optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])vgg = VGGNet().to(device).eval()for step in range(config.total_step):# Extract multiple(5) conv feature vectorstarget_features = vgg(target)content_features = vgg(content)style_features = vgg(style)style_loss = 0content_loss = 0for f1, f2, f3 in zip(target_features, content_features, style_features):# Compute content loss with target and content imagescontent_loss += torch.mean((f1 - f2)**2)# Reshape convolutional feature maps_, c, h, w = f1.size()f1 = f1.view(c, h * w)f3 = f3.view(c, h * w)# Compute gram matrixf1 = torch.mm(f1, f1.t())f3 = torch.mm(f3, f3.t())# Compute style loss with target and style imagesstyle_loss += torch.mean((f1 - f3)**2) / (c * h * w) # Compute total loss, backprop and optimizeloss = content_loss + config.style_weight * style_loss optimizer.zero_grad()loss.backward()optimizer.step()if (step+1) % config.log_step == 0:print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' .format(step+1, config.total_step, content_loss.item(), style_loss.item()))if (step+1) % config.sample_step == 0:# Save the generated imagedenorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))img = target.clone().squeeze()img = denorm(img).clamp_(0, 1)torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

写在if __name__=="__main__"后面的语句只会在本脚本中才能被执行,被调用时是不会被执行的。 

python的命令行工具:argparse,很优雅的添加参数

但是由于jupyter不支持添加外部参数,所以使用了外部博客的方法来支持(记住更改读取图片的位置)

import sys
if __name__ == "__main__":#解决方案来自于博客if '-f' in sys.argv:sys.argv.remove('-f')parser = argparse.ArgumentParser()parser.add_argument('--content', type=str, default='png/content.png')parser.add_argument('--style', type=str, default='png/style.png')parser.add_argument('--max_size', type=int, default=400)parser.add_argument('--total_step', type=int, default=2000)parser.add_argument('--log_step', type=int, default=10)parser.add_argument('--sample_step', type=int, default=500)parser.add_argument('--style_weight', type=float, default=100)parser.add_argument('--lr', type=float, default=0.003)#config = parser.parse_args()config = parser.parse_known_args()[0]   #参考博客 https://blog.csdn.net/ken_for_learning/article/details/89675904print(config)main(config)

相关文章:

Pytorch Advanced(三) Neural Style Transfer

神经风格迁移在之前的博客中已经用keras实现过了,比较复杂,keras版本。 这里用pytorch重新实现一次,原理图如下: from __future__ import division from torchvision import models from torchvision import transforms from PIL…...

英飞凌TC3xx--深度手撕HSM安全启动(三)--TC3xx HSM系统架构

今天聊TC3xx HSM系统,包括所用内核、UCB相关信息、Host和HSM交互方式。 1、HSM系统架构 下图来源于英飞凌官网培训材料。 TC3xx的HSM内核是一颗32位的ARM Cortex M3,主频可达100MHz,支持对称算法AES128、非对称算法PKC(Public Key Crypto) ECC256、Hash SHA2,以及T…...

黑马JVM总结(五)

(1)方法区 它是所有java虚拟机 线程共享的区,存储着跟类的结构相关的信息,类的成员变量,方法数据,成员方法,构造器方法,特殊方法(类的构造器) 方法区在虚拟机…...

C语言入门Day_18 判断和循坏的小结

目录 前言: 1.判断 2.循环 3.课堂笔记 4.思维导图 前言: 判断语句和循环语句都可以大致分为三个部分,第一个部分是固定的语法格式;第二部分是代码的执行顺序,第三部分是判断和循环成立与否的判断条件。 1.判断 1…...

mac 好用的工具推荐

mac 好用的工具推荐 落雪:全网的音乐畅听,下载地址:https://github.com/lyswhut/lx-music-desktopMotrix: 免费下载工具,下载地址:https://xclient.info/s/motrix.html#versionsDownie:视频下载工具&#x…...

星际争霸之小霸王之小蜜蜂(十二)--猫有九条命

系列文章目录 星际争霸之小霸王之小蜜蜂(十一)--杀杀杀 星际争霸之小霸王之小蜜蜂(十)--鼠道 星际争霸之小霸王之小蜜蜂(九)--狂鼠之灾 星际争霸之小霸王之小蜜蜂(八)--蓝皮鼠和大…...

【软件分析/静态分析】chapter8 课程11/12 指针分析—上下文敏感(Pointer Analysis - Context Sensitivity)

🔗 课程链接:李樾老师和谭天老师的: 南京大学《软件分析》课程11(Pointer Analysis - Context Sensitivity I)_哔哩哔哩_bilibili 南京大学《软件分析》课程12(Pointer Analysis - Context Sensitivity II&…...

时间复杂度与空间复杂度详解

时间复杂度与空间复杂度详解🦖 一、算法效率1.1 如何衡量一个算法的好坏1.2 算法的复杂度 二、时间复杂度2.1 时间复杂度的定义2.2 大O的渐进表示法2.3 如何记录表示算法复杂度 三、空间复杂度3.1 空间复杂度的定义3.2 小试牛刀 一、算法效率 1.1 如何衡量一个算法…...

目录操作函数

mkdir函数 rmdir函数 删除空目录 rename函数 换名 chdir函数 修改当前的工作目录 getcwd函数 获取当前工作的路径...

PlantUML入门教程:画时序图

软件工程中会用到各种UML图,例如用例图、时序图等。那我们能不能像写代码一样去画图呢? 今天推荐一款软件工程师的作图利器--PlantUML,它能让你用写代码的方式快速画出UML图。 一、什么是PlantUML? PlantUML是一个允许你快速作出…...

C#范围运算符

C#8.0语法中,范围运算符是一种用于快速截取序列的运算符,其语法为 “start…end”,表示从序列的 “start” 索引处开始,一直截取到"end" 索引处为止(包括 “end” 索引处的元素)。范围运算符主要…...

云数据库知识学习——云数据库产品、云数据库系统架构

一、云数据库产品 1.1、云数据库厂商概述 云数据库供应商主要分为三类。 ① 传统的数据库厂商,如 Teradata、Oracle、IBM DB2 和 Microsoft SQL Server 等。 ② 涉足数据库市场的云供应商,如 Amazon、Google、Yahoo!、阿里、百度、腾讯…...

C++中引用详解!

前言: 本文旨在讲解C中引用的相关操作,以及引用的一些注意事项!搬好小板凳,干货来了! 引用的概念 何谓引用呢?引用其实很容易理解,比如李华这个同学,他因为很调皮,所以…...

VUE3+TS项目无法找到模块“../version/version.js”的声明文件

问题描述 在导入 ../version/version.js 文件时,提示无法找到模块 解决方法 将version.js改为version.ts可以正常导入 注意,因为version.js是我自己写的模块,我可以直接该没有关系,但是如果是引入的其他的第三方包&#xff0c…...

数据结构-堆的实现及应用(堆排序和TOP-K问题)

数据结构-堆的实现及应用[堆排序和TOP-K问题] 一.堆的基本知识点1.知识点 二.堆的实现1.堆的结构2.向上调整算法与堆的插入2.向下调整算法与堆的删除 三.整体代码四.利用回调函数避免对向上和向下调整算法的修改1.向上调整算法的修改2.向下调整算法的修改3.插入元素和删除元素函…...

Spring 条件注解没生效?咋回事

条件注解相信各位小伙伴都用过,Spring 中的多环境配置 profile 底层就是通过条件注解来实现的,松哥在之前的 Spring 视频中也有和大家详细介绍过条件注解的使用,感兴趣的小伙伴戳这里:Spring源码应该怎么学?。 从 Spr…...

96. 不同的二叉搜索树

class Solution { public:int numTrees(int n) {if (n0) {return 1;}vector<int> dp(n1, 0);dp[0] 1;dp[1] 0;for (int i 1; i < n; i) {for (int j 0; j < i; j) {dp[i] dp[j] * dp[i - 1 - j];}}return dp[n];} };...

Android Jetpack 中Hilt的使用

Hilt 是 Android 的依赖项注入库&#xff0c;可减少在项目中执行手动依赖项注入的样板代码。执行 手动依赖项注入 要求您手动构造每个类及其依赖项&#xff0c;并借助容器重复使用和管理依赖项。 Hilt 通过为项目中的每个 Android 类提供容器并自动管理其生命周期&#xff0c;…...

批量采集的时间管理与优化

在进行大规模数据采集时&#xff0c;如何合理安排和管理爬取任务的时间成为了每个专业程序员需要面对的挑战。本文将分享一些关于批量采集中时间管理和优化方面的实用技巧&#xff0c;帮助你提升爬虫工作效率。 1. 制定明确目标并设置合适频率 首先要明确自己所需获取数据的范…...

uniApp监听左右滑动事件

监听左右滑动事件的步骤 1. 添加需要监听滑动事件的元素 在你的页面中&#xff0c;添加需要监听滑动事件的元素。这可以是一个 view、swiper 或其他组件&#xff0c;取决于你的需求。例如&#xff1a; <template><view class"body" touchstart"touc…...

十八、MySQL添加外键?

1、外键 外键是用来让两张表的数据之间建立联系&#xff0c;从而保证数据的一致性和完整性。 注意&#xff0c;父表被关联的字段类型&#xff0c;必须和子表被关联的字段类型一致。 2、实际操作 &#xff08;1&#xff09;初始化两张表格&#xff1a; 子表&#xff1a; 父…...

图像文件的操作MATLAB基础函数使用

简介 MATLAB中的图像处理工具箱体统了一套全方位的标准算法和图形工具&#xff0c;用于进行图像处理、分析、可视化和算法开发。这里仅仅对常用的基础函数做个使用介绍。 查询图像文件的信息 使用如下函数 imfinfo(filename,fmt) 函数imfinfo返回一个结构体的info&#xff…...

【k8s】Kubernetes版本v1.17.3 kubesphere 3.1.1 默认用户登录失败

1.发帖&#xff1a; Kubernetes版本v1.17.3 kubesphere 3.11 默认用户登录失败 - KubeSphere 开发者社区 2. 问题日志&#xff1a; 2.1问题排查方法 &#xff1a; 用户无法登录 http://192.168.56.100:30880/ 2.2查看用户状态 kubectl get users [rootk8s-node1 ~]# k…...

Mysql加密功能

Mysql加密功能 InnoDB加密功能查询条件问题开启整个数据库加密 InnoDB加密功能 InnoDB是MySQL数据库引擎的一种&#xff0c;它提供了加密存储的功能。具体来说&#xff0c;InnoDB引擎支持以下两种方式的加密存储&#xff1a; 表级加密&#xff1a;InnoDB支持表级加密&#xff…...

redis-win10安装和解决清缓存报错“Error: Protocol error, got “H“ as reply type byte”

win10安装 https://github.com/microsoftarchive/redis/releases 下载最新的zip&#xff0c;解压&#xff0c;把路径加到Path里&#xff0c;每次直接在cmd里 redis-server.exeError: Protocol error, got “H” as reply type byte 这个报错是因为我端口写错了。。无语 D:…...

【视觉检测】电源线圈上的导线弯直与否视觉检测系统软硬件方案

 检测内容 线圈上的导线弯直与否检测系统。  检测要求 检测线圈上的导线有无弯曲&#xff0c;弯曲度由客户自己设定。检测速度5K/8H625PCS/H。  视觉可行性分析 对样品进行了光学实验&#xff0c;并进行图像处理&#xff0c;原则上可以使用机器视觉进行测试测量…...

Java elasticsearch scroll模板实现

一、scroll说明和使用场景 scroll的使用场景&#xff1a;大数据量的检索和操作 scroll顾名思义&#xff0c;就是游标的意思&#xff0c;核心的应用场景就是遍历 elasticsearch中的数据&#xff1b; 通常我们遍历数据采用的是分页&#xff0c;elastcisearch还支持from size的…...

嵌入式基础知识-信息安全与加密

本篇来介绍计算机领域的信息安全以及加密相关基础知识&#xff0c;这些在嵌入式软件开发中也同样会用到。 1 信息安全 1.1 信息安全的基本要素 保密性&#xff1a;确保信息不被泄露给未授权的实体。包括最小授权原则、防暴露、信息加密、物理加密。完整性&#xff1a;保证数…...

TCP的三次握手与四次挥手

首先&#xff0c;源端口号和目标端口号是不可少的&#xff0c;这一点和 UDP 是一样的。如果没有这两个端口号。数据就不知道应该发给哪个应用。 接下来是包的序号。为什么要给包编号呢&#xff1f;当然是为了解决乱序的问题。不编好号怎么确认哪个应该先来&#xff0c;哪个应该…...

【Face Swapping综述】Quick Overview of Face Swap Deep Fakes

【Face Swapping综述】Quick Overview of Face Swap Deep Fakes 0、前言Abstract1. Introduction2. Face Swapping Process2.1. Preprocessing2.2. Identity Extraction2.3. Attributes Extractor2.4. Generator2.5. Postprocessing2.6. Evaluation Methods3. Challenges4. Con…...

专业做家具的网站/如何做网站优化seo

Word是我们常用的的办公软件&#xff0c;广泛被运用&#xff0c;那么我们怎么把Word转换为网页html格式&#xff1f; 文件&#xff1a;url80.ctfile.com/f/25127180-734987515-e5f15e?p551685 (访问密码: 551685) 需要软件&#xff1a; word2003 或 wps 个人建议用wps更方便…...

sae wordpress 3.9/建网站的详细步骤

注意&#xff1a;本文是在乌班图和Windows10环境下配置&#xff0c;Centos与乌班图略有不同&#xff0c;就是Centos的MySQL配置文件路径为/etc/my.cnf,其他操作一致1. 主从同步的定义主从同步使得数据可以从一个数据库服务器复制到其他服务器上&#xff0c;在复制数据时&#x…...

微网站后台录入/今日热点新闻事件2022

一、类加载器 ClassLoader 能根据需要将 class 文件加载到 JVM 中&#xff0c;它使用双亲委托模型&#xff0c;在加载类的时候会判断如果类未被自己加载过&#xff0c;就优先让父加载器加载。另外在使用 instanceof 关键字、equals()方法、isAssignableFrom()方法、isInstance(…...

wordpress在线客服插件/网络营销服务商有哪些

钩子方法 pytest_runtest_makereport 可以清晰的了解用例的执行过程&#xff0c;并获取到每个用例的执行结果。 钩子方法 pytest_runtest_makereport 源码&#xff1a; 按照执行顺序&#xff0c;具体过程如下&#xff1a; 1、先判断&#xff0c;当 report.when setup 时&…...

张店党风廉政建设网站/百度收录刷排名

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2020年美容师&#xff08;中级&#xff09;考试题及美容师&#xff08;中级&#xff09;多少分及格&#xff0c;包含美容师&#xff08;中级&#xff09;考试题答案和解析及美容师&#xff08;中级&#xff09;多少分…...

济南做网站哪好/郑州网络推广方案

CPU&#xff08;Central Processing Unit&#xff0c;中央处理器&#xff09;发展出来三个分枝&#xff0c;一个是DSP&#xff08;Digital Signal Processing/Processor&#xff0c;数字信号处理&#xff09;&#xff0c;另外两个是MCU&#xff08;Micro Control Unit&#xff…...