当前位置: 首页 > news >正文

昇思25天学习打卡营第19天 | RNN实现情感分类

RNN实现情感分类

概述

情感分类是自然语言处理中的经典任务,是典型的分类问题。本节使用MindSpore实现一个基于RNN网络的情感分类模型,实现如下的效果:

输入: This film is terrible
正确标签: Negative
预测标签: Negative输入: This film is great
正确标签: Positive
预测标签: Positive

数据准备

本节使用情感分类的经典数据集IMDB影评数据集,数据集包含Positive和Negative两类,下面为其样例:

ReviewLabel
“Quitting” may be as much about exiting a pre-ordained identity as about drug withdrawal. As a rural guy coming to Beijing, class and success must have struck this young artist face on as an appeal to separate from his roots and far surpass his peasant parents’ acting success. Troubles arise, however, when the new man is too new, when it demands too big a departure from family, history, nature, and personal identity. The ensuing splits, and confusion between the imaginary and the real and the dissonance between the ordinary and the heroic are the stuff of a gut check on the one hand or a complete escape from self on the other.Negative
This movie is amazing because the fact that the real people portray themselves and their real life experience and do such a good job it’s like they’re almost living the past over again. Jia Hongsheng plays himself an actor who quit everything except music and drugs struggling with depression and searching for the meaning of life while being angry at everyone especially the people who care for him most.Positive

此外,需要使用预训练词向量对自然语言单词进行编码,以获取文本的语义特征,本节选取Glove词向量作为Embedding。

数据下载模块

为了方便数据集和预训练词向量的下载,首先设计数据下载模块,实现可视化下载流程,并保存至指定路径。数据下载模块使用requests库进行http请求,并通过tqdm库对下载百分比进行可视化。此外针对下载安全性,使用IO的方式下载临时文件,而后保存至指定的路径并返回。

tqdmrequests库需手动安装,命令如下:pip install tqdm requests

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
# !pip uninstall mindspore -y
# !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 
import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'def http_get(url: str, temp_file: IO):"""使用requests库下载数据,并使用tqdm库进行流程可视化"""req = requests.get(url, stream=True)content_length = req.headers.get('Content-Length')total = int(content_length) if content_length is not None else Noneprogress = tqdm(unit='B', total=total)for chunk in req.iter_content(chunk_size=1024):if chunk:progress.update(len(chunk))temp_file.write(chunk)progress.close()def download(file_name: str, url: str):"""下载数据并存为指定名称"""if not os.path.exists(cache_dir):os.makedirs(cache_dir)cache_path = os.path.join(cache_dir, file_name)cache_exist = os.path.exists(cache_path)if not cache_exist:with tempfile.NamedTemporaryFile() as temp_file:http_get(url, temp_file)temp_file.flush()temp_file.seek(0)with open(cache_path, 'wb') as cache_file:shutil.copyfileobj(temp_file, cache_file)return cache_path

完成数据下载模块后,下载IMDB数据集进行测试(此处使用华为云的镜像用于提升下载速度)。下载过程及保存的路径如下:

imdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_path
'/home/nginx/.mindspore_examples/aclImdb_v1.tar.gz'

加载IMDB数据集

下载好的IMDB数据集为tar.gz文件,我们使用Python的tarfile库对其进行读取,并将所有数据和标签分别进行存放。原始的IMDB数据集解压目录如下:

    ├── aclImdb│   ├── imdbEr.txt│   ├── imdb.vocab│   ├── README│   ├── test│   └── train│         ├── neg│         ├── pos...

数据集已分割为train和test两部分,且每部分包含neg和pos两个分类的文件夹,因此需分别train和test进行读取并处理数据和标签。

import re
import six
import string
import tarfileclass IMDBData():"""IMDB数据集加载器加载IMDB数据集并处理为一个Python迭代对象。"""label_map = {"pos": 1,"neg": 0}def __init__(self, path, mode="train"):self.mode = modeself.path = pathself.docs, self.labels = [], []self._load("pos")self._load("neg")def _load(self, label):pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))# 将数据加载至内存with tarfile.open(self.path) as tarf:tf = tarf.next()while tf is not None:if bool(pattern.match(tf.name)):# 对文本进行分词、去除标点和特殊字符、小写处理self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(None, six.b(string.punctuation)).lower()).split())self.labels.append([self.label_map[label]])tf = tarf.next()def __getitem__(self, idx):return self.docs[idx], self.labels[idx]def __len__(self):return len(self.docs)

完成IMDB数据加载器后,加载训练数据集进行测试,输出数据集数量:

imdb_train = IMDBData(imdb_path, 'train')
len(imdb_train)
25000

将IMDB数据集加载至内存并构造为迭代对象后,可以使用mindspore.dataset提供的Generatordataset接口加载数据集迭代对象,并进行下一步的数据处理,下面封装一个函数将train和test分别使用Generatordataset进行加载,并指定数据集中文本和标签的column_name分别为textlabel:

import mindspore.dataset as dsdef load_imdb(imdb_path):imdb_train = ds.GeneratorDataset(IMDBData(imdb_path, "train"), column_names=["text", "label"], shuffle=True, num_samples=10000)imdb_test = ds.GeneratorDataset(IMDBData(imdb_path, "test"), column_names=["text", "label"], shuffle=False)return imdb_train, imdb_test

