当前位置: 首页 > news >正文

【我的创作纪念日】使用pix2pixgan实现barts2020数据集的处理(完整版本)

使用pix2pixgan (pytorch)实现T1 -> T2的基本代码

使用 https://github.com/eriklindernoren/PyTorch-GAN/ 这里面的pix2pixgan代码进行实现。

进去之后我们需要重新处理数据集,并且源代码里面先训练的生成器,后训练鉴别器。

一般情况下,先训练判别器而后训练生成器是因为这种训练顺序在理论和实践上更加稳定和有效。我们需要改变顺序以及一些代码:

以下是一些原因:

  1. 判别器的任务相对简单:判别器的任务是将真实样本与生成样本区分开来。这相对于生成器而言是一个相对简单的分类任务,因为它只需要区分两种类型的样本。通过先训练判别器,我们可以确保其具有足够的能力来准确识别真实和生成的样本。
  2. 生成器依赖于判别器的反馈:生成器的目标是生成逼真的样本,以尽可能地欺骗判别器。通过先训练判别器,我们可以得到关于生成样本质量的反馈信息。生成器可以根据判别器的反馈进行调整,并逐渐提高生成样本的质量。
  3. 训练稳定性:在GAN的早期训练阶段,生成器产生的样本可能会非常不真实。如果首先训练生成器,那么判别器可能会很容易辨别这些低质量的生成样本,导致梯度更新不稳定。通过先训练判别器,我们可以使生成器更好地适应判别器的反馈,从而增加训练的稳定性。
  4. 避免模式崩溃:在GAN训练过程中,存在模式坍塌的问题,即生成器只学会生成少数几种样本而不是整个数据分布。通过先训练判别器,我们可以提供更多样本的多样性,帮助生成器避免陷入模式崩溃现象。

尽管先训练鉴别器再训练生成器是一种常见的做法,但并不意味着这是唯一正确的方式。根据特定的问题和数据集,有时候也可以尝试其他训练策略,例如逆向训练(先训练生成器)。选择何种顺序取决于具体情况和实验结果。

数据集使用的是BraTs2020数据集,他的介绍和处理方法在我的知识链接里面。目前使用的是个人电脑的GPU跑的。然后数据也只取了前200个训练集,并且20%分出来作为测试集。

并且我们在训练的时候,每隔一定的batch使用matplotlib将T1,生成的T1,真实的T2进行展示,并且将生成器和鉴别器的loss进行展示。

通过比较可以发现使用了逐像素的L1 LOSS可以让生成的结果更好。

在这里插入图片描述

训练10个epoch时的结果图:

在这里插入图片描述

此时的测试结果:

PSNR mean: 21.1621928375993 PSNR std: 1.1501189362634836
NMSE mean: 0.14920212 NMSE std: 0.03501928
SSIM mean: 0.5401535398016223 SSIM std: 0.019281408927679166

代码:

dataloader.py

