智能聊天机器人:使用PyTorch构建多轮对话系统
使用PyTorch构建多轮对话系统的示例代码。这个示例项目包括一个简单的Seq2Seq模型用于对话生成,并使用GRU作为RNN的变体。以下是代码的主要部分,包括数据预处理、模型定义和训练循环。
数据预处理
首先,准备数据并进行预处理。这部分代码假定你有一个对话数据集,格式为成对的问答句子。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random# 假设数据集是一个成对的问答列表
pairs = [["Hi, how are you?", "I'm good, thank you! How about you?"],["What is your name?", "My name is Chatbot."],# 添加更多对话数据
]# 简单的词汇表和索引映射
word2index = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
index2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
vocab_size = len(word2index)def tokenize(sentence):return sentence.lower().split()def build_vocab(pairs):global word2index, index2word, vocab_sizefor pair in pairs:for sentence in pair:for word in tokenize(sentence):if word not in word2index:word2index[word] = vocab_sizeindex2word[vocab_size] = wordvocab_size += 1def sentence_to_tensor(sentence):tokens = tokenize(sentence)indices = [word2index.get(word, word2index["<UNK>"]) for word in tokens]return torch.tensor(indices + [word2index["<EOS>"]], dtype=torch.long)build_vocab(pairs)
数据集和数据加载
定义一个Dataset类和DataLoader来加载数据。
class ChatDataset(Dataset):def __init__(self, pairs):self.pairs = pairsdef __len__(self):return len(self.pairs)def __getitem__(self, idx):input_tensor = sentence_to_tensor(self.pairs[idx][0])target_tensor = sentence_to_tensor(self.pairs[idx][1])return input_tensor, target_tensordef collate_fn(batch):inputs, targets = zip(*batch)input_lengths = [len(seq) for seq in inputs]target_lengths = [len(seq) for seq in targets]inputs = nn.utils.rnn.pad_sequence(inputs, padding_value=word2index["<PAD>"])targets = nn.utils.rnn.pad_sequence(targets, padding_value=word2index["<PAD>"])return inputs, targets, input_lengths, target_lengthsdataset = ChatDataset(pairs)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn, shuffle=True)
模型定义
定义一个简单的Seq2Seq模型,包括编码器和解码器。
class Encoder(nn.Module):def __init__(self, input_size, hidden_size, num_layers=1):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers)def forward(self, input_seq, input_lengths, hidden=None):embedded = self.embedding(input_seq)packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, enforce_sorted=False)outputs, hidden = self.gru(packed, hidden)outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)return outputs, hiddenclass Decoder(nn.Module):def __init__(self, output_size, hidden_size, num_layers=1):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers)self.out = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input_step, hidden, encoder_outputs):embedded = self.embedding(input_step)gru_output, hidden = self.gru(embedded, hidden)output = self.softmax(self.out(gru_output.squeeze(0)))return output, hiddenclass Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, input_tensor, target_tensor, input_lengths, target_lengths, teacher_forcing_ratio=0.5):batch_size = input_tensor.size(1)max_target_len = max(target_lengths)vocab_size = self.decoder.out.out_featuresoutputs = torch.zeros(max_target_len, batch_size, vocab_size).to(self.device)encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)decoder_input = torch.tensor([[word2index["<SOS>"]] * batch_size]).to(self.device)decoder_hidden = encoder_hiddenfor t in range(max_target_len):decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)outputs[t] = decoder_outputtop1 = decoder_output.argmax(1)decoder_input = target_tensor[t].unsqueeze(0) if random.random() < teacher_forcing_ratio else top1.unsqueeze(0)return outputsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(vocab_size, hidden_size=256).to(device)
decoder = Decoder(vocab_size, hidden_size=256).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)
训练循环
定义训练循环并进行模型训练。
def train(model, dataloader, num_epochs, learning_rate=0.001):criterion = nn.CrossEntropyLoss(ignore_index=word2index["<PAD>"])optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):model.train()total_loss = 0for inputs, targets, input_lengths, target_lengths in dataloader:inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs, targets, input_lengths, target_lengths)loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader)}")train(model, dataloader, num_epochs=10)
测试与推理
定义一个简单的推理函数来进行对话生成。
def evaluate(model, sentence, max_length=10):model.eval()with torch.no_grad():input_tensor = sentence_to_tensor(sentence).unsqueeze(1).to(device)input_length = [input_tensor.size(0)]encoder_outputs, encoder_hidden = model.encoder(input_tensor, input_length)decoder_input = torch.tensor([[word2index["<SOS>"]]]).to(device)decoder_hidden = encoder_hiddendecoded_words = []for _ in range(max_length):decoder_output, decoder_hidden = model.decoder(decoder_input, decoder_hidden, encoder_outputs)top1 = decoder_output.argmax(1).item()if top1 == word2index["<EOS>"]:breakelse:decoded_words.append(index2word[top1])decoder_input = torch.tensor([[top1]]).to(device)return ' '.join(decoded_words)print(evaluate(model, "Hi, how are you?"))
总结
这只是一个简单的示例,用于展示如何使用PyTorch构建一个基本的多轮对话系统。实际应用中,可能需要更多的数据预处理、更复杂的模型(如Transformer)、更细致的训练策略和优化技术,以及更丰富的对话数据集。希望这个示例对你有所帮助!
相关文章:
智能聊天机器人:使用PyTorch构建多轮对话系统
使用PyTorch构建多轮对话系统的示例代码。这个示例项目包括一个简单的Seq2Seq模型用于对话生成,并使用GRU作为RNN的变体。以下是代码的主要部分,包括数据预处理、模型定义和训练循环。 数据预处理 首先,准备数据并进行预处理。这部分代码假…...
昇思25天学习打卡营第16天 | 文本解码原理-以MindNLP为例
基于 MindSpore 实现 BERT 对话情绪识别 上几章我们学习过了基于MindSpore来实现计算机视觉的一些应用,那么从这期开始要开始一个新的领域——LLM 首先了解一下什么是LLM LLM 是 “大型语言模型”(Large Language Model)的缩写。LLM 是一种…...
Unity之Text组件换行\n没有实现+动态中英互换
前因:文本中的换行 \n没有换行而是打印出来了,解决方式 因为unity会默认把\n替换成\\n 面板中使用富文本这个选项啊 没有用 m_text.text m_text.text.Replace("\\n", "\n"); ###动态中英文互译 using System.Collections; using…...
vue3+ el-tree 展开和折叠,默认展开第一项
默认第一项展开: 展开所有项: 折叠所有项: <template><el-treestyle"max-width: 600px":data"treeData"node-key"id":default-expanded-keys"defaultExpandedKey":props"defaultProps"…...
ProFormList --复杂数据联动ProFormDependency
需求: (1)数据联动:测试数据1、2互相依赖,测试数据1<测试数据2,测试数据2>测试数据1。 (2)点击添加按钮,添加一行。 (3)自定义操作按钮。 ࿰…...
Git、Github、tortoiseGit下载安装调试全套教程
一、Git 1.下载安装Git 编辑器可默认Vim,可换成别的,此处换成VScode,换成VScode或别的都需要单独下载和调用 (1)Git安装:https://www.cnblogs.com/xiuxingzhe/p/9300905.html 超级完整的 Git的下载、安…...
老师怎么快速发布成绩?
期末考试的钟声刚刚敲响,成绩单的发放却成了老师们的一大难题。每当期末成绩揭晓,老师们便要开始一项繁琐的任务——将每一份成绩单逐一私信给家长。这不仅耗费了大量的时间和精力,也让本就忙碌的期末工作变得更加繁重。然而,随着…...
央视揭露:上百元的AI填报高考志愿真的靠谱吗?阿里云新增两位AI圈“代言人”!|AI日报
文章推荐 MiniMax闫俊杰:国内模型远不及GPT-4;OpenAI隐瞒黑客曾入侵其内部系统|AI日报 今日热点 月之暗面、智联招聘成为阿里云新“代言人”,使用阿里云强大算力和大模型服务平台提升模型推理效率 7月8日,阿里云官…...
TPM管理咨询公司甄选指南
在竞争激烈的市场环境中,TPM(全面生产维护)管理咨询公司的重要性日益凸显。然而,如何在众多咨询公司中筛选出最适合自己企业的合作伙伴,成为了许多企业决策者面临的难题。本文将从专业度、行业经验、服务质量和性价比等…...
探索 Scikit-Learn:机器学习的强大工具库
Scikit-Learn 探索 Scikit-Learn:机器学习的强大工具库主要功能模块分类(Classification)回归(Regression)聚类(Clustering)降维(Dimensionality Reduction)模型选择&…...
音视频质量评判标准
一、实时通信延时指标 通过图中表格可以看到,如果端到端延迟在200ms以内,说明整个通话是优质的,通话效果就像大家在同一个房间里聊天一样;300ms以内,大多数人很满意,400ms以内,有小部分人可以感…...
如何在vue3中使用scss
一 要使用scss首先需要下载相关的包 可以在终端使用下面的命令下载相关包 npm install -D sass 二 在src文件下新建一个文件夹叫做styles 在文件夹下创建三个文件 index.scss主要用来引用其他文件 reset.scss用来清除默认的样式 variable.scss用来配置全局属性 三 需要在v…...
Gartner发布采用美国防部模型实施零信任的方法指南:七大支柱落地方法
零信任是网络安全计划的关键要素,但制定策略可能会很困难。安全和风险管理领导者应使用美国国防部模型的七大支柱以及 Gartner 研究来设计零信任策略。 战略规划假设 到 2026 年,10% 的大型企业将拥有全面、成熟且可衡量的零信任计划,而 202…...
Flutter——最详细(Badge)使用教程
背景 主要常用于组件叠加上圆点提示; 使用场景,消息数量提示,消息红点提示 属性作用backgroundColor红点背景色smallSize设置红点大小isLabelVisible是否显示offset设置红点位置alignment设置红点位置child设置底部组件 代码块 class Badge…...
SQLServer的系统数据库用别的服务器上的系统数据库替换后做跨服务器连接时出现凭证、非对称金钥或私密金钥的资料无效
出错作业背景: 公司的某个sqlserver服务器要做迁移,由于该sqlserver服务器上数据库很多,并且做了很多的job和维护计划,重新安装的sqlserver这些都是空的,于是就想到了把系统4个系统数据库进行替换,然后也把…...
vue前端面试
一 .v-if和v-show的区别 v-if 和 v-show 是 Vue.js 中两个常用的条件渲染指令,它们都可以根据条件决定是否渲染某个元素。但是它们之间存在一些区别。 语法:v-if 和 v-show 的语法相同,都接收一个布尔值作为参数。 <div v-if"show…...
【网络安全】Host碰撞漏洞原理+工具+脚本
文章目录 漏洞原理虚拟主机配置Host头部字段Host碰撞漏洞漏洞场景工具漏洞原理 Host 碰撞漏洞,也称为主机名冲突漏洞,是一种网络攻击手段。常见危害有:绕过访问控制,通过公网访问一些未经授权的资源等。 虚拟主机配置 在Web服务器(如Nginx或Apache)上,多个网站可以共…...
unattended-upgrade进程介绍
unattended-upgrade 是一个用于自动更新 Debian 和 Ubuntu 系统的软件包。这个进程通常用于定期下载并安装安全更新,以保持系统的安全性和稳定性。 具体来说,这个命令 /usr/bin/python3 /usr/bin/unattended-upgrade --download-only 表示运行 unattend…...
SpringBoot 中多例模式的神秘世界:用法区别以及应用场景,最后的灵魂拷问会吗?- 第519篇
历史文章(文章累计500) 《国内最全的Spring Boot系列之一》 《国内最全的Spring Boot系列之二》 《国内最全的Spring Boot系列之三》 《国内最全的Spring Boot系列之四》 《国内最全的Spring Boot系列之五》 《国内最全的Spring Boot系列之六》 《…...
基于STM32设计的智能婴儿床(ESP8266局域网)_2024升级版_180
基于STM32设计的智能婴儿床(采用STM32F103C8T6)(180) 文章目录 一、设计需求【1】项目功能介绍【2】程序最终的运行逻辑【3】硬件模块组成【4】ESP8266模块配置【5】上位机开发思路【6】系统功能模块划分1.2 项目开发背景1.3 开发工具的选择1.4 系统框架图1.5 系统原理图1.6 硬…...
可靠性+灵活性:电力载波技术在楼宇自控中的核心价值
可靠性灵活性:电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中,电力载波技术(PLC)凭借其独特的优势,正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据,无需额外布…...
【git】把本地更改提交远程新分支feature_g
创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...
IT供电系统绝缘监测及故障定位解决方案
随着新能源的快速发展,光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域,IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选,但在长期运行中,例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...
蓝桥杯3498 01串的熵
问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798, 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...
智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制
在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
Ubuntu系统复制(U盘-电脑硬盘)
所需环境 电脑自带硬盘:1块 (1T) U盘1:Ubuntu系统引导盘(用于“U盘2”复制到“电脑自带硬盘”) U盘2:Ubuntu系统盘(1T,用于被复制) !!!建议“电脑…...
快速排序算法改进:随机快排-荷兰国旗划分详解
随机快速排序-荷兰国旗划分算法详解 一、基础知识回顾1.1 快速排序简介1.2 荷兰国旗问题 二、随机快排 - 荷兰国旗划分原理2.1 随机化枢轴选择2.2 荷兰国旗划分过程2.3 结合随机快排与荷兰国旗划分 三、代码实现3.1 Python实现3.2 Java实现3.3 C实现 四、性能分析4.1 时间复杂度…...
内窥镜检查中基于提示的息肉分割|文献速递-深度学习医疗AI最新文献
Title 题目 Prompt-based polyp segmentation during endoscopy 内窥镜检查中基于提示的息肉分割 01 文献速递介绍 以下是对这段英文内容的中文翻译: ### 胃肠道癌症的发病率呈上升趋势,且有年轻化倾向(Bray等人,2018&#x…...