加载IMDB数据集,可以看到imdb_train是一个GeneratorDataset对象。

imdb_train, imdb_test = load_imdb(imdb_path)
imdb_train
<mindspore.dataset.engine.datasets_user_defined.GeneratorDataset at 0xffff8c3d6310>

加载预训练词向量

预训练词向量是对输入单词的数值化表示,通过nn.Embedding层,采用查表的方式,输入单词对应词表中的index,获得对应的表达向量。
因此进行模型构造前,需要将Embedding层所需的词向量和词表进行构造。这里我们使用Glove(Global Vectors for Word Representation)这种经典的预训练词向量,
其数据格式如下:

WordVector
the0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 …
,0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -0.42852 -0.55641 -0.364 …

我们直接使用第一列的单词作为词表,使用dataset.text.Vocab将其按顺序加载;同时读取每一行的Vector并转为numpy.array,用于nn.Embedding加载权重使用。具体实现如下:

import zipfile
import numpy as npdef load_glove(glove_path):glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')if not os.path.exists(glove_100d_path):glove_zip = zipfile.ZipFile(glove_path)glove_zip.extractall(cache_dir)embeddings = []tokens = []with open(glove_100d_path, encoding='utf-8') as gf:for glove in gf:word, embedding = glove.split(maxsplit=1)tokens.append(word)embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))# 添加 <unk>, <pad> 两个特殊占位符对应的embeddingembeddings.append(np.random.rand(100))embeddings.append(np.zeros((100,), np.float32))vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)embeddings = np.array(embeddings).astype(np.float32)return vocab, embeddings

由于数据集中可能存在词表没有覆盖的单词,因此需要加入<unk>标记符;同时由于输入长度的不一致,在打包为一个batch时需要将短的文本进行填充,因此需要加入<pad>标记符。完成后的词表长度为原词表长度+2。

下面下载Glove词向量,并加载生成词表和词向量权重矩阵。

glove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
len(vocab.vocab())
400002

使用词表将the转换为index id,并查询词向量矩阵对应的词向量:

idx = vocab.tokens_to_ids('the')
embedding = embeddings[idx]
idx, embedding
(0,array([-0.038194, -0.24487 ,  0.72812 , -0.39961 ,  0.083172,  0.043953,-0.39141 ,  0.3344  , -0.57545 ,  0.087459,  0.28787 , -0.06731 ,0.30906 , -0.26384 , -0.13231 , -0.20757 ,  0.33395 , -0.33848 ,-0.31743 , -0.48336 ,  0.1464  , -0.37304 ,  0.34577 ,  0.052041,0.44946 , -0.46971 ,  0.02628 , -0.54155 , -0.15518 , -0.14107 ,-0.039722,  0.28277 ,  0.14393 ,  0.23464 , -0.31021 ,  0.086173,0.20397 ,  0.52624 ,  0.17164 , -0.082378, -0.71787 , -0.41531 ,0.20335 , -0.12763 ,  0.41367 ,  0.55187 ,  0.57908 , -0.33477 ,-0.36559 , -0.54857 , -0.062892,  0.26584 ,  0.30205 ,  0.99775 ,-0.80481 , -3.0243  ,  0.01254 , -0.36942 ,  2.2167  ,  0.72201 ,-0.24978 ,  0.92136 ,  0.034514,  0.46745 ,  1.1079  , -0.19358 ,-0.074575,  0.23353 , -0.052062, -0.22044 ,  0.057162, -0.15806 ,-0.30798 , -0.41625 ,  0.37972 ,  0.15006 , -0.53212 , -0.2055  ,-1.2526  ,  0.071624,  0.70565 ,  0.49744 , -0.42063 ,  0.26148 ,-1.538   , -0.30223 , -0.073438, -0.28312 ,  0.37104 , -0.25217 ,0.016215, -0.017099, -0.38984 ,  0.87424 , -0.72569 , -0.51058 ,-0.52028 , -0.1459  ,  0.8278  ,  0.27062 ], dtype=float32))

数据集预处理

通过加载器加载的IMDB数据集进行了分词处理,但不满足构造训练数据的需要,因此要对其进行额外的预处理。其中包含的预处理如下:

  • 通过Vocab将所有的Token处理为index id。
  • 将文本序列统一长度,不足的使用<pad>补齐,超出的进行截断。

这里我们使用mindspore.dataset中提供的接口进行预处理操作。这里使用到的接口均为MindSpore的高性能数据引擎设计,每个接口对应操作视作数据流水线的一部分,详情请参考MindSpore数据引擎。
首先针对token到index id的查表操作,使用text.Lookup接口,将前文构造的词表加载,并指定unknown_token。其次为文本序列统一长度操作,使用PadEnd接口,此接口定义最大长度和补齐值(pad_value),这里我们取最大长度为500,填充值对应词表中<pad>的index id。

除了对数据集中text进行预处理外,由于后续模型训练的需要,要将label数据转为float32格式。

import mindspore as mslookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
type_cast_op = ds.transforms.TypeCast(ms.float32)