# dataloader for fine-tuning
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import torch.utils.data as data
import numpy as np
from PIL import ImageEnhance, Image
import random
import osdef cv_random_flip(img, label):# left right flipflip_flag = random.randint(0, 2)if flip_flag == 1:img = np.flip(img, 0).copy()label = np.flip(label, 0).copy()if flip_flag == 2:img = np.flip(img, 1).copy()label = np.flip(label, 1).copy()return img, labeldef randomCrop(image, label):border = 30image_width = image.size[0]image_height = image.size[1]crop_win_width = np.random.randint(image_width - border, image_width)crop_win_height = np.random.randint(image_height - border, image_height)random_region = ((image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,(image_height + crop_win_height) >> 1)return image.crop(random_region), label.crop(random_region)def randomRotation(image, label):rotate = random.randint(0, 1)if rotate == 1:rotate_time = random.randint(1, 3)image = np.rot90(image, rotate_time).copy()label = np.rot90(label, rotate_time).copy()return image, labeldef colorEnhance(image):bright_intensity = random.randint(7, 13) / 10.0image = ImageEnhance.Brightness(image).enhance(bright_intensity)contrast_intensity = random.randint(4, 11) / 10.0image = ImageEnhance.Contrast(image).enhance(contrast_intensity)color_intensity = random.randint(7, 13) / 10.0image = ImageEnhance.Color(image).enhance(color_intensity)sharp_intensity = random.randint(7, 13) / 10.0image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)return imagedef randomGaussian(img, mean=0.002, sigma=0.002):def gaussianNoisy(im, mean=mean, sigma=sigma):for _i in range(len(im)):im[_i] += random.gauss(mean, sigma)return imflag = random.randint(0, 3)if flag == 1:width, height = img.shapeimg = gaussianNoisy(img[:].flatten(), mean, sigma)img = img.reshape([width, height])return imgdef randomPeper(img):flag = random.randint(0, 3)if flag == 1:noiseNum = int(0.0015 * img.shape[0] * img.shape[1])for i in range(noiseNum):randX = random.randint(0, img.shape[0] - 1)randY = random.randint(0, img.shape[1] - 1)if random.randint(0, 1) == 0:img[randX, randY] = 0else:img[randX, randY] = 1return imgclass BraTS_Train_Dataset(data.Dataset):def __init__(self, source_modal, target_modal, img_size,image_root, data_rate, sort=False, argument=False, random=False):self.source = source_modalself.target = target_modalself.modal_list = ['t1', 't2']self.image_root = image_rootself.data_rate = data_rateself.images = [self.image_root + f for f in os.listdir(self.image_root) if f.endswith('.npy')]self.images.sort(key=lambda x: int(x.split(image_root)[1].split(".npy")[0]))self.img_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(img_size)])self.gt_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(img_size, Image.NEAREST)])self.sort = sortself.argument = argumentself.random = randomself.subject_num = len(self.images) // 60if self.random == True:subject = np.arange(self.subject_num)np.random.shuffle(subject)self.LUT = []for i in subject:for j in range(60):self.LUT.append(i * 60 + j)# print('slice number:', self.__len__())def __getitem__(self, index):if self.random == True:index = self.LUT[index]npy = np.load(self.images[index])img = npy[self.modal_list.index(self.source), :, :]gt = npy[self.modal_list.index(self.target), :, :]if self.argument == True:img, gt = cv_random_flip(img, gt)img, gt = randomRotation(img, gt)img = img * 255img = Image.fromarray(img.astype(np.uint8))img = colorEnhance(img)img = img.convert('L')img = self.img_transform(img)gt = self.img_transform(gt)return img, gtdef __len__(self):return int(len(self.images) * self.data_rate)def get_loader(batchsize, shuffle, pin_memory=True, source_modal='t1', target_modal='t2',img_size=256, img_root='data/train/', data_rate=0.1, num_workers=8, sort=False, argument=False,random=False):dataset = BraTS_Train_Dataset(source_modal=source_modal, target_modal=target_modal,img_size=img_size, image_root=img_root, data_rate=data_rate, sort=sort,argument=argument, random=random)data_loader = data.DataLoader(dataset=dataset, batch_size=batchsize, shuffle=shuffle,pin_memory=pin_memory, num_workers=num_workers)return data_loader# if __name__=='__main__':
#     data_loader = get_loader(batchsize=1, shuffle=True, pin_memory=True, source_modal='t1',
#                              target_modal='t2', img_size=256, num_workers=8,
#                              img_root='data/train/', data_rate=0.1, argument=True, random=False)
#     length = len(data_loader)
#     print("data_loader的长度为:", length)
#     # 将 data_loader 转换为迭代器
#     data_iter = iter(data_loader)
#
#     # 获取第一批数据
#     batch = next(data_iter)
#
#     # 打印第一批数据的大小
#     print("第一批数据的大小:", batch[0].shape)  # 输入图像的张量
#     print("第一批数据的大小:", batch[1].shape)  # 目标图像的张量
#     print(batch.shape)

models.py

