伏羲0.07(文生图)
为了使0.06代码能够有效运行并输出项目目录及所有文件,我们在代码中添加一些额外的功能。
- 项目目录结构
项目目录结构如下:
text_to_image_project/
│
├── config.yaml
├── data/
│ ├── train_data.csv
│ └── test_data.txt
├── models/
│ └── text_to_image_model.pth
├── main.py
└── README.md
- 示例配置文件
config.yaml
model:path: models/text_to_image_model.pthtext_encoder_model_name: bert-base-uncaseddata:input_file: data/test_data.txtoutput_dir: data/generated_imagesdataset_path: data/train_data.csvtraining:batch_size: 64learning_rate: 0.0002epochs: 100
- 示例训练数据文件
data/train_data.csv
text,image_path
a beautiful sunset,data/images/sunset.jpg
a cute puppy,data/images/puppy.jpg
a red rose,data/images/rose.jpg
- 示例测试数据文件
data/test_data.txt
a beautiful sunset
a cute puppy
a red rose
- 完善后的代码
main.py
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import yaml
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import random
import numpy as np# 配置文件加载
def load_config(config_path):with open(config_path, 'r', encoding='utf-8') as file:config = yaml.safe_load(file)return config# 数据加载
def load_text_data(file_path):with open(file_path, 'r', encoding='utf-8') as file:text_data = file.readlines()return [line.strip() for line in text_data]# 数据清洗
def clean_data(data):return data.dropna().drop_duplicates()# 数据增强
def augment_data(image, mode):if mode == 'train':transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomRotation(10),transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])else:transform = transforms.Compose([transforms.Resize((64, 64)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])return transform(image)# 文本编码器
class TextEncoder(nn.Module):def __init__(self, model_name):super(TextEncoder, self).__init__()self.tokenizer = AutoTokenizer.from_pretrained(model_name)self.model = AutoModel.from_pretrained(model_name)def forward(self, text):inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)outputs = self.model(**inputs)return outputs.last_hidden_state.mean(dim=1)# 图像生成器
class ImageGenerator(nn.Module):def __init__(self, in_channels):super(ImageGenerator, self).__init__()self.decoder = nn.Sequential(nn.ConvTranspose2d(in_channels, 512, kernel_size=4, stride=1, padding=0),nn.BatchNorm2d(512),nn.ReLU(True),nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(64),nn.ReLU(True),nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),nn.Tanh())def forward(self, x):x = x.view(-1, x.size(1), 1, 1)return self.decoder(x)# 判别器
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),nn.Sigmoid())def forward(self, x):return self.main(x)# 模型定义
class TextToImageModel(nn.Module):def __init__(self, text_encoder_model_name):super(TextToImageModel, self).__init__()self.text_encoder = TextEncoder(text_encoder_model_name)self.image_generator = ImageGenerator(768) # 768 is the hidden size of BERTdef forward(self, text):text_features = self.text_encoder(text)return self.image_generator(text_features)# 模型加载
def load_model(model_path, text_encoder_model_name):model = TextToImageModel(text_encoder_model_name)if os.path.exists(model_path):model.load_state_dict(torch.load(model_path))model.eval()return model# 图像保存
def save_image(image, path):if not os.path.exists(os.path.dirname(path)):os.makedirs(os.path.dirname(path))image.save(path)# 数据集类
class TextToImageDataset(Dataset):def __init__(self, csv_file, transform=None, mode='train'):self.data = pd.read_csv(csv_file)self.data = clean_data(self.data)self.transform = transformself.mode = modedef __len__(self):return len(self.data)def __getitem__(self, idx):text = self.data.iloc[idx]['text']image_path = self.data.iloc[idx]['image_path']image = Image.open(image_path).convert('RGB')if self.transform:image = self.transform(image, self.mode)return text, image# 模型训练
def train_model(config):transform = transforms.Compose([transforms.Resize((64, 64)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])dataset = TextToImageDataset(config['training']['dataset_path'], transform=augment_data, mode='train')dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True)model = TextToImageModel(config['model']['text_encoder_model_name'])discriminator = Discriminator()optimizer_g = optim.Adam(model.parameters(), lr=config['training']['learning_rate'])optimizer_d = optim.Adam(discriminator.parameters(), lr=config['training']['learning_rate'])criterion_gan = nn.BCELoss()criterion_l1 = nn.L1Loss()for epoch in range(config['training']['epochs']):model.train()discriminator.train()running_loss_g = 0.0running_loss_d = 0.0for i, (text, images) in enumerate(dataloader):real_labels = torch.ones(images.size(0), 1)fake_labels = torch.zeros(images.size(0), 1)# Train Discriminatoroptimizer_d.zero_grad()real_outputs = discriminator(images)d_loss_real = criterion_gan(real_outputs, real_labels)generated_images = model(text)fake_outputs = discriminator(generated_images.detach())d_loss_fake = criterion_gan(fake_outputs, fake_labels)d_loss = (d_loss_real + d_loss_fake) / 2d_loss.backward()optimizer_d.step()# Train Generatoroptimizer_g.zero_grad()generated_images = model(text)g_outputs = discriminator(generated_images)g_loss_gan = criterion_gan(g_outputs, real_labels)g_loss_l1 = criterion_l1(generated_images, images)g_loss = g_loss_gan + 100 * g_loss_l1 # Weighted sum of GAN loss and L1 lossg_loss.backward()optimizer_g.step()running_loss_g += g_loss.item()running_loss_d += d_loss.item()print(f"Epoch {epoch + 1}, Generator Loss: {running_loss_g / len(dataloader)}, Discriminator Loss: {running_loss_d / len(dataloader)}")# 保存训练好的模型torch.save(model.state_dict(), config['model']['path'])# 图像生成
def generate_images(model, text_data, output_dir):for text in text_data:input_tensor = model.text_encoder([text])image = model.image_generator(input_tensor)image = image.squeeze(0).detach().cpu().numpy()image = (image * 127.5 + 127.5).astype('uint8')image = Image.fromarray(image.transpose(1, 2, 0))# 保存图像save_image(image, f"{output_dir}/{text}.png")# 图形用户界面
class TextToImageGUI:def __init__(self, root):self.root = rootself.root.title("文本生成图像")self.config = load_config('config.yaml')self.model = load_model(self.config['model']['path'], self.config['model']['text_encoder_model_name'])self.text_input = tk.Text(root, height=10, width=50)self.text_input.pack(pady=10)self.train_button = tk.Button(root, text="训练模型", command=self.train_model)self.train_button.pack(pady=10)self.generate_button = tk.Button(root, text="生成图像", command=self.generate_image)self.generate_button.pack(pady=10)self.image_label = tk.Label(root)self.image_label.pack(pady=10)def train_model(self):train_model(self.config)self.model = load_model(self.config['model']['path'], self.config['model']['text_encoder_model_name'])messagebox.showinfo("成功", "模型训练完成")def generate_image(self):text = self.text_input.get("1.0", tk.END).strip()if not text:messagebox.showwarning("警告", "请输入文本")returninput_tensor = self.model.text_encoder([text])image = self.model.image_generator(input_tensor)image = image.squeeze(0).detach().cpu().numpy()image = (image * 127.5 + 127.5).astype('uint8')image = Image.fromarray(image.transpose(1, 2, 0))# 显示图像img_tk = ImageTk.PhotoImage(image)self.image_label.config(image=img_tk)self.image_label.image = img_tk# 保存图像save_image(image, f"{self.config['data']['output_dir']}/{text}.png")messagebox.showinfo("成功", "图像已生成并保存")# 输出项目目录及所有文件
def list_files(startpath):for root, dirs, files in os.walk(startpath):level = root.replace(startpath, '').count(os.sep)indent = ' ' * 4 * (level)print('{}{}/'.format(indent, os.path.basename(root)))subindent = ' ' * 4 * (level + 1)for f in files:print('{}{}'.format(subindent, f))if __name__ == "__main__":config = load_config('config.yaml')# 输出项目目录及所有文件project_root = os.path.dirname(os.path.abspath(__file__))print("项目目录及所有文件:")list_files(project_root)# 加载模型model = load_model(config['model']['path'], config['model']['text_encoder_model_name'])# 加载文本数据text_data = load_text_data(config['data']['input_file'])# 生成图像generate_images(model, text_data, config['data']['output_dir'])# 启动图形用户界面root = tk.Tk()app = TextToImageGUI(root)root.mainloop()
-
项目目录及所有文件输出
在 main.py 中添加了一个 list_files 函数,用于输出项目目录及所有文件。这个函数会在程序启动时自动调用,输出当前项目的目录结构和所有文件。 -
说明文档
README.md
在项目根目录下创建一个 README.md 文件,内容如下:
# 文本生成图像项目## 目录结构
text_to_image_project/
│
├── config.yaml
├── data/
│ ├── train_data.csv
│ └── test_data.txt
├── models/
│ └── text_to_image_model.pth
├── main.py
└── README.mdcode## 配置文件 `config.yaml`
```yaml
model:path: models/text_to_image_model.pthtext_encoder_model_name: bert-base-uncaseddata:input_file: data/test_data.txtoutput_dir: data/generated_imagesdataset_path: data/train_data.csvtraining:batch_size: 64learning_rate: 0.0002epochs: 100
训练数据文件
data/train_data.csv
csv
text,image_path
a beautiful sunset,data/images/sunset.jpg
a cute puppy,data/images/puppy.jpg
a red rose,data/images/rose.jpg
测试数据文件
data/test_data.txt
code
a beautiful sunset
a cute puppy
a red rose
运行项目
确保安装了所需的依赖库:bash
pip install torch torchvision transformers pillow tkinter
运行 main.py:bash
python main.py
功能
训练模型:点击“训练模型”按钮,开始训练模型。
生成图像:在文本框中输入文本,点击“生成图像”按钮,生成相应的图像并显示在界面上,同时保存到指定目录。
项目目录及所有文件
项目启动时会自动输出项目目录及所有文件。
希望这些改进和示例文件能帮助你更好地理解和运行项目。如果有任何问题或需要进一步的帮助,请随时告诉我!
相关文章:
伏羲0.07(文生图)
为了使0.06代码能够有效运行并输出项目目录及所有文件,我们在代码中添加一些额外的功能。 项目目录结构 项目目录结构如下: text_to_image_project/ │ ├── config.yaml ├── data/ │ ├── train_data.csv │ └── test_data.txt ├── mod…...
scala的泛型特质的应用场景
//泛型特质的应用场景 //作比较找出最大值 //定义一个函数,用来求List元素中的最大值参考代码:object Test4 {def getMax[T](list:List[T])(implicit ev:T > Ordered[T]): T {list.reduce((a:T,b:T)> if(a>b) a else b)}def main(args: Array…...
Win10环境vscode+latex+中文快速配置
安装vscodelatex workshop 配置: {"liveServer.settings.donotVerifyTags": true,"liveServer.settings.donotShowInfoMsg": true,"explorer.confirmDelete": false,"files.autoSave": "afterDelay","exp…...
【vue2】el-select,虚拟滚动(vue-virtual-scroller)
需求背景 vue2+element-ui项目中,当el-select中数据量较大时(超出5000个dom节点),会导致页面加载和渲染卡顿、el-select下拉列表延迟展开。 在现在的el-select的基础上使用分页或者虚拟列表的形式去处理大量的下拉菜单,可以保证页面的正常渲染及el-select的…...
【ETCD】[源码阅读]深度解析 EtcdServer 的 processInternalRaftRequestOnce 方法
在分布式系统中,etcd 的一致性与高效性得益于其强大的 Raft 协议模块。而 processInternalRaftRequestOnce 是 etcd 服务器处理内部 Raft 请求的核心方法之一。本文将从源码角度解析这个方法的逻辑流程,帮助读者更好地理解 etcd 的内部实现。 方法源码 …...
【RabbitMQ】RabbitMQ中核心概念交换机(Exchange)、队列(Queue)和路由键(Routing Key)等详细介绍
博主介绍:✌全网粉丝21W,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...
【AI知识】过拟合、欠拟合和正则化
一句话总结: 过拟合和欠拟合是机器学习中的两个相对的概念,正则化是用于解决过拟合的方法。 1. 欠拟合: 指模型在训练数据上表现不佳,不能充分捕捉数据的潜在规律,导致在训练集和测试集上的误差都很高。欠拟合意味着模…...
计算机毕设-基于springboot的航空散货调度系统的设计与实现(附源码+lw+ppt+开题报告)
博主介绍:✌多个项目实战经验、多个大型网购商城开发经验、在某机构指导学员上千名、专注于本行业领域✌ 技术范围:Java实战项目、Python实战项目、微信小程序/安卓实战项目、爬虫大数据实战项目、Nodejs实战项目、PHP实战项目、.NET实战项目、Golang实战…...
视图、转发与重定向、静态资源处理
目录 视图 默认视图 视图机制原理 自定义视图 请求转发与重定向 静态资源处理 视图 每个视图解析器都实现了 Ordered 接口并开放出一个 order 属性 可以通过 order 属性指定解析器的优先顺序,order 越小优先级越高 默认是最低优先级,Integer.MAX_…...
优选算法——分治(快排)
1. 颜色分类 题目链接:75. 颜色分类 - 力扣(LeetCode) 题目展示: 题目分析:本题其实就要将数组最终分成3块儿,这也是后面快排的优化思路,具体大家来看下图。 这里我们上来先定义了3个指针&…...
【Linux系统】文件系统
Windows 和 Linux 的文件系统: windows:NTFS —> NTFS:磁盘大于目录:目录是磁盘的一部分。ubuntu :EXT4 —> EXT4: 目录大于磁盘:磁盘是目录的一部分。 Windows文件系统的特点 基于分区的文件系统: Windows…...
javaweb的基础
文章的简介: 页面的展示(HTML)页面的修改、绑定、弹窗(js的dom、bom等)页面的请求(Ajax) 1、在HTML中用标签和css样式实现了浏览器页面。 2、用JS实现页面内容(图片,复选框、文本颜色内容)的修改和弹框&…...
家里养几条金鱼比较好?
金鱼,作为备受喜爱的家庭水族宠物,其饲养数量一直是众多养鱼爱好者关注的焦点。究竟养几条金鱼最为适宜,实则需要综合考量多方面因素,方能达到美观、健康与和谐的理想养鱼境界。 从风水文化的视角来看,金鱼数量有着诸…...
写作词汇积累:差池、一体两面、切实可行极简理解
差池 【差池】可以是名词,是指意外的事或错误。 【差池】也可以是形容词,是指参差不齐、差劲或不行。 1. 由于操作不当,导致这次实验出现了【差池】,我们需要重新分析原因并调整方案。(名词,表示意外的事…...
移远EC200A-CN的OPENCPU使用GO开发嵌入式程序TBOX
演示地址: http://134.175.123.194:8811 admin admin 演示视频: https://www.bilibili.com/video/BV196q2YQEDP 主要功能 WatchDog 1. 守护进程 2. OTA远程升级 TBOX 1. 数据采集、数据可视化、数据上报(内置Modbus TCP/RTU/ASCII,GPS协…...
LEED绿色建筑认证最新消息
关于LEED绿色建筑认证的最新消息,可以从以下几个方面进行概述: 一、认证体系更新与发展 LEED认证体系不断更新和完善,以更好地适应全球绿色建筑的发展趋势。例如,LEED v4能源更新已通过投票,并于2024年3月1日全面启用…...
SpringBoot中集成常见邮箱中容易出现的问题
本来也没打算想写得。不过也是遇到一些坑,就记录一下吧,也折腾了小半天 1.maven配置 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-mail</artifactId></dependency>2…...
webstorm开发uniapp(从安装到项目运行)
1、下载uniapp插件 下载连接:Uniapp Tool - IntelliJ IDEs Plugin | Marketplace (结合自己的webstorm版本下载,不然解析不了) 将下载到的zip文件防在webstorm安装路径下,本文的地址为: 2、安装uniapp插…...
C# 探险之旅:第七节 - 条件判断(三元判断符):? : 的奇妙冒险
嘿,勇敢的探险家们!欢迎来到 C# 编程世界的奇妙之旅的第七节。今天,我们要探索的是一个神秘而强大的宝藏——三元判断符 ? :。别怕,它听起来复杂,但实际上比找宝藏还简单! 场景设定:宝藏的选择…...
FlinkCDC实战:将 MySQL 数据同步至 ES
📌 当前需要处理的业务场景: 将订单表和相关联的表(比如: 商品表、子订单表、物流信息表)组织成宽表, 放入到 ES 中, 加速订单数据的查询. 同步数据到 es. 概述 1. 什么是 CDC 2. 什么是 Flink CDC 3. Flink CDC Connectors 和 Flink 的版本映射 实战 1. 宽表查…...
挑战杯推荐项目
“人工智能”创意赛 - 智能艺术创作助手:借助大模型技术,开发能根据用户输入的主题、风格等要求,生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用,帮助艺术家和创意爱好者激发创意、提高创作效率。 - 个性化梦境…...
【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器
一.自适应梯度算法Adagrad概述 Adagrad(Adaptive Gradient Algorithm)是一种自适应学习率的优化算法,由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率,适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...
镜像里切换为普通用户
如果你登录远程虚拟机默认就是 root 用户,但你不希望用 root 权限运行 ns-3(这是对的,ns3 工具会拒绝 root),你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案:创建非 roo…...
如何理解 IP 数据报中的 TTL?
目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...
3-11单元格区域边界定位(End属性)学习笔记
返回一个Range 对象,只读。该对象代表包含源区域的区域上端下端左端右端的最后一个单元格。等同于按键 End 向上键(End(xlUp))、End向下键(End(xlDown))、End向左键(End(xlToLeft)End向右键(End(xlToRight)) 注意:它移动的位置必须是相连的有内容的单元格…...
九天毕昇深度学习平台 | 如何安装库?
pip install 库名 -i https://pypi.tuna.tsinghua.edu.cn/simple --user 举个例子: 报错 ModuleNotFoundError: No module named torch 那么我需要安装 torch pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple --user pip install 库名&#x…...
【Redis】笔记|第8节|大厂高并发缓存架构实战与优化
缓存架构 代码结构 代码详情 功能点: 多级缓存,先查本地缓存,再查Redis,最后才查数据库热点数据重建逻辑使用分布式锁,二次查询更新缓存采用读写锁提升性能采用Redis的发布订阅机制通知所有实例更新本地缓存适用读多…...
DingDing机器人群消息推送
文章目录 1 新建机器人2 API文档说明3 代码编写 1 新建机器人 点击群设置 下滑到群管理的机器人,点击进入 添加机器人 选择自定义Webhook服务 点击添加 设置安全设置,详见说明文档 成功后,记录Webhook 2 API文档说明 点击设置说明 查看自…...
华为OD机考-机房布局
import java.util.*;public class DemoTest5 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseSystem.out.println(solve(in.nextLine()));}}priv…...
计算机基础知识解析:从应用到架构的全面拆解
目录 前言 1、 计算机的应用领域:无处不在的数字助手 2、 计算机的进化史:从算盘到量子计算 3、计算机的分类:不止 “台式机和笔记本” 4、计算机的组件:硬件与软件的协同 4.1 硬件:五大核心部件 4.2 软件&#…...