完成预处理操作后,需将其加入到数据集处理流水线中,使用map接口对指定的column添加操作。

imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])

由于IMDB数据集本身不包含验证集,我们手动将其分割为训练和验证两部分,比例取0.7, 0.3。

imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
[WARNING] ME(49995:281473341991216,MainProcess):2024-07-05-10:40:25.898.932 [mindspore/dataset/engine/datasets.py:1203] Dataset is shuffled before split.

最后指定数据集的batch大小,通过batch接口指定,并设置是否丢弃无法被batch size整除的剩余数据。

调用数据集的mapsplitbatch为数据集处理流水线增加对应操作,返回值为新的Dataset类型。现在仅定义流水线操作,在执行时开始执行数据处理流水线,获取最终处理好的数据并送入模型进行训练。

imdb_train = imdb_train.batch(64, drop_remainder=True)
imdb_valid = imdb_valid.batch(64, drop_remainder=True)

模型构建

完成数据集的处理后,我们设计用于情感分类的模型结构。首先需要将输入文本(即序列化后的index id列表)通过查表转为向量化表示,此时需要使用nn.Embedding层加载Glove词向量;然后使用RNN循环神经网络做特征提取;最后将RNN连接至一个全连接层,即nn.Dense,将特征转化为与分类数量相同的size,用于后续进行模型优化训练。整体模型结构如下:

nn.Embedding -> nn.RNN -> nn.Dense

这里我们使用能够一定程度规避RNN梯度消失问题的变种LSTM(Long short-term memory)做特征提取层。下面对模型进行详解:

Embedding

Embedding层又可称为EmbeddingLookup层,其作用是使用index id对权重矩阵对应id的向量进行查找,当输入为一个由index id组成的序列时,则查找并返回一个相同长度的矩阵,例如:

embedding = nn.Embedding(1000, 100) # 词表大小(index的取值范围)为1000,表示向量的size为100
input shape: (1, 16)                # 序列长度为16
output shape: (1, 16, 100)

这里我们使用前文处理好的Glove词向量矩阵,设置nn.Embeddingembedding_table为预训练词向量矩阵。对应的vocab_size为词表大小400002,embedding_size为选用的glove.6B.100d向量大小,即100。

RNN(循环神经网络)

循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。下图为RNN的一般结构:

RNN-0

图示左侧为一个RNN Cell循环,右侧为RNN的链式连接平铺。实际上不管是单个RNN Cell还是一个RNN网络,都只有一个Cell的参数,在不断进行循环计算中更新。

由于RNN的循环特性,和自然语言文本的序列特性(句子是由单词组成的序列)十分匹配,因此被大量应用于自然语言处理研究中。下图为RNN的结构拆解:

RNN

RNN单个Cell的结构简单,因此也造成了梯度消失(Gradient Vanishing)问题,具体表现为RNN网络在序列较长时,在序列尾部已经基本丢失了序列首部的信息。为了克服这一问题,LSTM(Long short-term memory)被提出,通过门控机制(Gating Mechanism)来控制信息流在每个循环步中的留存和丢弃。下图为LSTM的结构拆解:

LSTM

本节我们选择LSTM变种而不是经典的RNN做特征提取,来规避梯度消失问题,并获得更好的模型效果。下面来看MindSpore中nn.LSTM对应的公式:

h 0 : t , ( h t , c t ) = LSTM ( x 0 : t , ( h 0 , c 0 ) ) h_{0:t}, (h_t, c_t) = \text{LSTM}(x_{0:t}, (h_0, c_0)) h0:t,(ht,ct)=LSTM(x0:t,(h0,c0))

这里nn.LSTM隐藏了整个循环神经网络在序列时间步(Time step)上的循环,送入输入序列、初始状态,即可获得每个时间步的隐状态(hidden state)拼接而成的矩阵,以及最后一个时间步对应的隐状态。我们使用最后的一个时间步的隐状态作为输入句子的编码特征,送入下一层。

Time step:在循环神经网络计算的每一次循环,成为一个Time step。在送入文本序列时,一个Time step对应一个单词。因此在本例中,LSTM的输出 h 0 : t h_{0:t} h0:t对应每个单词的隐状态集合, h t h_t ht c t c_t ct对应最后一个单词对应的隐状态。

Dense

在经过LSTM编码获取句子特征后,将其送入一个全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers,bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers=n_layers,bidirectional=bidirectional,batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output

损失函数与优化器

完成模型主体构建后,首先根据指定的参数实例化网络;然后选择损失函数和优化器。针对本节情感分类问题的特性,即预测Positive或Negative的二分类问题,我们选择nn.BCEWithLogitsLoss(二分类交叉熵损失函数)。

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)

训练逻辑

在完成模型构建,进行训练逻辑的设计。一般训练逻辑分为一下步骤:

  1. 读取一个Batch的数据;
  2. 送入网络,进行正向计算和反向传播,更新权重;
  3. 返回loss。

下面按照此逻辑,使用tqdm库,设计训练一个epoch的函数,用于训练过程和loss的可视化。