import torch.nn as nn
import torch.nn.functional as F
import torchdef weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)##############################
#           U-NET
##############################class UNetDown(nn.Module):def __init__(self, in_size, out_size, normalize=True, dropout=0.0):super(UNetDown, self).__init__()layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]if normalize:layers.append(nn.InstanceNorm2d(out_size))layers.append(nn.LeakyReLU(0.2))if dropout:layers.append(nn.Dropout(dropout))self.model = nn.Sequential(*layers)def forward(self, x):return self.model(x)class UNetUp(nn.Module):def __init__(self, in_size, out_size, dropout=0.0):super(UNetUp, self).__init__()layers = [nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),nn.InstanceNorm2d(out_size),nn.ReLU(inplace=True),]if dropout:layers.append(nn.Dropout(dropout))self.model = nn.Sequential(*layers)def forward(self, x, skip_input):x = self.model(x)x = torch.cat((x, skip_input), 1)return xclass GeneratorUNet(nn.Module):def __init__(self, in_channels=3, out_channels=3):super(GeneratorUNet, self).__init__()self.down1 = UNetDown(in_channels, 64, normalize=False)self.down2 = UNetDown(64, 128)self.down3 = UNetDown(128, 256)self.down4 = UNetDown(256, 512, dropout=0.5)self.down5 = UNetDown(512, 512, dropout=0.5)self.down6 = UNetDown(512, 512, dropout=0.5)self.down7 = UNetDown(512, 512, dropout=0.5)self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)self.up1 = UNetUp(512, 512, dropout=0.5)self.up2 = UNetUp(1024, 512, dropout=0.5)self.up3 = UNetUp(1024, 512, dropout=0.5)self.up4 = UNetUp(1024, 512, dropout=0.5)self.up5 = UNetUp(1024, 256)self.up6 = UNetUp(512, 128)self.up7 = UNetUp(256, 64)self.final = nn.Sequential(nn.Upsample(scale_factor=2),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(128, out_channels, 4, padding=1),nn.Tanh(),)def forward(self, x):# U-Net generator with skip connections from encoder to decoderd1 = self.down1(x)d2 = self.down2(d1)d3 = self.down3(d2)d4 = self.down4(d3)d5 = self.down5(d4)d6 = self.down6(d5)d7 = self.down7(d6)d8 = self.down8(d7)u1 = self.up1(d8, d7)u2 = self.up2(u1, d6)u3 = self.up3(u2, d5)u4 = self.up4(u3, d4)u5 = self.up5(u4, d3)u6 = self.up6(u5, d2)u7 = self.up7(u6, d1)return self.final(u7)##############################
#        Discriminator
##############################class Discriminator(nn.Module):def __init__(self, in_channels=3):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, normalization=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]if normalization:layers.append(nn.InstanceNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(in_channels * 2, 64, normalization=False),*discriminator_block(64, 128),*discriminator_block(128, 256),*discriminator_block(256, 512),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(512, 1, 4, padding=1, bias=False))def forward(self, img_A, img_B):# Concatenate image and condition image by channels to produce inputimg_input = torch.cat((img_A, img_B), 1)return self.model(img_input)

pix2pix.py

import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variablefrom models import *
from dataloader import *import torch.nn as nn
import torch.nn.functional as F
import torch
if __name__=='__main__':parser = argparse.ArgumentParser()parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")parser.add_argument("--dataset_name", type=str, default="basta2020", help="name of the dataset")parser.add_argument("--batch_size", type=int, default=2, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")parser.add_argument("--img_height", type=int, default=256, help="size of image height")parser.add_argument("--img_width", type=int, default=256, help="size of image width")parser.add_argument("--channels", type=int, default=3, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=500, help="interval between sampling of images from generators")parser.add_argument("--checkpoint_interval", type=int, default=10, help="interval between model checkpoints")opt = parser.parse_args()print(opt)os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)cuda = True if torch.cuda.is_available() else False# Loss functionscriterion_GAN = torch.nn.MSELoss()criterion_pixelwise = torch.nn.L1Loss()# Loss weight of L1 pixel-wise loss between translated image and real imagelambda_pixel = 100# Calculate output of image discriminator (PatchGAN)patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)# Initialize generator and discriminatorgenerator = GeneratorUNet(in_channels=1, out_channels=1)discriminator = Discriminator(in_channels=1)if cuda:generator = generator.cuda()discriminator = discriminator.cuda()criterion_GAN.cuda()criterion_pixelwise.cuda()if opt.epoch != 0:# Load pretrained modelsgenerator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))else:# Initialize weightsgenerator.apply(weights_init_normal)discriminator.apply(weights_init_normal)# Optimizersoptimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# Configure dataloaderstransforms_ = [transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]dataloader = get_loader(batchsize=4, shuffle=True, pin_memory=True, source_modal='t1',target_modal='t2', img_size=256, num_workers=8,img_root='data/train/', data_rate=0.1, argument=True, random=False)# dataloader = DataLoader(#     ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),#     batch_size=opt.batch_size,#     shuffle=True,#     num_workers=opt.n_cpu,# )# val_dataloader = DataLoader(#     ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode="val"),#     batch_size=10,#     shuffle=True,#     num_workers=1,# )# Tensor typeTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# def sample_images(batches_done):#     """Saves a generated sample from the validation set"""#     imgs = next(iter(val_dataloader))#     real_A = Variable(imgs["B"].type(Tensor))#     real_B = Variable(imgs["A"].type(Tensor))#     fake_B = generator(real_A)#     img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)#     save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)# ----------#  Training# ----------prev_time = time.time()# 创建空列表用于保存损失值losses_G = []losses_D = []for epoch in range(opt.epoch, opt.n_epochs):for i, batch in enumerate(dataloader):# Model inputsreal_A = Variable(batch[0].type(Tensor))real_B = Variable(batch[1].type(Tensor))# print(real_A == real_B)# Adversarial ground truthsvalid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Real losspred_real = discriminator(real_B, real_A)loss_real = criterion_GAN(pred_real, valid)# Fake lossfake_B = generator(real_A)pred_fake = discriminator(fake_B.detach(), real_A)loss_fake = criterion_GAN(pred_fake, fake)# Total lossloss_D = 0.5 * (loss_real + loss_fake)loss_D.backward()optimizer_D.step()# ------------------#  Train Generators# ------------------optimizer_G.zero_grad()# GAN losspred_fake = discriminator(fake_B, real_A)loss_GAN = criterion_GAN(pred_fake, valid)# Pixel-wise lossloss_pixel = criterion_pixelwise(fake_B, real_B)# Total lossloss_G = loss_GAN + lambda_pixel * loss_pixel   # 希望生成的接近1loss_G.backward()optimizer_G.step()# --------------#  Log Progress# --------------# Determine approximate time leftbatches_done = epoch * len(dataloader) + ibatches_left = opt.n_epochs * len(dataloader) - batches_donetime_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))prev_time = time.time()# Print logsys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"% (epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),loss_pixel.item(),loss_GAN.item(),time_left,))mat = [real_A, fake_B, real_B]if (batches_done + 1) % 200 == 0:plt.figure(dpi=400)ax = plt.subplot(131)for i, img in enumerate(mat):ax = plt.subplot(1, 3, i + 1)  #get positionimg = img.permute([0, 2, 3, 1])  # b c h w ->b h w cif img.shape[0] != 1:   # 有多个就只取第一个img = img[1]img = img.squeeze(0)   # b h w c -> h w cif img.shape[2] == 1:img = img.repeat(1, 1, 3)  # process gray imgimg = img.cpu()ax.imshow(img.data)ax.set_xticks([])ax.set_yticks([])plt.show()if (batches_done + 1) % 20 ==0:losses_G.append(loss_G.item())losses_D.append(loss_D.item())if (batches_done + 1) % 200 == 0:  # 每20个batch添加一次损失# 保存损失值plt.figure(figsize=(10, 5))plt.plot(range(int((batches_done + 1) / 20)), losses_G, label="Generator Loss")plt.plot(range(int((batches_done + 1) / 20)), losses_D, label="Discriminator Loss")plt.xlabel("Epoch")plt.ylabel("Loss")plt.title("GAN Training Loss Curve")plt.legend()plt.show()# # If at sample interval save image# if batches_done % opt.sample_interval == 0:#     sample_images(batches_done)if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:# Save model checkpointstorch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))

