【域适应】基于域分离网络的MNIST数据10分类典型方法实现
关于
大规模数据收集和注释的成本通常使得将机器学习算法应用于新任务或数据集变得异常昂贵。规避这一成本的一种方法是在合成数据上训练模型,其中自动提供注释。尽管它们很有吸引力,但此类模型通常无法从合成图像推广到真实图像,因此需要域适应算法来操纵这些模型,然后才能成功应用。现有的方法要么侧重于将表示从一个域映射到另一个域,要么侧重于学习提取对于提取它们的域而言不变的特征。然而,通过只关注在两个域之间创建映射或共享表示,他们忽略了每个域的单独特征。域分离网络可以实现对每个域的独特之处进行特征建模,,同时进行模型域不变特征的提取。
参考文章: https://arxiv.org/abs/1608.06019
工具
方法实现
数据集定义
import torch.utils.data as data
from PIL import Image
import osclass GetLoader(data.Dataset):def __init__(self, data_root, data_list, transform=None):self.root = data_rootself.transform = transformf = open(data_list, 'r')data_list = f.readlines()f.close()self.n_data = len(data_list)self.img_paths = []self.img_labels = []for data in data_list:self.img_paths.append(data[:-3])self.img_labels.append(data[-2])def __getitem__(self, item):img_paths, labels = self.img_paths[item], self.img_labels[item]imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')if self.transform is not None:imgs = self.transform(imgs)labels = int(labels)return imgs, labelsdef __len__(self):return self.n_data
模型搭建
import torch.nn as nn
from functions import ReverseLayerFclass DSN(nn.Module):def __init__(self, code_size=100, n_class=10):super(DSN, self).__init__()self.code_size = code_size########################################### private source encoder##########################################self.source_encoder_conv = nn.Sequential()self.source_encoder_conv.add_module('conv_pse1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,padding=2))self.source_encoder_conv.add_module('ac_pse1', nn.ReLU(True))self.source_encoder_conv.add_module('pool_pse1', nn.MaxPool2d(kernel_size=2, stride=2))self.source_encoder_conv.add_module('conv_pse2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5,padding=2))self.source_encoder_conv.add_module('ac_pse2', nn.ReLU(True))self.source_encoder_conv.add_module('pool_pse2', nn.MaxPool2d(kernel_size=2, stride=2))self.source_encoder_fc = nn.Sequential()self.source_encoder_fc.add_module('fc_pse3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size))self.source_encoder_fc.add_module('ac_pse3', nn.ReLU(True))########################################## private target encoder#########################################self.target_encoder_conv = nn.Sequential()self.target_encoder_conv.add_module('conv_pte1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,padding=2))self.target_encoder_conv.add_module('ac_pte1', nn.ReLU(True))self.target_encoder_conv.add_module('pool_pte1', nn.MaxPool2d(kernel_size=2, stride=2))self.target_encoder_conv.add_module('conv_pte2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5,padding=2))self.target_encoder_conv.add_module('ac_pte2', nn.ReLU(True))self.target_encoder_conv.add_module('pool_pte2', nn.MaxPool2d(kernel_size=2, stride=2))self.target_encoder_fc = nn.Sequential()self.target_encoder_fc.add_module('fc_pte3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size))self.target_encoder_fc.add_module('ac_pte3', nn.ReLU(True))################################# shared encoder (dann_mnist)################################self.shared_encoder_conv = nn.Sequential()self.shared_encoder_conv.add_module('conv_se1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,padding=2))self.shared_encoder_conv.add_module('ac_se1', nn.ReLU(True))self.shared_encoder_conv.add_module('pool_se1', nn.MaxPool2d(kernel_size=2, stride=2))self.shared_encoder_conv.add_module('conv_se2', nn.Conv2d(in_channels=32, out_channels=48, kernel_size=5,padding=2))self.shared_encoder_conv.add_module('ac_se2', nn.ReLU(True))self.shared_encoder_conv.add_module('pool_se2', nn.MaxPool2d(kernel_size=2, stride=2))self.shared_encoder_fc = nn.Sequential()self.shared_encoder_fc.add_module('fc_se3', nn.Linear(in_features=7 * 7 * 48, out_features=code_size))self.shared_encoder_fc.add_module('ac_se3', nn.ReLU(True))# classify 10 numbersself.shared_encoder_pred_class = nn.Sequential()self.shared_encoder_pred_class.add_module('fc_se4', nn.Linear(in_features=code_size, out_features=100))self.shared_encoder_pred_class.add_module('relu_se4', nn.ReLU(True))self.shared_encoder_pred_class.add_module('fc_se5', nn.Linear(in_features=100, out_features=n_class))self.shared_encoder_pred_domain = nn.Sequential()self.shared_encoder_pred_domain.add_module('fc_se6', nn.Linear(in_features=100, out_features=100))self.shared_encoder_pred_domain.add_module('relu_se6', nn.ReLU(True))# classify two domainself.shared_encoder_pred_domain.add_module('fc_se7', nn.Linear(in_features=100, out_features=2))####################################### shared decoder (small decoder)######################################self.shared_decoder_fc = nn.Sequential()self.shared_decoder_fc.add_module('fc_sd1', nn.Linear(in_features=code_size, out_features=588))self.shared_decoder_fc.add_module('relu_sd1', nn.ReLU(True))self.shared_decoder_conv = nn.Sequential()self.shared_decoder_conv.add_module('conv_sd2', nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5,padding=2))self.shared_decoder_conv.add_module('relu_sd2', nn.ReLU())self.shared_decoder_conv.add_module('conv_sd3', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5,padding=2))self.shared_decoder_conv.add_module('relu_sd3', nn.ReLU())self.shared_decoder_conv.add_module('us_sd4', nn.Upsample(scale_factor=2))self.shared_decoder_conv.add_module('conv_sd5', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3,padding=1))self.shared_decoder_conv.add_module('relu_sd5', nn.ReLU(True))self.shared_decoder_conv.add_module('conv_sd6', nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3,padding=1))def forward(self, input_data, mode, rec_scheme, p=0.0):result = []if mode == 'source':# source private encoderprivate_feat = self.source_encoder_conv(input_data)private_feat = private_feat.view(-1, 64 * 7 * 7)private_code = self.source_encoder_fc(private_feat)elif mode == 'target':# target private encoderprivate_feat = self.target_encoder_conv(input_data)private_feat = private_feat.view(-1, 64 * 7 * 7)private_code = self.target_encoder_fc(private_feat)result.append(private_code)# shared encodershared_feat = self.shared_encoder_conv(input_data)shared_feat = shared_feat.view(-1, 48 * 7 * 7)shared_code = self.shared_encoder_fc(shared_feat)result.append(shared_code)reversed_shared_code = ReverseLayerF.apply(shared_code, p)domain_label = self.shared_encoder_pred_domain(reversed_shared_code)result.append(domain_label)if mode == 'source':class_label = self.shared_encoder_pred_class(shared_code)result.append(class_label)# shared decoderif rec_scheme == 'share':union_code = shared_codeelif rec_scheme == 'all':union_code = private_code + shared_codeelif rec_scheme == 'private':union_code = private_coderec_vec = self.shared_decoder_fc(union_code)rec_vec = rec_vec.view(-1, 3, 14, 14)rec_code = self.shared_decoder_conv(rec_vec)result.append(rec_code)return result
模型训练
import random
import os
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
from model_compat import DSN
from data_loader import GetLoader
from functions import SIMSE, DiffLoss, MSE
from test import test######################
# params #
######################source_image_root = os.path.join('.', 'dataset', 'mnist')
target_image_root = os.path.join('.', 'dataset', 'mnist_m')
model_root = 'model'
cuda = True
cudnn.benchmark = True
lr = 1e-2
batch_size = 32
image_size = 28
n_epoch = 100
step_decay_weight = 0.95
lr_decay_step = 20000
active_domain_loss_step = 10000
weight_decay = 1e-6
alpha_weight = 0.01
beta_weight = 0.075
gamma_weight = 0.25
momentum = 0.9manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)#######################
# load data #
#######################img_transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])dataset_source = datasets.MNIST(root=source_image_root,train=True,transform=img_transform
)dataloader_source = torch.utils.data.DataLoader(dataset=dataset_source,batch_size=batch_size,shuffle=True,num_workers=8
)train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt')dataset_target = GetLoader(data_root=os.path.join(target_image_root, 'mnist_m_train'),data_list=train_list,transform=img_transform
)dataloader_target = torch.utils.data.DataLoader(dataset=dataset_target,batch_size=batch_size,shuffle=True,num_workers=8
)#####################
# load model #
#####################my_net = DSN()#####################
# setup optimizer #
#####################def exp_lr_scheduler(optimizer, step, init_lr=lr, lr_decay_step=lr_decay_step, step_decay_weight=step_decay_weight):# Decay learning rate by a factor of step_decay_weight every lr_decay_stepcurrent_lr = init_lr * (step_decay_weight ** (step / lr_decay_step))if step % lr_decay_step == 0:print 'learning rate is set to %f' % current_lrfor param_group in optimizer.param_groups:param_group['lr'] = current_lrreturn optimizeroptimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)loss_classification = torch.nn.CrossEntropyLoss()
loss_recon1 = MSE()
loss_recon2 = SIMSE()
loss_diff = DiffLoss()
loss_similarity = torch.nn.CrossEntropyLoss()if cuda:my_net = my_net.cuda()loss_classification = loss_classification.cuda()loss_recon1 = loss_recon1.cuda()loss_recon2 = loss_recon2.cuda()loss_diff = loss_diff.cuda()loss_similarity = loss_similarity.cuda()for p in my_net.parameters():p.requires_grad = True#############################
# training network #
#############################
MNIST数据重建/共有部分特征/私有数据特征可视化
MNIST_m数据重建/共有部分特征/私有数据特征可视化
代码获取
相关问题和项目开发,欢迎私信交流和沟通。
相关文章:
【域适应】基于域分离网络的MNIST数据10分类典型方法实现
关于 大规模数据收集和注释的成本通常使得将机器学习算法应用于新任务或数据集变得异常昂贵。规避这一成本的一种方法是在合成数据上训练模型,其中自动提供注释。尽管它们很有吸引力,但此类模型通常无法从合成图像推广到真实图像,因此需要域…...
从零实现诗词GPT大模型:pytorch框架介绍
专栏规划: https://qibin.blog.csdn.net/article/details/137728228 因为咱们本系列文章主要基于深度学习框架pytorch进行,所以在正式开始之前,现对pytorch框架进行一个简单的介绍,主要面对深度学习或者pytorch还不熟悉的朋友。 一、安装pytorch 这一步很简单,主要通过p…...
[目标检测] OCR: 文字检测、文字识别、text spotter
概述 OCR技术存在两个步骤:文字检测和文字识别,而end-to-end完成这两个步骤的方法就是text spotter。 文字检测数据集摘要 daaset语言体量特色MTWI中英文20k源于网络图像,主要由合成图像,产品描述,网络广告(淘宝)MS…...
Windows环境下删除MySQL
文章目录 一、关闭MySQL服务1、winR打开运行,输入services.msc回车2、服务里找到MySQL并停止 二、卸载MySQL软件1、打开控制模板--卸载程序--卸载MySQL相关的所有组件 三、删除MySQL在物理硬盘上的所有文件1、删除MySQL的安装目录(默认在C盘下的Program …...
uniapp:uview-plus的一些记录
customStyle 并不是所有的组件都有customStyle属性来设置自定义属性,有的还是需要通过::v-deep来修改内置样式 form表单 labelStyle 需要的是一个对象 :labelStyle"{color: #333333,fontSize: 32rpx,fontWeight: 500}"dateTimePicker选择器设置默认值…...
OLTP 与 OLAP 系统说明对比和大数据经典架构 Lambda 和 Kappa 说明对比——解读大数据架构(五)
文章目录 前言OLTP 和 OLAPSMP 和 MPPlambda 架构Kappa 架构 前言 本文我们将研究不同类型的大数据架构设计,将讨论 OLTP 和 OLAP 的系统设计,以及有效处理数据的策略包括 SMP 和 MPP 等概念。然后我们将了解经典的 Lambda 架构和 Kappa 架构。 OLTP …...
步骤大全:网站建设3个基本流程详解
一.领取一个免费域名和SSL证书,和CDN 1.打开网站链接:https://www.rainyun.com/z22_ 2.在网站主页上,您会看到一个"登陆/注册"的选项。 3.点击"登陆/注册",然后选择"微信登录"选项。 4.使用您的…...
利用Sentinel解决雪崩问题(二)隔离和降级
前言: 虽然限流可以尽量避免因高并发而引起的服务故障,但服务还会因为其它原因而故障。而要将这些故障控制在一定范围避免雪崩,就要靠线程隔离(舱壁模式)和熔断降级手段了,不管是线程隔离还是熔断降级,都是对客户端(调…...
基于springboot的房产销售系统源码数据库
基于springboot的房产销售系统源码数据库 摘 要 随着科学技术的飞速发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势;对于房产销售系统当然也不能排除在外,随着网络技术的不断成熟,带动了房产…...
【MATLAB】基于Wi-Fi指纹匹配的室内定位-仿真获取WiFi RSSI数据(附代码)
基于Wi-Fi指纹匹配的室内定位-仿真获取WiFi RSSI数据 WiFi指纹匹配是室内定位最为基础和常见的研究,但是WiFi指纹的采集可以称得上是labor-intensive和time-consuming。现在,给大家分享一下我们课题组之前在做WiFi指纹定位时的基于射线跟踪技术仿真WiFi…...
深圳晶彩智能ESP32-3248S035R使用LovyanGFX实现手写板
深圳晶彩智能ESP32-3248S035R介绍 深圳晶彩智能出品ESP32-3248S035R为3.5寸彩色屏采用分辨率480x320彩色液晶屏,驱动芯片是ST7796。板载乐鑫公司出品ESP-WROOM-32,Flash 4M。型号尾部“R”标识电阻膜的感压式触摸屏,驱动芯片是XPT2046。 Lo…...
【Spring Boot】深入解密Spring Boot日志:最佳实践与策略解析
💓 博客主页:从零开始的-CodeNinja之路 ⏩ 收录文章:【Spring Boot】深入解密Spring Boot日志:最佳实践与策略解析 🎉欢迎大家点赞👍评论📝收藏⭐文章 目录 Spring Boot 日志一. 日志的概念?…...
ISTQB选择国内版,还是国际版呢
1, ISTQB简介 ISTQB(International Software Testing Qualifications Board)是一个国际软件测试资格认证机构,旨在提供一个统一的软件测试认证标准。ISTQB成立于2002年,是非盈利性的组织,由世界各地的国家或地区软件测…...
头歌-机器学习 第11次实验 softmax回归
第1关:softmax回归原理 任务描述 本关任务:使用Python实现softmax函数。 相关知识 为了完成本关任务,你需要掌握:1.softmax回归原理,2.softmax函数。 softmax回归原理 与逻辑回归一样,softmax回归同样…...
Qt for MCUs 2.7正式发布
本文翻译自:Qt for MCUs 2.7 released 原文作者:Qt Group高级产品经理Yoann Lopes 翻译:Macsen Wang Qt for MCUs的新版本已发布,为Qt Quick Ultralite引擎带来了新功能,增加了更多MCU平台的支持,并且我们…...
共享IP和独享IP如何选择,两者有何区别?
有跨境用户在选择共享IP和独享IP时会有疑问,不知道该如何进行选择,共享IP和独享IP各有其特点和应用场景,选择哪种方式主要取决于具体需求和预算。以下是对两者的详细比较: 首先两者的主要区别在于使用方式和安全性:共…...
文心一言VSchatGPT4
文心一言和GPT-4各有优势,具体表现在不同的测试场景下。 在某些测试场景中心一言的表现优于GPT-4,例如在故事的完整度和情节吸引力方面,文心一言表现得更加符合指令,情节更吸引人。这可能得益于其模型在训练时对中文语境的深入理…...
Linux 目录结构与基础查看命令
介绍 目录结构如下 /bin:存放着用户最经常使用的二进制可执行命令,如cp、ls、cat等。这些命令是系统管理员和普通用户进行日常操作所必需的。 /boot:存放启动系统使用的一些核心文件,如引导加载器(bootstrap loader…...
【matlab】如何解决打开缓慢问题(如何让matlab在十几秒内打开)
【matlab】如何解决打开缓慢问题(如何让matlab在十几秒内打开) 找到我们解压缩时Crack中的license_standalone.lic文件,将其拷贝 在安装matlab的路径下新建一个文件,粘贴上面的license_standalone.lic文件 在桌面鼠标移动到matl…...
【stata】求滚动波动情况
0.计算对象 计算 t t t、 t 1 t1 t1、 t 2 t2 t2 这三起滚动波动情况 V o l i , t l n ( ∑ n t n t 2 ( g n − g ˉ ) 2 3 ) Vol_{i,t} ln(\sqrt{\frac{\sum_{nt}^{nt2}(g_{n}-\bar{g})^2}{3}}) Voli,tln(3∑ntnt2(gn−gˉ)2 ) e . g e.g e.g: 假设 200…...
The C programming language (second edition,KR) exercise(CHAPTER 2)
E x c e r c i s e 2 − 1 Excercise\quad 2-1 Excercise2−1:输出结果如图1和图2所示,这道练习题需要文章1和文章2的知识。 #include <stdio.h> #include <limits.h>float getFloat(char sign, unsigned char exp, unsigned mantissa); do…...
rust实现循环链表
作为一个计算机技术专家,针对你的问题,我将首先解释如何使用Rust语言实现循环链表,并提供相应的代码示例。然后,我将解释一个可能的报错问题及其解决方法。 循环链表的实现 在Rust中实现循环链表,首先需要定义链表节…...
2. Spring的创建和Bean的存取
经过前面的学习我们已经大体明白了 IOC 思想以及它的实现方式 DI ,本节要讲的是如何Spring框架实现实现DI。 本节目标: Spring(Core) 项目创建将对象存储到 Spring 中将对象(bean)从 Spring 中取出 1. 创建 Spring 项目 与开篇演示的 Spring Boot 项目不…...
策略模式【行为模式C++】
1.概述 策略模式是一种行为设计模式, 它能让你定义一系列算法, 并将每种算法分别放入独立的类中, 以使算法的对象能够相互替换。 策略模式通常应用于需要多种算法进行操作的场景,如排序、搜索、数据压缩等。在这些情况下&#x…...
php中session相关知识(目前了解部分)
#记录学习知识 一.ini_set() 在PHP中,ini_set() 函数用于在脚本运行时设置指定的配置选项的值。这些配置选项可以是PHP的核心设置,例如文件上传的最大大小、脚本的最大执行时间、错误报告级别等。使用 ini_set() 可以临时改变PHP.ini文件中的设置&am…...
从零实现诗词GPT大模型:GPT是怎么生成内容的?
专栏规划: https://qibin.blog.csdn.net/article/details/137728228 再开始编写GPT之前,我们得对GPT是怎么生成内容的有一个大致的了解。目前的神经网络我们大多都可以看成是一个黑盒,即我们把数据输送给网络后,网络给我我们输出,我们可以不用关心这个黑盒里到底是怎么实现…...
8路HDMI+8路AV高清视频流媒体编码器JR-3218HD
产品简介: JR-3218HD高清音视频编码产品支持8路高清HDMI音视频采集功能,8路AV视频采集功能,8路3.5MM独独立音频接口采集功能。编码输出双码流H.264格式,音频MP3/AAC格式。编码码率可调,画面质量可控制。支持HTTP/RTSP…...
LangChain入门:14.LLMChain:最简单的链的使用
摘要 本文将介绍LangChain库中LLMChain工具的使用方法。LLMChain将提示模板、语言模型(LLM)和输出解析器整合在一起,形成一个连贯的处理链,简化了与语言模型的交互过程。我们将探讨LLMChain的技术特点、应用场景以及它解决的问题…...
深入理解k8s kube-proxy
1、概述 我觉得只要大家知道kube-proxy是用来配置网络规则的而不是转发流量的,真正的流量由iptables/ipvs来转发就可以了。 网络是k8s的一个关键部分。理解k8s中网络组件如何工作可以帮助更好的设计和配置我们的应用。 kube-proxy就是K8s网络的核心组件。它把我们…...
Spark-机器学习(1)什么是机器学习与MLlib算法库的认识
从这一系列开始,我会带着大家一起了解我们的机器学习,了解我们spark机器学习中的MLIib算法库,知道它大概的模型,熟悉并认识它。同时,本篇文章为个人spark免费专栏的系列文章,有兴趣的可以收藏关注一下&…...
做棋牌网站建设多少钱/360seo优化
咱们先定义一个简单的类:class Vehicle {int passengers;int fuelcap;int mpg;}有了这个模板,就能够用它来建立对象: ---若对对象与类概念模糊的能够看: 对象与类详解Vehicle veh1 new Vehicle();一般把这条语句的动做称之为建立…...
自己做的网站发布到网上/外贸seo公司
效果 步骤...
北京礼品网站建设/网络营销的六大功能
入门三问: 组的概念是什么?为什么引入它?有什么用? 答:通过组可以更加方便的管理用户,组的概念应用于各行行业,例如企业会使用部门、职能或地理区域的分类方式来管理成员,映射在Li…...
医院网站建设的好处/app运营方案
LightDB支持存储过程,除了支持Postgres的plpgsql存储过程,还兼容Oracle的存储过程,新增了plorasql过程语言。上一篇中我们介绍了存储过程中的控制语句,这一篇主要讲述存储过程中的静态SQL语句。 什么是静态SQL语句呢?…...
wordpress调用单页面/seo研究中心道一老师
主要内容:函数指针一、函数指针定义int maxValue(int a,int b){return a > b ?a : b;}函数名和数组名一样是地址,存在在代码区int maxValue(int a。int b)int (*p)(int。int) NULL函数指针定义。p是变量,其它是类型…...
asp.net 网站后台管理系统制作/今日热点新闻视频
结合前两篇文章: 小试Flex框架Fabrication Flex多国语言示例 加上Fabrication自身支持的元标签,可简化一些代码,但简化后也付出了一定的代码,那就是变量需要声明为public,而之前虽然繁琐,但却可以将其声明为…...