def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return lossgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return lossdef train_one_epoch(model, train_dataset, epoch=0):model.set_train()total = train_dataset.get_dataset_size()loss_total = 0step_total = 0with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in train_dataset.create_tuple_iterator():loss = train_step(*i)loss_total += loss.asnumpy()step_total += 1t.set_postfix(loss=loss_total/step_total)t.update(1)

评估指标和逻辑

训练逻辑完成后,需要对模型进行评估。即使用模型的预测结果和测试集的正确标签进行对比,求出预测的准确率。由于IMDB的情感分类为二分类问题,对预测值直接进行四舍五入即可获得分类标签(0或1),然后判断是否与正确标签相等即可。下面为二分类准确率计算函数实现:

def binary_accuracy(preds, y):"""计算每个batch的准确率"""# 对预测值进行四舍五入rounded_preds = np.around(ops.sigmoid(preds).asnumpy())correct = (rounded_preds == y).astype(np.float32)acc = correct.sum() / len(correct)return acc

有了准确率计算函数后,类似于训练逻辑,对评估逻辑进行设计, 分别为以下步骤:

  1. 读取一个Batch的数据;
  2. 送入网络,进行正向计算,获得预测结果;
  3. 计算准确率。

同训练逻辑一样,使用tqdm进行loss和过程的可视化。此外返回评估loss至供保存模型时作为模型优劣的判断依据。

在进行evaluate时,使用的模型是不包含损失函数和优化器的网络主体;
在进行evaluate前,需要通过model.set_train(False)将模型置为评估状态,此时Dropout不生效。

def evaluate(model, test_dataset, criterion, epoch=0):total = test_dataset.get_dataset_size()epoch_loss = 0epoch_acc = 0step_total = 0model.set_train(False)with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in test_dataset.create_tuple_iterator():predictions = model(i[0])loss = criterion(predictions, i[1])epoch_loss += loss.asnumpy()acc = binary_accuracy(predictions, i[1])epoch_acc += accstep_total += 1t.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)t.update(1)return epoch_loss / total

模型训练与保存

前序完成了模型构建和训练、评估逻辑的设计,下面进行模型训练。这里我们设置训练轮数为5轮。同时维护一个用于保存最优模型的变量best_valid_loss,根据每一轮评估的loss值,取loss值最小的轮次,将模型进行保存。为节省用例运行时长,此处num_epochs设置为2,可根据需要自行修改。

num_epochs = 5
best_valid_loss = float('inf')
ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')for epoch in range(num_epochs):train_one_epoch(model, imdb_train, epoch)valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)if valid_loss < best_valid_loss:best_valid_loss = valid_lossms.save_checkpoint(model, ckpt_file_name)
Epoch 0:   0%|          | 0/109 [00:00<?, ?it/s]-Epoch 0: 100%|██████████| 109/109 [10:13<00:00,  5.63s/it, loss=0.673]  
Epoch 0: 100%|██████████| 46/46 [00:23<00:00,  1.95it/s, acc=0.652, loss=0.626]
Epoch 1: 100%|██████████| 109/109 [01:27<00:00,  1.24it/s, loss=0.67] 
Epoch 1: 100%|██████████| 46/46 [00:14<00:00,  3.25it/s, acc=0.653, loss=0.633]
Epoch 2: 100%|██████████| 109/109 [01:30<00:00,  1.20it/s, loss=0.612]
Epoch 2: 100%|██████████| 46/46 [00:13<00:00,  3.33it/s, acc=0.74, loss=0.543] 
Epoch 3: 100%|██████████| 109/109 [01:29<00:00,  1.22it/s, loss=0.559]
Epoch 3: 100%|██████████| 46/46 [00:13<00:00,  3.37it/s, acc=0.749, loss=0.529]
Epoch 4: 100%|██████████| 109/109 [01:29<00:00,  1.22it/s, loss=0.518]
Epoch 4: 100%|██████████| 46/46 [00:13<00:00,  3.36it/s, acc=0.751, loss=0.523]

可以看到每轮Loss逐步下降,在验证集上的准确率逐步提升。

模型加载与测试

模型训练完成后,一般需要对模型进行测试或部署上线,此时需要加载已保存的最优模型(即checkpoint),供后续测试使用。这里我们直接使用MindSpore提供的Checkpoint加载和网络权重加载接口:1.将保存的模型Checkpoint加载到内存中,2.将Checkpoint加载至模型。

load_param_into_net接口会返回模型中没有和Checkpoint匹配的权重名,正确匹配时返回空列表。

param_dict = ms.load_checkpoint(ckpt_file_name)
ms.load_param_into_net(model, param_dict)
([], [])

对测试集打batch,然后使用evaluate方法进行评估,得到模型在测试集上的效果。

imdb_test = imdb_test.batch(64)
evaluate(model, imdb_test, loss_fn)
Epoch 0: 100%|█████████▉| 390/391 [01:29<00:00,  4.56it/s, acc=0.696, loss=0.575]\

\

Epoch 0: 100%|██████████| 391/391 [01:40<00:00,  3.88it/s, acc=0.696, loss=0.575]0.5750424911451462

自定义输入测试