processing.py 数据预处理

import numpy as np
from matplotlib import pylab as plt
import nibabel as nib
import random
import glob
import os
from PIL import Image
import imageiodef normalize(image, mask=None, percentile_lower=0.2, percentile_upper=99.8):if mask is None:mask = image != image[0, 0, 0]cut_off_lower = np.percentile(image[mask != 0].ravel(), percentile_lower)cut_off_upper = np.percentile(image[mask != 0].ravel(), percentile_upper)res = np.copy(image)res[(res < cut_off_lower) & (mask != 0)] = cut_off_lowerres[(res > cut_off_upper) & (mask != 0)] = cut_off_upperres = res / res.max()  # 0-1return resdef visualize(t1_data, t2_data, flair_data, t1ce_data, gt_data):plt.figure(figsize=(8, 8))plt.subplot(231)plt.imshow(t1_data[:, :], cmap='gray')plt.title('Image t1')plt.subplot(232)plt.imshow(t2_data[:, :], cmap='gray')plt.title('Image t2')plt.subplot(233)plt.imshow(flair_data[:, :], cmap='gray')plt.title('Image flair')plt.subplot(234)plt.imshow(t1ce_data[:, :], cmap='gray')plt.title('Image t1ce')plt.subplot(235)plt.imshow(gt_data[:, :])plt.title('GT')plt.show()def visualize_to_gif(t1_data, t2_data, t1ce_data, flair_data):transversal = []coronal = []sagittal = []slice_num = t1_data.shape[2]for i in range(slice_num):sagittal_plane = np.concatenate((t1_data[:, :, i], t2_data[:, :, i],t1ce_data[:, :, i], flair_data[:, :, i]), axis=1)coronal_plane = np.concatenate((t1_data[i, :, :], t2_data[i, :, :],t1ce_data[i, :, :], flair_data[i, :, :]), axis=1)transversal_plane = np.concatenate((t1_data[:, i, :], t2_data[:, i, :],t1ce_data[:, i, :], flair_data[:, i, :]), axis=1)transversal.append(transversal_plane)coronal.append(coronal_plane)sagittal.append(sagittal_plane)imageio.mimsave("./transversal_plane.gif", transversal, duration=0.01)imageio.mimsave("./coronal_plane.gif", coronal, duration=0.01)imageio.mimsave("./sagittal_plane.gif", sagittal, duration=0.01)returnif __name__ == '__main__':t1_list = sorted(glob.glob('../data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*t1.*'))t2_list = sorted(glob.glob('../data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*t2.*'))data_len = len(t1_list)train_len = int(data_len * 0.8)test_len = data_len - train_lentrain_path = '../data/train/'test_path = '../data/test/'os.makedirs(train_path, exist_ok=True)os.makedirs(test_path, exist_ok=True)for i, (t1_path, t2_path) in enumerate(zip(t1_list, t2_list)):print('preprocessing the', i + 1, 'th subject')t1_img = nib.load(t1_path)  # (240,140,155)t2_img = nib.load(t2_path)# to numpyt1_data = t1_img.get_fdata()t2_data = t2_img.get_fdata()t1_data = normalize(t1_data)  # normalize to [0,1]t2_data = normalize(t2_data)tensor = np.stack([t1_data, t2_data])  # (2, 240, 240, 155)if i < train_len:for j in range(60):Tensor = tensor[:, 10:210, 25:225, 50 + j]np.save(train_path + str(60 * i + j + 1) + '.npy', Tensor)else:for j in range(60):Tensor = tensor[:, 10:210, 25:225, 50 + j]np.save(test_path + str(60 * (i - train_len) + j + 1) + '.npy', Tensor)

