伏羲0.07(文生图)
为了使0.06代码能够有效运行并输出项目目录及所有文件,我们在代码中添加一些额外的功能。
- 项目目录结构
项目目录结构如下:
text_to_image_project/
│
├── config.yaml
├── data/
│ ├── train_data.csv
│ └── test_data.txt
├── models/
│ └── text_to_image_model.pth
├── main.py
└── README.md
- 示例配置文件
config.yaml
model:path: models/text_to_image_model.pthtext_encoder_model_name: bert-base-uncaseddata:input_file: data/test_data.txtoutput_dir: data/generated_imagesdataset_path: data/train_data.csvtraining:batch_size: 64learning_rate: 0.0002epochs: 100
- 示例训练数据文件
data/train_data.csv
text,image_path
a beautiful sunset,data/images/sunset.jpg
a cute puppy,data/images/puppy.jpg
a red rose,data/images/rose.jpg
- 示例测试数据文件
data/test_data.txt
a beautiful sunset
a cute puppy
a red rose
- 完善后的代码
main.py
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import yaml
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import random
import numpy as np# 配置文件加载
def load_config(config_path):with open(config_path, 'r', encoding='utf-8') as file:config = yaml.safe_load(file)return config# 数据加载
def load_text_data(file_path):with open(file_path, 'r', encoding='utf-8') as file:text_data = file.readlines()return [line.strip() for line in text_data]# 数据清洗
def clean_data(data):return data.dropna().drop_duplicates()# 数据增强
def augment_data(image, mode):if mode == 'train':transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomRotation(10),transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])else:transform = transforms.Compose([transforms.Resize((64, 64)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])return transform(image)# 文本编码器
class TextEncoder(nn.Module):def __init__(self, model_name):super(TextEncoder, self).__init__()self.tokenizer = AutoTokenizer.from_pretrained(model_name)self.model = AutoModel.from_pretrained(model_name)def forward(self, text):inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)outputs = self.model(**inputs)return outputs.last_hidden_state.mean(dim=1)# 图像生成器
class ImageGenerator(nn.Module):def __init__(self, in_channels):super(ImageGenerator, self).__init__()self.decoder = nn.Sequential(nn.ConvTranspose2d(in_channels, 512, kernel_size=4, stride=1, padding=0),nn.BatchNorm2d(512),nn.ReLU(True),nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(64),nn.ReLU(True),nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),nn.Tanh())def forward(self, x):x = x.view(-1, x.size(1), 1, 1)return self.decoder(x)# 判别器
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),nn.Sigmoid())def forward(self, x):return self.main(x)# 模型定义
class TextToImageModel(nn.Module):def __init__(self, text_encoder_model_name):super(TextToImageModel, self).__init__()self.text_encoder = TextEncoder(text_encoder_model_name)self.image_generator = ImageGenerator(768) # 768 is the hidden size of BERTdef forward(self, text):text_features = self.text_encoder(text)return self.image_generator(text_features)# 模型加载
def load_model(model_path, text_encoder_model_name):model = TextToImageModel(text_encoder_model_name)if os.path.exists(model_path):model.load_state_dict(torch.load(model_path))model.eval()return model# 图像保存
def save_image(image, path):if not os.path.exists(os.path.dirname(path)):os.makedirs(os.path.dirname(path))image.save(path)# 数据集类
class TextToImageDataset(Dataset):def __init__(self, csv_file, transform=None, mode='train'):self.data = pd.read_csv(csv_file)self.data = clean_data(self.data)self.transform = transformself.mode = modedef __len__(self):return len(self.data)def __getitem__(self, idx):text = self.data.iloc[idx]['text']image_path = self.data.iloc[idx]['image_path']image = Image.open(image_path).convert('RGB')if self.transform:image = self.transform(image, self.mode)return text, image# 模型训练
def train_model(config):transform = transforms.Compose([transforms.Resize((64, 64)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])dataset = TextToImageDataset(config['training']['dataset_path'], transform=augment_data, mode='train')dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True)model = TextToImageModel(config['model']['text_encoder_model_name'])discriminator = Discriminator()optimizer_g = optim.Adam(model.parameters(), lr=config['training']['learning_rate'])optimizer_d = optim.Adam(discriminator.parameters(), lr=config['training']['learning_rate'])criterion_gan = nn.BCELoss()criterion_l1 = nn.L1Loss()for epoch in range(config['training']['epochs']):model.train()discriminator.train()running_loss_g = 0.0running_loss_d = 0.0for i, (text, images) in enumerate(dataloader):real_labels = torch.ones(images.size(0), 1)fake_labels = torch.zeros(images.size(0), 1)# Train Discriminatoroptimizer_d.zero_grad()real_outputs = discriminator(images)d_loss_real = criterion_gan(real_outputs, real_labels)generated_images = model(text)fake_outputs = discriminator(generated_images.detach())d_loss_fake = criterion_gan(fake_outputs, fake_labels)d_loss = (d_loss_real + d_loss_fake) / 2d_loss.backward()optimizer_d.step()# Train Generatoroptimizer_g.zero_grad()generated_images = model(text)g_outputs = discriminator(generated_images)g_loss_gan = criterion_gan(g_outputs, real_labels)g_loss_l1 = criterion_l1(generated_images, images)g_loss = g_loss_gan + 100 * g_loss_l1 # Weighted sum of GAN loss and L1 lossg_loss.backward()optimizer_g.step()running_loss_g += g_loss.item()running_loss_d += d_loss.item()print(f"Epoch {epoch + 1}, Generator Loss: {running_loss_g / len(dataloader)}, Discriminator Loss: {running_loss_d / len(dataloader)}")# 保存训练好的模型torch.save(model.state_dict(), config['model']['path'])# 图像生成
def generate_images(model, text_data, output_dir):for text in text_data:input_tensor = model.text_encoder([text])image = model.image_generator(input_tensor)image = image.squeeze(0).detach().cpu().numpy()image = (image * 127.5 + 127.5).astype('uint8')image = Image.fromarray(image.transpose(1, 2, 0))# 保存图像save_image(image, f"{output_dir}/{text}.png")# 图形用户界面
class TextToImageGUI:def __init__(self, root):self.root = rootself.root.title("文本生成图像")self.config = load_config('config.yaml')self.model = load_model(self.config['model']['path'], self.config['model']['text_encoder_model_name'])self.text_input = tk.Text(root, height=10, width=50)self.text_input.pack(pady=10)self.train_button = tk.Button(root, text="训练模型", command=self.train_model)self.train_button.pack(pady=10)self.generate_button = tk.Button(root, text="生成图像", command=self.generate_image)self.generate_button.pack(pady=10)self.image_label = tk.Label(root)self.image_label.pack(pady=10)def train_model(self):train_model(self.config)self.model = load_model(self.config['model']['path'], self.config['model']['text_encoder_model_name'])messagebox.showinfo("成功", "模型训练完成")def generate_image(self):text = self.text_input.get("1.0", tk.END).strip()if not text:messagebox.showwarning("警告", "请输入文本")returninput_tensor = self.model.text_encoder([text])image = self.model.image_generator(input_tensor)image = image.squeeze(0).detach().cpu().numpy()image = (image * 127.5 + 127.5).astype('uint8')image = Image.fromarray(image.transpose(1, 2, 0))# 显示图像img_tk = ImageTk.PhotoImage(image)self.image_label.config(image=img_tk)self.image_label.image = img_tk# 保存图像save_image(image, f"{self.config['data']['output_dir']}/{text}.png")messagebox.showinfo("成功", "图像已生成并保存")# 输出项目目录及所有文件
def list_files(startpath):for root, dirs, files in os.walk(startpath):level = root.replace(startpath, '').count(os.sep)indent = ' ' * 4 * (level)print('{}{}/'.format(indent, os.path.basename(root)))subindent = ' ' * 4 * (level + 1)for f in files:print('{}{}'.format(subindent, f))if __name__ == "__main__":config = load_config('config.yaml')# 输出项目目录及所有文件project_root = os.path.dirname(os.path.abspath(__file__))print("项目目录及所有文件:")list_files(project_root)# 加载模型model = load_model(config['model']['path'], config['model']['text_encoder_model_name'])# 加载文本数据text_data = load_text_data(config['data']['input_file'])# 生成图像generate_images(model, text_data, config['data']['output_dir'])# 启动图形用户界面root = tk.Tk()app = TextToImageGUI(root)root.mainloop()
-
项目目录及所有文件输出
在 main.py 中添加了一个 list_files 函数,用于输出项目目录及所有文件。这个函数会在程序启动时自动调用,输出当前项目的目录结构和所有文件。 -
说明文档
README.md
在项目根目录下创建一个 README.md 文件,内容如下:
# 文本生成图像项目## 目录结构
text_to_image_project/
│
├── config.yaml
├── data/
│ ├── train_data.csv
│ └── test_data.txt
├── models/
│ └── text_to_image_model.pth
├── main.py
└── README.mdcode## 配置文件 `config.yaml`
```yaml
model:path: models/text_to_image_model.pthtext_encoder_model_name: bert-base-uncaseddata:input_file: data/test_data.txtoutput_dir: data/generated_imagesdataset_path: data/train_data.csvtraining:batch_size: 64learning_rate: 0.0002epochs: 100
训练数据文件
data/train_data.csv
csv
text,image_path
a beautiful sunset,data/images/sunset.jpg
a cute puppy,data/images/puppy.jpg
a red rose,data/images/rose.jpg
测试数据文件
data/test_data.txt
code
a beautiful sunset
a cute puppy
a red rose
运行项目
确保安装了所需的依赖库:bash
pip install torch torchvision transformers pillow tkinter
运行 main.py:bash
python main.py
功能
训练模型:点击“训练模型”按钮,开始训练模型。
生成图像:在文本框中输入文本,点击“生成图像”按钮,生成相应的图像并显示在界面上,同时保存到指定目录。
项目目录及所有文件
项目启动时会自动输出项目目录及所有文件。
希望这些改进和示例文件能帮助你更好地理解和运行项目。如果有任何问题或需要进一步的帮助,请随时告诉我!
相关文章:
伏羲0.07(文生图)
为了使0.06代码能够有效运行并输出项目目录及所有文件,我们在代码中添加一些额外的功能。 项目目录结构 项目目录结构如下: text_to_image_project/ │ ├── config.yaml ├── data/ │ ├── train_data.csv │ └── test_data.txt ├── mod…...
scala的泛型特质的应用场景
//泛型特质的应用场景 //作比较找出最大值 //定义一个函数,用来求List元素中的最大值参考代码:object Test4 {def getMax[T](list:List[T])(implicit ev:T > Ordered[T]): T {list.reduce((a:T,b:T)> if(a>b) a else b)}def main(args: Array…...
Win10环境vscode+latex+中文快速配置
安装vscodelatex workshop 配置: {"liveServer.settings.donotVerifyTags": true,"liveServer.settings.donotShowInfoMsg": true,"explorer.confirmDelete": false,"files.autoSave": "afterDelay","exp…...
【vue2】el-select,虚拟滚动(vue-virtual-scroller)
需求背景 vue2+element-ui项目中,当el-select中数据量较大时(超出5000个dom节点),会导致页面加载和渲染卡顿、el-select下拉列表延迟展开。 在现在的el-select的基础上使用分页或者虚拟列表的形式去处理大量的下拉菜单,可以保证页面的正常渲染及el-select的…...
【ETCD】[源码阅读]深度解析 EtcdServer 的 processInternalRaftRequestOnce 方法
在分布式系统中,etcd 的一致性与高效性得益于其强大的 Raft 协议模块。而 processInternalRaftRequestOnce 是 etcd 服务器处理内部 Raft 请求的核心方法之一。本文将从源码角度解析这个方法的逻辑流程,帮助读者更好地理解 etcd 的内部实现。 方法源码 …...
【RabbitMQ】RabbitMQ中核心概念交换机(Exchange)、队列(Queue)和路由键(Routing Key)等详细介绍
博主介绍:✌全网粉丝21W,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...
【AI知识】过拟合、欠拟合和正则化
一句话总结: 过拟合和欠拟合是机器学习中的两个相对的概念,正则化是用于解决过拟合的方法。 1. 欠拟合: 指模型在训练数据上表现不佳,不能充分捕捉数据的潜在规律,导致在训练集和测试集上的误差都很高。欠拟合意味着模…...
计算机毕设-基于springboot的航空散货调度系统的设计与实现(附源码+lw+ppt+开题报告)
博主介绍:✌多个项目实战经验、多个大型网购商城开发经验、在某机构指导学员上千名、专注于本行业领域✌ 技术范围:Java实战项目、Python实战项目、微信小程序/安卓实战项目、爬虫大数据实战项目、Nodejs实战项目、PHP实战项目、.NET实战项目、Golang实战…...
视图、转发与重定向、静态资源处理
目录 视图 默认视图 视图机制原理 自定义视图 请求转发与重定向 静态资源处理 视图 每个视图解析器都实现了 Ordered 接口并开放出一个 order 属性 可以通过 order 属性指定解析器的优先顺序,order 越小优先级越高 默认是最低优先级,Integer.MAX_…...
优选算法——分治(快排)
1. 颜色分类 题目链接:75. 颜色分类 - 力扣(LeetCode) 题目展示: 题目分析:本题其实就要将数组最终分成3块儿,这也是后面快排的优化思路,具体大家来看下图。 这里我们上来先定义了3个指针&…...
【Linux系统】文件系统
Windows 和 Linux 的文件系统: windows:NTFS —> NTFS:磁盘大于目录:目录是磁盘的一部分。ubuntu :EXT4 —> EXT4: 目录大于磁盘:磁盘是目录的一部分。 Windows文件系统的特点 基于分区的文件系统: Windows…...
javaweb的基础
文章的简介: 页面的展示(HTML)页面的修改、绑定、弹窗(js的dom、bom等)页面的请求(Ajax) 1、在HTML中用标签和css样式实现了浏览器页面。 2、用JS实现页面内容(图片,复选框、文本颜色内容)的修改和弹框&…...
家里养几条金鱼比较好?
金鱼,作为备受喜爱的家庭水族宠物,其饲养数量一直是众多养鱼爱好者关注的焦点。究竟养几条金鱼最为适宜,实则需要综合考量多方面因素,方能达到美观、健康与和谐的理想养鱼境界。 从风水文化的视角来看,金鱼数量有着诸…...
写作词汇积累:差池、一体两面、切实可行极简理解
差池 【差池】可以是名词,是指意外的事或错误。 【差池】也可以是形容词,是指参差不齐、差劲或不行。 1. 由于操作不当,导致这次实验出现了【差池】,我们需要重新分析原因并调整方案。(名词,表示意外的事…...
移远EC200A-CN的OPENCPU使用GO开发嵌入式程序TBOX
演示地址: http://134.175.123.194:8811 admin admin 演示视频: https://www.bilibili.com/video/BV196q2YQEDP 主要功能 WatchDog 1. 守护进程 2. OTA远程升级 TBOX 1. 数据采集、数据可视化、数据上报(内置Modbus TCP/RTU/ASCII,GPS协…...
LEED绿色建筑认证最新消息
关于LEED绿色建筑认证的最新消息,可以从以下几个方面进行概述: 一、认证体系更新与发展 LEED认证体系不断更新和完善,以更好地适应全球绿色建筑的发展趋势。例如,LEED v4能源更新已通过投票,并于2024年3月1日全面启用…...
SpringBoot中集成常见邮箱中容易出现的问题
本来也没打算想写得。不过也是遇到一些坑,就记录一下吧,也折腾了小半天 1.maven配置 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-mail</artifactId></dependency>2…...
webstorm开发uniapp(从安装到项目运行)
1、下载uniapp插件 下载连接:Uniapp Tool - IntelliJ IDEs Plugin | Marketplace (结合自己的webstorm版本下载,不然解析不了) 将下载到的zip文件防在webstorm安装路径下,本文的地址为: 2、安装uniapp插…...
C# 探险之旅:第七节 - 条件判断(三元判断符):? : 的奇妙冒险
嘿,勇敢的探险家们!欢迎来到 C# 编程世界的奇妙之旅的第七节。今天,我们要探索的是一个神秘而强大的宝藏——三元判断符 ? :。别怕,它听起来复杂,但实际上比找宝藏还简单! 场景设定:宝藏的选择…...
FlinkCDC实战:将 MySQL 数据同步至 ES
📌 当前需要处理的业务场景: 将订单表和相关联的表(比如: 商品表、子订单表、物流信息表)组织成宽表, 放入到 ES 中, 加速订单数据的查询. 同步数据到 es. 概述 1. 什么是 CDC 2. 什么是 Flink CDC 3. Flink CDC Connectors 和 Flink 的版本映射 实战 1. 宽表查…...
debug小记
红框: 步过:遇到方法不想进入方法 绿框:代码跑在第几行也可以看见 蓝框:可以显示变量的值,三种方式都可以看变量的值...
Qt C++ 显示多级结构体,包括结构体名、变量名和值
文章目录 mainwindow.hmainwindow.cppstructures.hmain.cpp QTreeView 和 QStandardItemModel 来实现。以下是实现这一功能的步骤和示例代码: 定义多级结构体: 假设你有一个多级结构体,如下所示: struct SubStruct {int subValue…...
【JAVA】旅游行业中大数据的使用
一、应用场景 数据采集与整合:全面收集旅游数据,如客流量、游客满意度等,整合形成统一数据集,为后续分析提供便利。 舆情监测与分析:实时监测旅游目的地的舆情信息,运用NLP算法进行智能处理,及…...
【AI+网络/仿真数据集】1分钟搭建云原生端到端5G网络
导语: 近期智慧网络开放创新平台上线了端到端网络仿真能力,区别于传统的网络仿真工具需要复杂的领域知识可界面操作,该平台的网络仿真能力主打一个小白友好和功能专业。 https://jiutian.10086.cn/open/jiutian.10086.cn/open/ 端到端仿…...
微服务-01【续】
1.OpenFeign 上篇文章我们利用Nacos实现了服务的治理,利用利用RestTemplate实现了服务的远程调用。但是远程调用的代码太复杂了: 而且这种调用方式,与原本的本地方法调用差异太大,编程时的体验也不统一,一会儿远程调用…...
测试工程师八股文01|Linux系统操作
一、Linux系统操作 1、gzip tar和gzip结合使用 $ tar czf b.tar.gz *txt 以gzip方式打包并且压缩 $ tar xzf b.tar.gz -C btar 以gzip方式解压并解包,如果 btar 目录不存在,则需要先手动创建该目录。 代码第二行:如果没有指定 -C …...
【Qt】qt基础
目录 一、使用Qt Creator创建qt项目 二、项目文件解析 三、Qt中创建图形化界面的程序的两种方法 四、对象树 五、Qt中处理打印乱码问题的利器:qDebug() 一、使用Qt Creator创建qt项目 1.选择项目模板 选中第一类模板Application(Qt应用程序,包含普…...
UniScene:Video、LiDAR 和Occupancy全面SOTA
论文: https://arxiv.org/pdf/2412.05435 项目页面:https://arlo0o.github.io/uniscene/ 0. 摘要 生成高保真度、可控制且带有标注的训练数据对于自动驾驶至关重要。现有方法通常直接从粗糙的场景布局生成单一形式的数据,这不仅无法输出多样化下游任务…...
TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(1)——神经网络与模型训练过程详解 0. 前言1. 神经网络基础1.1 神经网络简介1.2 神经网络的训练1.3 神经网络的应用 2. 从零开始构建前向传播2.1 计算隐藏层节点值2.2 应用激活函数2.3 计算输出层值2.4 计算损失值2.4.1 在连续变…...
03篇--二值化与自适应二值化
二值化 定义 何为二值化?顾名思义,就是将图像中的像素值改为只有两种值,黑与白。此为二值化。 二值化操作的图像只能是灰度图,意思就是二值化也是一个二维数组,它与灰度图都属于单信道,仅能表示一种色调…...
温岭网站开发/朋友圈信息流广告投放价格
Exchange Server 2016 RTM 可以在 Windows Server 2012、Windows Server 2012 R2 和 Windows Server 2016 的标准版和企业版中进行部署安装,但 Exchange Server 2016 只支持完整 GUI 的 Windows Server 而不支持核心安装的系统,本文我们将对快速部署方式…...
wordpress文章关联/网盘搜索神器
4个类就可体验HADOOP-RPC的简单、实用。 一,writable-rpc 一,协议 所谓协议,实际是一个接口,用来定义该RPC提供的功能 package rpc.writeable;public interface MyProtocol {void mkdir(String path);String getName(String na…...
云南专业网站建设/外贸网站制作推广
算是狗年上班的最后一天吧,想想还是略略总结一下近半年来的概况。 这段时间比较懒得去总结更新发表新的博客,一来是生活和工作的琐碎让自己有些懈怠,二是对自己写的东西缺乏深度感到困扰,自己大概也带着些完美型人格吧。最近的工…...
烟台开发区人才市场招聘信息/合肥网络公司seo建站
奔小康赚大钱Time Limit: 1000/1000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others)Total Submission(s): 2325 Accepted Submission(s): 1020Problem Description传说在遥远的地方有一个非常富裕的村落,有一天,村长决定进行制度改革:重新分配房…...
谷歌网站推广/广州最新消息今天
git diff 检查更新 git fetch #需要先 fetch git diff master..origin/master --name-only -- [path] #path:指定检查 可以是文件或者文件夹,--name-only:只列出有变化的文件名 git diff master..origin/master --name-only -- [path] …...
河南省建设信息管理协会/seog
ORACLE下删除当前用户下所有对象的SQLSql代码--删除某个用户下的对象set heading off;set feedback off;spool c:\dropobj.sql;prompt --Drop constraintselect alter table ||table_name|| drop constraint ||constraint_name|| ; from user_constraints where constraint_typ…...