最后我们设计一个预测函数,实现开头描述的效果,输入一句评价,获得评价的情感分类。具体包含以下步骤:

  1. 将输入句子进行分词;
  2. 使用词表获取对应的index id序列;
  3. index id序列转为Tensor;
  4. 送入模型获得预测结果;
  5. 打印输出预测结果。

具体实现如下:

score_map = {1: "Positive",0: "Negative"
}def predict_sentiment(model, vocab, sentence):model.set_train(False)tokenized = sentence.lower().split()indexed = vocab.tokens_to_ids(tokenized)tensor = ms.Tensor(indexed, ms.int32)tensor = tensor.expand_dims(0)prediction = model(tensor)return score_map[int(np.round(ops.sigmoid(prediction).asnumpy()))]

最后我们预测开头的样例,可以看到模型可以很好地将评价语句的情感进行分类。

predict_sentiment(model, vocab, "This film is terrible")
'Negative'
predict_sentiment(model, vocab, "This film is great")
'Positive'

相关文章:

昇思25天学习打卡营第19天 | RNN实现情感分类

RNN实现情感分类 概述 情感分类是自然语言处理中的经典任务&#xff0c;是典型的分类问题。本节使用MindSpore实现一个基于RNN网络的情感分类模型&#xff0c;实现如下的效果&#xff1a; 输入: This film is terrible 正确标签: Negative 预测标签: Negative输入: This fil…...

【VUE基础】VUE3第三节—核心语法之ref标签、props

ref标签 作用&#xff1a;用于注册模板引用。 用在普通DOM标签上&#xff0c;获取的是DOM节点。 用在组件标签上&#xff0c;获取的是组件实例对象。 用在普通DOM标签上&#xff1a; <template><div class"person"><h1 ref"title1">…...

生物化学笔记:电阻抗基础+电化学阻抗谱EIS+电化学系统频率响应分析

视频教程地址 引言 方法介绍 稳定&#xff1a;撤去扰动会到原始状态&#xff0c;反之不稳定&#xff0c;还有近似稳定的 阻抗谱图形&#xff08;Nyquist和Bode图&#xff09; 阻抗谱图形是用于分析电化学系统和材料的工具&#xff0c;主要有两种类型&#xff1a;Nyquist图和B…...

SQL使用join查询方式找出没有分类的电影id以及名称

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站&#xff0c;这篇文章男女通用&#xff0c;看懂了就去分享给你的码吧。 描述 现有电影信息…...

对MsgPack与JSON进行序列化的效率比较

序列化是将对象转换为字节流的过程&#xff0c;以便在内存或磁盘上存储。常见的序列化方法包括MsgPack和JSON。以下将详细探讨MsgPack和JSON在序列化效率方面的差异。 1. MsgPack的效率&#xff1a; 优点&#xff1a; 高压缩率&#xff1a; MsgPack采用高效的二进制编码格式&…...

Unix\Linux 执行shell报错:“$‘\r‘: 未找到命令” 解决

linux执行脚本sh xxx.sh报错&#xff1a;$xxx\r: 未找到命令 原因&#xff1a;shell脚本在Windows编写导致的换行问题&#xff1a; Windows 的换行符号为 CRLF&#xff08;\r\n&#xff09;&#xff0c;而 Unix\Linux 为 LF&#xff08;\n&#xff09;。 缩写全称ASCII转义说…...

动态路由--RIP配置(思科cisco)

一、简介 RIP协议&#xff08;Routing Information Protocol&#xff0c;路由信息协议&#xff09;是一种基于距离矢量的动态路由选择协议。 在RIP协议中&#xff0c;如果路由器A和网络B直接相连&#xff0c;那么路由器A到网络B的距离被定义为1跳。若从路由器A出发到达网络B需要…...

python - 函数 / 字典 / 集合

一.函数 形参和实参&#xff1a; >>> def MyFirstFunction(name): 函数定义过程中的name是叫形参 ... print(传递进来的 name 叫做实参&#xff0c;因为Ta是具体的参数值&#xff01;) print前面要加缩进tab&#xff0c;否则会出错。 >>> MyFirstFun…...

connect to github中personal access token生成token方法

一、问题 执行git push时弹出以下提示框 二、解决方法 去github官网生成Token&#xff0c;步骤如下 选择要授予此 令牌token 的 范围 或 权限 要使用 token 从命令行访问仓库&#xff0c;请选择 repo 。 要使用 token 从命令行删除仓库&#xff0c;请选择 delete_repo 其他根…...

Appium启动APP时报错Security exception: Permission Denial

报错内容Security exception: Permission Denial: starting Intent 直接通过am命令尝试也是同样的报错 查阅资料了解到&#xff1a;android:exported | App quality | Android Developers exported属性默认false&#xff0c;所以android:exported"false"修改为t…...

ubuntu22 使用ufw防火墙

专栏总目录 一、安装 sudo apt update sudo apt install ufw 二、启动防火墙 &#xff08;一&#xff09;启动命令 sudo ufw enable &#xff08;二&#xff09;重启命令 sudo ufw reload 三、配置规则 #允许SSH连接 sudo ufw allow ssh #如果sshd服务端口指定到了8888&a…...

初识STM32:开发方式及环境