testutil.py

#-*- codeing = utf-8 -*-
#@Time : 2023/9/23 0023 17:21
#@Author : Tom
#@File : testutil.py.py
#@Software : PyCharm
import argparsefrom math import log10, sqrt
import numpy as np
from skimage.metrics import structural_similarity as ssimdef psnr(res,gt):mse = np.mean((res - gt) ** 2)if(mse == 0):return 100max_pixel = 1psnr = 20 * log10(max_pixel / sqrt(mse))return psnrdef nmse(res,gt):Norm = np.linalg.norm((gt * gt),ord=2)if np.all(Norm == 0):return 0else:nmse = np.linalg.norm(((res - gt) * (res - gt)),ord=2) / Normreturn nmse

test.py

#-*- codeing = utf-8 -*-
#@Time : 2023/9/23 0023 16:14
#@Author : Tom
#@File : test.py.py
#@Software : PyCharmimport torch
from models import *
from dataloader import *
from testutil import *if __name__ == '__main__':images_save = "images_save/"slice_num = 4os.makedirs(images_save, exist_ok=True)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = GeneratorUNet(in_channels=1, out_channels=1)data_loader = get_loader(batchsize=4, shuffle=True, pin_memory=True, source_modal='t1',target_modal='t2', img_size=256, num_workers=8,img_root='data/test/', data_rate=1, argument=True, random=False)model = model.to(device)model.load_state_dict(torch.load("saved_models/basta2020/generator_0.pth", map_location=torch.device(device)), strict=False)PSNR = []NMSE = []SSIM = []for i, (img, gt) in enumerate(data_loader):batch_size = img.size()[0]img = img.to(device, dtype=torch.float)gt = gt.to(device, dtype=torch.float)with torch.no_grad():pred = model(img)for j in range(batch_size):a = pred[j]save_image([pred[j]], images_save + str(i * batch_size + j + 1) + '.png', normalize=True)print(images_save + str(i * batch_size + j + 1) + '.png')pred, gt = pred.cpu().detach().numpy().squeeze(), gt.cpu().detach().numpy().squeeze()for j in range(batch_size):PSNR.append(psnr(pred[j], gt[j]))NMSE.append(nmse(pred[j], gt[j]))SSIM.append(ssim(pred[j], gt[j]))PSNR = np.asarray(PSNR)NMSE = np.asarray(NMSE)SSIM = np.asarray(SSIM)PSNR = PSNR.reshape(-1, slice_num)NMSE = NMSE.reshape(-1, slice_num)SSIM = SSIM.reshape(-1, slice_num)PSNR = np.mean(PSNR, axis=1)print(PSNR.size)NMSE = np.mean(NMSE, axis=1)SSIM = np.mean(SSIM, axis=1)print("PSNR mean:", np.mean(PSNR), "PSNR std:", np.std(PSNR))print("NMSE mean:", np.mean(NMSE), "NMSE std:", np.std(NMSE))print("SSIM mean:", np.mean(SSIM), "SSIM std:", np.std(SSIM))

相关文章:

【我的创作纪念日】使用pix2pixgan实现barts2020数据集的处理(完整版本)

使用pix2pixgan &#xff08;pytorch)实现T1 -> T2的基本代码 使用 https://github.com/eriklindernoren/PyTorch-GAN/ 这里面的pix2pixgan代码进行实现。 进去之后我们需要重新处理数据集&#xff0c;并且源代码里面先训练的生成器&#xff0c;后训练鉴别器。 一般情况下…...

背包算法(Knapsack problem)

背包算法&#xff08;Knapsack problem&#xff09;是一种常见的动态规划问题&#xff0c;它的基本思想是利用动态规划思想求解给定重量和价值下的最优解。具体来说&#xff0c;背包算法用于解决一个整数背包问题&#xff0c;即给定一组物品&#xff0c;每个物品有自己的重量和…...

“童”趣迎国庆 安全“童”行-柿铺梁坡社区开展迎国庆活动

“金秋十月好心境&#xff0c;举国欢腾迎国庆。”国庆节来临之际&#xff0c;为进一步加强梁坡社区未成年人爱国主义教育&#xff0c;丰富文化生活&#xff0c;营造热烈喜庆、文明和谐的节日氛围。9月24日上午&#xff0c;樊城区柿铺街道梁坡社区新时代文明实践站联合襄阳市和时…...

常用压缩解压缩命令

在Linux中常见的压缩格式有.zip、.rar、.tar.gz.、tar.bz2等压缩格式。不同的压缩格式需要用不同的压缩命令和工具。须知&#xff0c;在Linux系统中.tar.gz为标准格式的压缩和解压缩格式&#xff0c;因此本文也会着重讲解tar.gz格式压缩包的压缩和解压缩命令。须知&#xff0c;…...

第四十一章 持久对象和SQL - Storage

文章目录 第四十一章 持久对象和SQL - StorageStorage存储定义概览持久类使用的Globals注意 第四十一章 持久对象和SQL - Storage Storage 每个持久类定义都包含描述类属性如何映射到实际存储它们的Global的信息。类编译器为类生成此信息&#xff0c;并在修改和重新编译时更新…...

【Java接口性能优化】skywalking使用

skywalking使用 提示&#xff1a;微服务中-skywalking使用 文章目录 skywalking使用一、进入skywalking主页二、进入具体服务1.查看接口 一、进入skywalking主页 二、进入具体服务 可以点击列表或搜索后&#xff0c;点击进入具体服务 依次选择日期、小时、分钟 1.查看接口 依次…...

大学各个专业介绍

计算机类 五米高考-计算机类 注&#xff1a;此处平均薪酬为毕业五年平均薪酬&#xff0c;薪酬数据仅供参考 来源&#xff1a; 掌上高考 电气类 五米高考-电气类 机械类 五米高考-机械类 电子信息类 五米高考-电子信息类 土木类 五米高考-土木类...

linux 列出网络上所有活动的主机

列出网络上所有活动的主机 #!/bin/bash# {start..end}会由shell对其进行扩展生成一组ip地址for ip in 192.168.0.{1..255} ;do ping $ip -c 2 &> /dev/null ; # $?获取退出状态&#xff0c;顺利退出则为0 if [ $? -eq 0 ]; then echo $ip is alive fidone https://zh…...

基于vue+Element Table Popover 弹出框内置表格的封装

文章目录 项目场景&#xff1a;实现效果认识组件代码效果分析 封装&#xff1a;代码封装思路页面中使用 项目场景&#xff1a; 在选择数据的时候需要在已选择的数据中对比选择&#xff0c;具体就是点击一个按钮&#xff0c;弹出一个小的弹出框&#xff0c;但不像对话框那样还需…...

机器人过程自动化(RPA)入门 4. 数据处理

到目前为止,我们已经了解了RPA的基本知识,以及如何使用流程图或序列来组织工作流中的步骤。我们现在了解了UiPath组件,并对UiPath Studio有了全面的了解。我们用几个简单的例子制作了我们的第一个机器人。在我们继续之前,我们应该了解UiPath中的变量和数据操作。它与其他编…...

java导出word(含图片、表格)

1.pom 引入 <!--word报告生成依赖--><dependency><groupId>org.apache.poi</groupId><artifactId>poi</artifactId><version>4.1.2</version></dependency><dependency><groupId>org.apache.poi</groupI…...

MySQL数据库记录的修改与更新

数据的修改和更新是数据库管理的核心任务之一,尤其是在动态和快速变化的环境下。本文将深入探讨如何在MySQL数据库中有效地进行记录的修改和更新。特别是将通过使用《三国志》游戏数据作为例子,来具体展示这些操作如何实施。文章主要面向具有基础数据库知识的读者。 文章目录…...

开具数电票如何减少认证频次?