STM32的编程模型 假如使用C语言的方式写了一段程序&#xff0c;这段程序首先会被烧录到芯片当中&#xff08;Flash存储器中&#xff09;&#xff0c;Flash存储器中的程序会逐条的进入CPU里面去执行。 CPU相当于人的一个大脑&#xff0c;虽然能执行运算和执行指令&#xff0c;…...

详解Amivest 流动性比率

详解Amivest 流动性比率 Claude-3.5-Sonnet Poe Amivest流动性比率是一个衡量证券市场流动性的重要指标。这个比率主要用于评估在不对价格造成重大影响的情况下,市场能够吸收多少交易量。以下是对Amivest流动性比率的详细解释: 定义: Amivest流动性比率是交易额与绝对收益率的…...

pycharm小游戏制作

以下是一个使用 Python 和 PyGame库在 PyCharm中创建一个简单的小游戏&#xff08;贪吃蛇游戏&#xff09;的示例代码&#xff0c;希望对您有所帮助&#xff1a; import pygame import random# 基础设置 # 屏幕高度 SCREEN_HEIGHT 480 # 屏幕宽度 SCREEN_WIDTH 600 # 小方格…...

昇思11天

基于 MindSpore 实现 BERT 对话情绪识别 BERT模型概述 BERT&#xff08;Bidirectional Encoder Representations from Transformers&#xff09;是由Google于2018年开发并发布的一种新型语言模型。BERT在许多自然语言处理&#xff08;NLP&#xff09;任务中发挥着重要作用&am…...

AI绘画Stable Diffusion【图生图教程】:图片高清修复的三种方案详解,你一定能用上!(附资料)

大家好&#xff0c;我是画画的小强 今天给大家分享一下用AI绘画Stable Diffusion 进行 高清修复&#xff08;Hi-Res Fix&#xff09;&#xff0c;这是用于提升图像分辨率和细节的技术。在生成图像时&#xff0c;初始的低分辨率图像会通过放大算法和细节增强技术被转换为高分辨…...

适用于Mac和Windows的最佳iPhone恢复软件

本文将指导您选择一款出色的iPhone数据恢复软件来检索您的宝贵数据。 市场上有许多所谓的iPhone恢复程序。各种程序很难选择并选择其中之一。一旦您做出了错误的选择&#xff0c;您的数据就会有风险。 最好的iPhone数据恢复软件应包含以下功能。 1.安全可靠。 2.恢复成功率高…...

64.ThreadLocal造成的内存泄漏

内存泄漏 程序中已动态分配的堆内存,由于某种原因程序为释放和无法释放,造成系统内存的浪费,导致程序运行速度减慢甚至系统崩溃等严重后果。内存泄漏的堆积终将导致内存溢出。 内存溢出 没有足够的内存提供申请者使用。 ThreadLocal出现内存泄漏的真实原因 内存泄漏的发…...

深入刨析Redis存储技术设计艺术(二)