“数电票”开具需多次刷脸认证&#xff0c;如何减少认证频次&#xff1f; 法定代表人、财务负责人可以在“身份认证频次设置”功能自行设置身份认证时间间隔&#xff0c;方法如下&#xff1a; 第一步 登录电子税务局。企业法定代表人或财务负责人通过手机APP“扫一扫”&#x…...

【进阶C语言】动态内存分配

本章大致内容介绍&#xff1a; 1.malloc函数和free函数 2.calloc函数 3.realloc函数 4.常见错误案例 5.笔试题详解 6.柔性数组 一、malloc和free 1.malloc函数 &#xff08;1&#xff09;函数原型 函数参数&#xff1a;根据用户的需求需要开辟多大的字节空间&#xff…...

手机上记录的备忘录内容怎么分享到电脑上查看?

手机已经成为了我们生活中不可或缺的一部分&#xff0c;我们用它来处理琐碎事务&#xff0c;记录生活点滴&#xff0c;手机备忘录就是我们常用的工具之一。但随着工作的需要&#xff0c;我们往往会遇到一个问题&#xff1a;手机上记录的备忘录内容&#xff0c;如何方便地分享到…...

LeetCode 2251. 花期内花的数目:排序 + 二分

【LetMeFly】2251.花期内花的数目&#xff1a;排序 二分 力扣题目链接&#xff1a;https://leetcode.cn/problems/number-of-flowers-in-full-bloom/ 给你一个下标从 0 开始的二维整数数组 flowers &#xff0c;其中 flowers[i] [starti, endi] 表示第 i 朵花的 花期 从 st…...

【3】贪心算法-最优装载问题-加勒比海盗

算法背景 在北美洲东南部&#xff0c;有一片神秘的海域&#xff0c;那里碧海蓝天、阳光 明媚&#xff0c;这正是传说中海盗最活跃的加勒比海&#xff08;Caribbean Sea&#xff09;。 有一天&#xff0c;海盗们截获了一艘装满各种各样古董的货船&#xff0c;每一 件古董都价值连…...

JavaScript 的 for 循环应该如何学习?

JS for 循环语法 JS for 循环适合在已知循环次数时使用&#xff0c;语法格式如下&#xff1a; for(initialization; condition; increment) {// 要执行的代码 }for 循环中包含三个可选的表达式 initialization、condition 和 increment&#xff0c;其中&#xff1a; initial…...

C++核心编程--对象篇

4.2、对象 4.2.1、对象的初始化和清理 用于对对象进行初始化设置&#xff0c;以及对象销毁前的清理数据的设置。 构造函数和析构函数 防止对象初始化和清理也是非常重要的安全问题 一个对象或变量没有初始化状态&#xff0c;对其使用后果是未知的同样使用完一个对象或变量&…...

安装php扩展XLSXWriter,解决php导入excel表格时获取日期变成浮点数的方法

安装php扩展XLSXWriter 1、下载安装包 PECL :: Package :: xlswriter #例如选择下载1.3.6版本 2、解压下载包 tar -zxvf xlswriter-1.3.6.tgz 3、进入文件夹,编译 cd xlswriter-1.3.6 phpize ./configure --with-php-config=/usr/local/php7.1/bin/php-config make&am…...

Vue+element开发Simple Admin后端管理系统页面

最近看到各种admin&#xff0c;头大&#xff0c;内容太多&#xff0c;根本不知道怎么改。所以制作了这个项目&#xff0c;只包含框架、和开发中最常用的表格和表单&#xff0c;不用自己从头搭建架构&#xff0c;同时也容易上手二次开发。可以轻松从其他开源项目整合到本项目。项…...

源码编译安装pkg-config

安装环境&#xff1a;银河麒麟 1 到这个网址下载pkg-config源码&#xff1a; Index of /releases (pkg-config.freedesktop.org) 2 解压 3 进入解压后的目录。输入 ./configure 但是报错。 4 根据报错信息&#xff0c;将configure改为&#xff1a; ./configure --with-i…...

游览器找不到服务器上PHP文件的一种原因

最近在练习搭建网站&#xff0c;遇到游览器找不到服务器上的php文件的问题。后来查找发现&#xff0c;apache文档根目录跟apache虚拟主机文档根目录不同&#xff0c;服务器开启了虚拟主机功能。这导致游览器找不到php文件。使用的环境LAMP 里操作系统和软件版本如下&#xff1a…...

C++之std::function的介绍

C之std::function的介绍 std::function和函数指针的区别介绍std::function 的常见用法包括用法举例 std::function和函数指针的区别介绍 std::function 和函数指针在 C 中都可以用来存储和调用函数&#xff0c;但它们的使用方式和功能有所不同。 函数指针是一种指向函数的指针…...

卷积神经网络学习(一)

CNN应用对象是图像&#xff0c;CNN可被应用于的任务&#xff1a; 1、分类&#xff08;classification&#xff09;&#xff1a;对图像按其中的物体进行分类&#xff0c;如图像中有人与猫&#xff0c;则图像可分为两类。 2、目标检测&#xff08;object detection&#xff09;&a…...

使用KEIL自带的仿真器仿真遇到问题解决

*** error 65: access violation at 0x40021000 : no read permission 修改debug选项设置为下方内容。...

4700 万美元损失,Xn00d 合约漏洞攻击事件分析

4700 万美元损失&#xff0c;Xn00d 合约漏洞攻击事件分析 基础知识 ERC777 ERC777 是 ERC20 标准的高级代币标准&#xff0c;要提供了一些新的功能&#xff1a;运营商及钩子。 运营商功能。通过此功能能够允许第三方账户代表某一合约或者地址 进行代币的发送交易钩子功能。…...

第5讲:v-if与v-show的使用方法及区别

v-if条件判断 v-if是条件渲染指令&#xff0c;它根据表达式的真假来删除和插入元素&#xff0c;它的基本语法如下&#xff1a; v-if “expression” expression是一个返回bool值的表达式&#xff0c;表达式可以是一个bool属性&#xff0c;也可以是一个返回bool的运算式 &#…...

C理解(一):内存与位操作

本文主要探讨C语言的内存和为操作操作相关知识。 冯诺依曼结构和哈佛结构 冯诺依曼结构&#xff1a;数据和代码放在一起,便于读取和修改,安全性低 哈佛结构是&#xff1a;数据和代码分开存放,安全性高,读取和修麻烦 内存 内存是用来存储全局变量、局…...

ESP8266使用记录(四)

放上最终效果 ESP8266&Unity游戏 整合放进了坏玩具车遥控器里 最终只使用了mpu6050的yaw数据&#xff0c;因为roll值漂移…… 使用了https://github.com/ElectronicCats/mpu6050 整个流程 ESP8266取MPU6050数据&#xff0c;处理后通过udp发送给Unity显示出来 MPU6050_Z…...

网络科技公司网站建设策划/百度热搜榜排名今日p2p

Java技术栈www.javastack.cn关注阅读更多优质文章作者&#xff1a;稻草叔叔来源&#xff1a;juejin.im/post/6844903635533594632Git 是目前最流行的源代码管理工具。为规范开发&#xff0c;保持代码提交记录以及 git 分支结构清晰&#xff0c;方便后续维护&#xff0c;现规范 …...

dede网站地图位置/关键词排名代发

2019独角兽企业重金招聘Python工程师标准>>> 其实微博是个好东西&#xff0c;关注一些技术博主之后&#xff0c;你不用再逛好多论坛了&#xff0c;因为一些很好的文章微博会告诉你&#xff0c;最近看到酷勤网推荐的一篇文章《30个提 高Web程序执行效率的好经验》&am…...

阿里服务器可以做多少个网站/网站推广投放

DataTable da CommonBLL.GetList("*", "sys_dict", "IfState1 and DictTypeId14");string strField "CACCNUM as 账号账号,Loannumber as 借据号,BILLDATE as 借款时间,CAName as 借款人姓名,CPOSITION as 质押商品房位置,LoanAmount as …...

机械厂网站模板/国际购物网站平台有哪些

很多初学者问小编是如何学习Java的&#xff0c;有没有好的建议&#xff1f; 今天给大家来点干货&#xff0c;因此咱们就不说一些学习方法和技巧了&#xff0c;直接来谈每个阶段要学习的内容甚至是一些书籍。这一部分的内容&#xff0c;同样适用于一些希望转行到Java的同学。 对…...

河南省城乡和住房建设厅/班级优化大师是干什么用的

目录概念&#xff1a;通俗理解&#xff1a;可重入锁的工作原理&#xff1a;ReenTrantLock可重入锁和synchronized的区别&#xff1a;ReentrantLock源码分析:可重入锁代码演示&#xff1a;概念&#xff1a; Reentrant Re entrant&#xff0c;Re是重复、又、再的意思&#xff0…...

嘉兴商城网站开发设计/推广方案如何写

1号进程是什么 当我们使用 /bin/bash 启动一个centos的容器 docker run -it --rm centos:7 /bin/bash那么启动命令就是1号进程 [rootded49b74042c /]# ps aux USER PID %CPU %MEM VSZ RSS TTY STAT START TIME COMMAND root 1 0.2 0.0 11836 …...