三、Redis主存储 3.1、存储相关结构体 redisServer:服务器 server.h struct redisServer { /* General */ pid_t pid; /* Main process pid. */ pthread_t main_thread_id; /* Main thread id */ char *configfile; /* Absolut…...

python读取写入txt文本文件

读取 txt 文件 def read_txt_file(file_path):"""读取文本文件的内容:param file_path: 文本文件的路径:return: 文件内容"""try:with open(file_path, r, encodingutf-8) as file:content file.read()return contentexcept FileNotFoundError…...

日期选取限制日期范围antdesign vue

限制选取的日期范围 效果图 <a-date-pickerv-model"dateTime"format"YYYY-MM-DD":disabled-date"disabledDate"valueFormat"YYYY-MM-DD"placeholder"请选择日期"allowClear />methods:{//回放日期选取范围限制&…...

【大模型】衡量巨兽:解读评估LLM性能的关键技术指标

衡量巨兽&#xff1a;解读评估LLM性能的关键技术指标 引言一、困惑度&#xff1a;语言模型的试金石1.1 定义与原理1.2 计算公式1.3 应用与意义 二、BLEU 分数&#xff1a;翻译质量的标尺2.1 定义与原理2.2 计算方法2.3 应用与意义 三、其他评估指标&#xff1a;综合考量下的多元…...

《优化接口设计的思路》系列:第2篇—小程序性能优化

优化Uniapp应用程序的性能可以从以下几个方面进行优化&#xff1a; 1.减少页面加载时间&#xff1a;避免页面过多和过大的组件&#xff0c;减少不必要的资源加载。可以使用懒加载的方式&#xff0c;根据用户的实际需求来加载页面和组件。 2.节流和防抖&#xff1a;对于频繁触发…...

prototype 和 __proto__的区别

prototype 和 __proto__ 在 JavaScript 中都与对象的原型链有关&#xff0c;但它们各自有不同的用途和含义。 prototype prototype 是函数对象的一个属性&#xff0c;它指向一个对象&#xff0c;这个对象包含了可以由特定类型的所有实例共享的属性和方法。当我们创建一个新的…...

网络中未授权访问漏洞(Rsync,PhpInfo)

Rsync未授权访问漏洞 Rsync未授权访问漏洞是指Rsync服务配置不当或存在漏洞&#xff0c;导致攻击者可以未经授权访问和操作Rsync服务。Rsync是一个用于文件同步和传输的开源工具&#xff0c;通常在Unix/Linux系统上使用。当Rsync服务未经正确配置时&#xff0c;攻击者可以利用…...

DataWhaleAI分子预测夏令营 学习笔记

AI分子预测夏令营学习笔记 一、直播概览 主持人介绍 姓名&#xff1a;徐翼萌角色&#xff1a;DataWhale助教活动目的&#xff1a;分享机器学习赛事经验&#xff0c;提升参赛者在分子预测领域的能力 嘉宾介绍 姓名&#xff1a;余老师背景&#xff1a;Data成员&#xff0c;腾…...

lnmp php7 安装ssh2扩展

安装ssh2扩展前必须安装libssh2包 下载地址: wget http://www.libssh2.org/download/libssh2-1.11.0.tar.gzwget http://pecl.php.net/get/ssh2-1.4.tgz &#xff08;这里要换成最新的版本&#xff09; 先安装 libssh2 再安装 SSH2: tar -zxvf libssh2-1.11.0.tar.gzcd libss…...

数据库概念题总结

1、 2、简述数据库设计过程中&#xff0c;每个设计阶段的任务 需求分析阶段&#xff1a;从现实业务中获取数据表单&#xff0c;报表等分析系统的数据特征&#xff0c;数据类型&#xff0c;数据约束描述系统的数据关系&#xff0c;数据处理要求建立系统的数据字典数据库设计…...

提升用户体验之requestAnimationFrame实现前端动画

1)requestAnimationFrame是什么? 1.MDN官方解释 2.解析这段话&#xff1a; 1、那么浏览器重绘是指什么呢&#xff1f; ——大多数电脑的显示器刷新频率是60Hz&#xff0c;1000ms/6016.66666667ms的时间刷新一次 2、重绘之前调用指定的回调函数更新动画&#xff1f; ——requ…...

Mysql慢日志、慢SQL

慢查询日志 查看执行慢的SQL语句&#xff0c;需要先开启慢查询日志。 MySQL 的慢查询日志&#xff0c;记录在 MySQL 中响应时间超过阀值的语句&#xff08;具体指运行时间超过 long_query_time 值的SQL。long_query_time 的默认值为10&#xff0c;意思是运行10秒以上(不含10秒…...

卫星网络——Walker星座简单介绍

一、星座构型介绍 近年来&#xff0c;随着卫星应用领的不断拓展&#xff0c;许多任务已经无法单纯依靠单颗卫星来完成。与单个卫星相比&#xff0c;卫星星座的覆盖范围显著增加&#xff0c;合理的星座构型可以使其达到全球连续覆盖或全球多重连续覆盖&#xff0c;这样的特性使得…...

C++ Lambda表达式第一篇, 闭合(Closuretype)

C Lambda表达式第一篇&#xff0c; 闭合Closuretype ClosureType::operator()(params)auto 模板参数类型显式模板参数类型其他 ClosureType::operator ret(*)(params)() lambda 表达式是唯一的未命名&#xff0c;非联合&#xff0c;非聚合类类型&#xff08;称为闭包类型&#…...

移动校园(3):处理全校课程数据excel文档,实现空闲教室查询与课程表查询

首先打开教学平台 然后导出为excel文档 import mathimport pandas as pd import pymssql serverName 127.0.0.1 userName sa passWord 123456 databaseuniSchool conn pymssql.connect(serverserverName,useruserName,passwordpassWord,databasedatabase) cursor conn.cur…...

【MySQL】1.初识MySQL

初识MySQL 一.MySQL 安装1.卸载已有的 MySQL2.获取官方 yum 源3.安装 MySQL4.登录 MySQL5.配置 my.cnf 二.MySQL 数据库基础1.MySQL 是什么&#xff1f;2.服务器&#xff0c;数据库和表3.mysqld 的层状结构4.SQL 语句分类 一.MySQL 安装 1.卸载已有的 MySQL //查询是否有相关…...

查看电脑显卡(NVIDIA)应该匹配什么版本的CUDA Toolkit

被串行计算逼到要吐时&#xff0c;决定重拾CUDa了&#xff0c;想想那光速般的处理感觉&#xff08;夸张了&#xff09;不要太爽&#xff0c;记下我的闯关记录。正好我的电脑配了NVIDIA独显&#xff0c;GTX1650&#xff0c;有菜可以炒呀&#xff0c;没有英伟达的要绕道了。回到正…...

优化:遍历List循环查找数据库导致接口过慢问题

前提&#xff1a; 我们在写查询的时候&#xff0c;有时候会遇到多表联查&#xff0c;一遇到多表联查大家就会直接写sql语句&#xff0c;不会使用较为方便的LambdaQueryWrapper去查询了。作为一个2024新进入码农世界的小白&#xff0c;我喜欢使用LambdaQueryWrapper&#xff0c;…...

NoSQL 之 Redis 配置与常用命令

一、关系型数据库与非关系型数据库 1、数据库概述 &#xff08;1&#xff09;关系型数据库 关系型数据库是一个结构化的数据库&#xff0c;创建在关系模型&#xff08;二维表格模型&#xff09;基础上&#xff0c;一般面向于记 录。 SQL 语句&#xff08;标准数据查询语言&am…...

用SpringBoot打造坚固防线:轻松实现XSS攻击防御

在这篇博客中&#xff0c;我们将深入探讨如何使用SpringBoot有效防御XSS攻击。通过结合注解和过滤器的方式&#xff0c;我们可以为应用程序构建一个强大的安全屏障&#xff0c;确保用户数据不被恶意脚本所侵害。 目录 什么是XSS攻击&#xff1f;SpringBoot中的XSS防御策略使用…...

2024机器人科研/研发领域最新研究方向岗位职责与要求

具身智能工程师 从事具身智能领域的技术研究或产品开发&#xff0c;制定具身智能技术标准&#xff0c;利用大模型技术来提高机器人的智能化水平&#xff0c;研究端云协同的机器人系统框架&#xff0c;并赋能人形/复合等各类形态的机器人。具体内容包括不限于&#xff1a; 1、负…...

笔记:Newtonsoft.Json 序列化接口集合

在使用 Newtonsoft.Json 序列化接口集合时&#xff0c;一个常见的挑战是如何处理接口的具体实现&#xff0c;因为接口本身并不包含关于要实例化哪个具体类的信息。为了正确序列化和反序列化接口集合&#xff0c;你需要提供一些额外的信息或使用自定义的转换器来指导 Newtonsoft…...

【Unity设计模式】✨使用 MVC 和 MVP 编程模式

前言 最近在学习Unity游戏设计模式&#xff0c;看到两本比较适合入门的书&#xff0c;一本是unity官方的 《Level up your programming with game programming patterns》 ,另一本是 《游戏编程模式》 这两本书介绍了大部分会使用到的设计模式&#xff0c;因此很值得学习 本…...

CDH安装和配置流程

这份文件是一份关于CDH&#xff08;Clouderas Distribution Including Apache Hadoop&#xff09;安装的详细手册&#xff0c;主要内容包括以下几个部分&#xff1a; 1. **前言**&#xff1a; - CDH是基于Apache Hadoop的发行版&#xff0c;由Cloudera公司开发。 - 相比…...

SpringMVC:SpringMVC执行流程

文章目录 一、介绍二、什么是MVC 一、介绍 Spring MVC 是一种基于Java的Web框架&#xff0c;它采用了MVC&#xff08;Model - View - Controller&#xff09;设计模式&#xff0c;通过吧Model、View和Controller分离&#xff0c;将Web层进行职责解耦&#xff0c;把复杂的Web应…...

如何在前端网页实现live2d的动态效果

React如何在前端网页实现live2d的动态效果 业务需求&#xff1a; 因为公司需要做机器人相关的业务&#xff0c;主要是聊天形式的内容&#xff0c;所以需要一个虚拟的卡通形象。而且为了更直观的展示用户和机器人对话的状态&#xff0c;该live2d动画的嘴型需要根据播放的内容来…...

昇思25天学习打卡营第15天|linchenfengxue

Pix2Pix实现图像转换 Pix2Pix概述 Pix2Pix是基于条件生成对抗网络&#xff08;cGAN, Condition Generative Adversarial Networks &#xff09;实现的一种深度学习图像转换模型&#xff0c;该模型是由Phillip Isola等作者在2017年CVPR上提出的&#xff0c;可以实现语义/标签到…...

软考中级数据库系统工程师备考经验分享

前几天软考成绩出了&#xff0c;赶紧查询了一下发现自己顺利通过啦&#xff08;上午63&#xff0c;下午67&#xff0c;开心&#xff09;&#xff0c;因此本文记录一下我的备考经验分享给大家。因为工作中项目管理类的知识没有系统学习过&#xff0c;本来想直接报名软考高级证书…...

Centos7删除MariaDB

在 CentOS 7 上删除 MariaDB 可以通过 yum 包管理器来完成。以下是一步一步的指导&#xff1a; 打开终端&#xff1a;首先&#xff0c;你需要打开你的 CentOS 7 系统的终端。 停止 MariaDB 服务&#xff08;如果正在运行&#xff09;&#xff1a;在卸载 MariaDB 之前&#xff…...

【Docker系列】Docker 镜像构建中的跨设备移动问题及解决方案

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

C++友元函数和友元类的使用

1.友元介绍 在C++中,友元(friend)是一种机制,允许某个类或函数访问其他类的私有成员。通过友元,可以授予其他类或函数对该类的私有成员的访问权限。友元关系在一些特定的情况下很有用,例如在类之间共享数据或实现特定的功能。 友元可以分为两种类型:类友元和函数友元。…...

黑马苍穹外卖技术亮点 详情

1.使用工厂模式和策略模式实现布隆过滤器解决缓存穿透问题 Bitmap Bitmap是一种数据结构&#xff0c;它使用位图来表示数据。在处理大量数据时&#xff0c;Bitmap可以通过将每个数据元素映射到一个位&#xff0c;然后使用位运算来对数据进行操作。 通过使用Bitmap&#xff0c…...