当前位置: 首页 > news >正文

人名分类器(nlp)

# coding: utf-8
import osos.environ['KMP_DUPLICATE_LIB_OK'] = 'True'# 导入torch工具
import jsonimport torch
# 导入nn准备构建模型
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 导入torch的数据源 数据迭代器工具包
from torch.utils.data import Dataset, DataLoader
# 用于获得常见字母及字符规范化
import string
# 导入时间工具包
import time
# 引入制图工具包
import matplotlib.pyplot as plt
# 从io中导入文件打开方法
from io import open# 1 获取常用的字符 标点,把每个char字符作为一个token,用onehot编码表示token
# 因此我们的词表就是 char表 (字符表) 57个char
all_letters = string.ascii_letters + " ,.;'"
print(all_letters)n_letter = len(all_letters)  # 词表的大小
print('字符表的长度:', n_letter)# 2 获取国家的类别种数
# 国家名 种类数
categorys = ['Italian', 'English', 'Arabic', 'Spanish', 'Scottish', 'Irish', 'Chinese', 'Vietnamese', 'Japanese','French', 'Greek', 'Dutch', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Czech', 'German']
# 国家名 个数,就是模型的 (linear输出维度) 分类数
categorynum = len(categorys)
print('categorys--->', categorys)# 3 读取数据
def read_data(filename):# 3.1 初始化空列表两个my_list_x, my_list_y = [], []# 3.2 读取文件内容with open(filename, 'r', encoding='utf-8') as fr:for line in fr.readlines():# 异常点判断:改行长度<=5,说明这是异常样本,直接跳到下一行if len(line) <= 5:continuex, y = line.strip().split('\t')my_list_x.append(x)my_list_y.append(y)# 3.3 返回两个列表return my_list_x, my_list_y# 4 构建数据集
class NameClsDataset(Dataset):def __init__(self, mylist_x, mylist_y):self.mylist_x = mylist_xself.mylist_y = mylist_ydef __len__(self):return len(self.mylist_x)def __getitem__(self, item):# 01 item 异常值出处理index = min(max(item, 0), len(self.mylist_x) - 1)# 02 根据idx拿到人名 国家名x = self.mylist_x[index]y = self.mylist_y[index]# 03 完成onehottensor_x = torch.zeros(len(x), n_letter)for idx, letter in enumerate(x):tensor_x[idx][all_letters.find(letter)] = 1# 04 获得标签tensor_y = torch.tensor(categorys.index(y), dtype=torch.long)return tensor_x, tensor_y# 5 构建dataloader
def get_dataloader():filename = './data/name_classfication.txt'my_list_x, my_list_y = read_data(filename)mydataset = NameClsDataset(my_list_x, my_list_y)my_dataloader = DataLoader(mydataset,batch_size=1,shuffle=True,  # 打乱顺序# drop_last=True,  # 是否丢弃最后那个不足一个batch_size的数据组# collate_fn=collate_fn,  # 处理一个batch的数据为整齐的维度)x, y = next(iter(my_dataloader))# print(x)# print(x.shape)# print(y)return my_dataloader# 6 创建rnn模型
class MyRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_sizeself.num_layers = num_layersself.rnn = nn.RNN(self.input_size, self.hidden_size,self.num_layers, batch_first=True)# self.linear = nn.Linear(self.hidden_size, self.hidden_size)self.linear = nn.Linear(self.hidden_size, self.output_size)self.softmax = nn.LogSoftmax(dim=-1)def forward(self, input):# input.shape = (1, 9, 57)# hidden.shape = (1, 1, 128)# rnn_output.shape = (1, 9, 128)# rnn_hn.shape = (1, 1, 128)# rnn_output, _ = self.rnn(input)rnn_output, rnn_hn = self.rnn(input)# temp.shape = (1, 128)# temp = rnn_output[0][-1].unsqueeze(0)temp = rnn_hn[0]# output.shape=(1,18)# self.softmax(output) (2, 18)output = self.linear(temp)  # 可以接受三维数据return self.softmax(output), rnn_hn# 7 测试RNN
def ceshiRNN():# 1 拿到数据my_dataloader = get_dataloader()# 2 实例化模型input_size = n_letter  # 字符表的大小 (词表的大小)hidden_size = 128  # 超参数 768,rnn输出维度output_size = len(categorys)  # 18,分类总数my_rnn = MyRNN(input_size, hidden_size, output_size)# 3 将数据送入到模型x, y = next(iter(my_dataloader))output, hn = my_rnn(x)  # output.shape = (1, 18)print(output.shape)print(hn.shape)# 8 训练RNN
def train_my_rnn():epochs = 1my_lr = 1e-3# 1 读取数据my_list_x, my_list_y = read_data('./data/name_classfication.txt')# 2 定义datasetmyDataset = NameClsDataset(my_list_x, my_list_y)# 3 实例化dataloadermy_dataloader = DataLoader(myDataset, batch_size=1, shuffle=True)# 4 实例化RNN模型input_size = 57hidden_size = 128output_size = 18my_rnn = MyRNN(input_size, hidden_size, output_size)# 5 损失函数my_crossentropy = nn.NLLLoss()# 6 优化器my_optimizer = optim.Adam(my_rnn.parameters(), lr=my_lr)# 7 日志start_time = time.time()total_iter_num = 0  # 已经训练好的样本数total_loss = 0  # 总的losstotal_loss_list = []  # 每隔多少步存储loss-avgtotal_acc_num = 0total_acc_list = []  # 存储间隔准确率acc-avg# 8 开始训练# 8.1 外部循环for epoch_idx in range(epochs):# 8.2 batch循环for i, (x, y) in enumerate(my_dataloader):# 8.3 将x送入到模型 一轮模型训练output, hn = my_rnn(x)my_loss = my_crossentropy(output, y)my_optimizer.zero_grad()my_loss.backward()my_optimizer.step()total_iter_num += 1total_loss += my_loss.item()item1 = 1 if torch.argmax(output, dim=-1).item() == y.item() else 0total_acc_num += item1# 每隔 100 步存储avg-loss acc-avgif total_iter_num % 100 == 0:# 保存一下平均损失loss_avg = total_loss / total_iter_numtotal_loss_list.append(loss_avg)# acc-avgacc_avg = total_acc_num / total_iter_numtotal_acc_list.append(acc_avg)if total_iter_num % 1000 == 0:loss_avg = total_loss / total_iter_numacc_avg = total_acc_num / total_iter_numend_time = time.time()use_time = end_time - start_timeprint('当前的训练批次:%d, 平均损失:%.5f, 训练时间:%.3f, 准确率:%.2f' % (epoch_idx + 1,loss_avg,use_time,acc_avg))# 9 保存模型torch.save(my_rnn.state_dict(), './model/my_rnn.bin')# 10 结束all_time = time.time() - start_timeprint('总耗时:', all_time)return total_loss_list, total_acc_list, all_time# 9 将模型结果进行保存,方便进行读取
def save_rnn_res():# 1 训练模型,得到需要的结果total_loss_list, total_acc_list, all_time = train_my_rnn()# 2 定义一个字典dict1 = {'loss': total_loss_list,'time': all_time,'acc': total_acc_list}# 3 保存成jsonwith open('./data/rnn_result.json', 'w') as fw:fw.write(json.dumps(dict1))# 10 读取模型结果json
def read_json(json_path):with open(json_path, 'r') as fr:# '{a:1, b:2,,,}'  --> json.loads()# json.load() 加载json文件res = json.load(fr)return res# 11 绘图
def plt_RNN():# 1 拿到数据rnn_results = read_json('./data/rnn_result-epoch3.json')total_loss_list_rnn, all_time_rnn, total_acc_list_rnn = rnn_results['loss'], rnn_results['time'], rnn_results['acc']lstm_results = read_json('./data/lstm_result-epoch3.json')total_loss_list_lstm, all_time_lstm, total_acc_list_lstm = lstm_results['loss'], lstm_results['time'], lstm_results['acc']gru_results = read_json('./data/gru_result-epoch3.json')total_loss_list_gru, all_time_gru, total_acc_list_gru = gru_results['loss'], gru_results['time'], gru_results['acc']# 2 绘制loss对比曲线图plt.figure(0)plt.plot(total_loss_list_rnn, label='RNN')plt.plot(total_loss_list_lstm, label='LSTM', color='red')plt.plot(total_loss_list_gru, label='GRU', color='orange')plt.legend(loc='upper right')plt.savefig('./picture/loss.png')plt.show()# 3 绘制耗时柱状图plt.figure(1)x_data = ['RNN', 'LSTM', 'GRU']y_data = [all_time_rnn, all_time_lstm, all_time_gru]plt.bar(range(len(x_data)), y_data, tick_label=x_data)plt.savefig('./picture/use_time.png')plt.show()# 4 绘制acc曲线图plt.figure(2)plt.plot(total_acc_list_rnn, label='RNN')plt.plot(total_acc_list_lstm, label='LSTM', color='red')plt.plot(total_acc_list_gru, label='GRU', color='orange')plt.legend(loc='upper right')plt.savefig('./picture/acc.png')plt.show()# 12 定义预测输入的x --》 tensor_x
def line2tensor(x):tensor_x = torch.zeros(len(x), n_letter)for li, letter in enumerate(x):tensor_x[li][all_letters.find(letter)] = 1return tensor_x# 13 预测主函数
def rnn_predict(x):# 1 x --》 tensor_xtensor_x = line2tensor(x)# 2 实力化模型my_rnn = MyRNN(input_size=57, hidden_size=128, output_size=18)my_rnn.load_state_dict(torch.load('./model/my_rnn.bin'))# 3 预测with torch.no_grad():  # 预测时不去计算梯度input0 = tensor_x.unsqueeze(0)  # input0 是三维的,rnn需要output, hn = my_rnn(input0)topv, topi = output.topk(3, 1, True)print('人名是', x)# 4 打印topk个for i in range(3):value = topv[0][i]index = topi[0][i]cate = categorys[index]print('国家名是:', cate)if __name__ == '__main__':# filename = './data/name_classfication.txt'# x, y = read_data(filename)# print(x)# print(y)# get_dataloader()# ceshiRNN()# train_my_rnn()# plt_RNN()rnn_predict('zhang')

相关文章:

人名分类器(nlp)

# coding: utf-8 import osos.environ[KMP_DUPLICATE_LIB_OK] True# 导入torch工具 import jsonimport torch # 导入nn准备构建模型 import torch.nn as nn import torch.nn.functional as F import torch.optim as optim # 导入torch的数据源 数据迭代器工具包 from torch.ut…...

斐波那契数列 相关问题 详解

斐波那契数列相关问题详解 斐波那契数列及其相关问题是算法学习中的经典主题&#xff0c;变形与应用非常广泛&#xff0c;涵盖了递推关系、动态规划、组合数学、数论等多个领域。以下是斐波那契数列的相关问题及其解法的详解。 1. 经典斐波那契数列 定义 初始条件&#xff1…...

Pytorch微调深度学习模型

在公开数据训练了模型&#xff0c;有时候需要拿到自己的数据上微调。今天正好做了一下微调&#xff0c;在此记录一下微调的方法。用Pytorch还是比较容易实现的。 网上找了很多方法&#xff0c;以及Chatgpt也给了很多方法&#xff0c;但是不够简洁和容易理解。 大体步骤是&…...

springboot 使用笔记

1.springboot 快速启动项目 注意&#xff1a;该启动只是临时启动&#xff0c;不能关闭终端面板 cd /www/wwwroot java -jar admin.jar2.脚本启动 linux shell脚本启动springboot服务 3.java一键部署springboot 第5条 https://blog.csdn.net/qq_30272167/article/details/1…...

网络安全基础——网络安全法

填空题 1.根据**《中华人民共和国网络安全法》**第二十条(第二款)&#xff0c;任何组织和个人试用网路应当遵守宪法法律&#xff0c;遵守公共秩序&#xff0c;遵守社会公德&#xff0c;不危害网络安全&#xff0c;不得利用网络从事危害国家安全、荣誉和利益&#xff0c;煽动颠…...

SCAU软件体系结构实验四 组合模式

目录 一、题目 二、源码 一、题目 个人(Person)与团队(Team)可以形成一个组织(Organization)&#xff1a;组织有两种&#xff1a;个人组织和团队组织&#xff0c;多个个人可以组合成一个团队&#xff0c;不同的个人与团队可以组合成一个更大的团队。 使用控制台或者JavaFx界面…...

Amazon商品详情API接口:电商创新与用户体验的驱动力

在电子商务蓬勃发展的今天&#xff0c;作为全球最大的电商平台之一&#xff0c;亚马逊&#xff08;Amazon&#xff09;凭借其强大的技术实力和丰富的商品资源&#xff0c;为全球用户提供了优质的购物体验。其中&#xff0c;Amazon商品详情API接口在电商创新与用户体验提升方面扮…...

手机无法连接服务器1302什么意思?

你有没有遇到过手机无法连接服务器&#xff0c;屏幕上显示“1302”这样的错误代码&#xff1f;尤其是在急需使用手机进行工作或联系朋友时&#xff0c;突然出现的连接问题无疑会带来不少麻烦。那么&#xff0c;什么是1302错误&#xff0c;它又意味着什么呢&#xff1f; 1302错…...

Android adb shell dumpsys audio 信息查看分析详解

Android adb shell dumpsys audio 信息查看分析详解 一、前言 Android 如果要分析当前设备的声音通道相关日志&#xff0c; 仅仅看AudioService的日志是看不到啥日志的&#xff0c;但是看整个audio关键字的日志又太多太乱了&#xff0c; 所以可以看一下系统提供的一个调试指令…...

Python 网络爬虫操作指南

网络爬虫是自动化获取互联网上信息的一种工具。它广泛应用于数据采集、分析以及实现信息聚合等众多领域。本文将为你提供一个完整的Python网络爬虫操作指南&#xff0c;帮助你从零开始学习并实现简单的网络爬虫。我们将涵盖基本的爬虫概念、Python环境配置、常用库介绍。 上传…...

基于FPGA的2FSK调制-串口收发-带tb仿真文件-实际上板验证成功

基于FPGA的2FSK调制 前言一、2FSK储备知识二、代码分析1.模块分析2.波形分析 总结 前言 设计实现连续相位 2FSK 调制器&#xff0c;2FSK 的两个频率为:fI15KHz&#xff0c;f23KHz&#xff0c;波特率为 1500 bps,比特0映射为f 载波&#xff0c;比特1映射为 载波。 1&#xff09…...

JavaScript的基础数据类型

一、JavaScript中的数组 定义 数组是一种特殊的对象&#xff0c;用于存储多个值。在JavaScript中&#xff0c;数组可以包含不同的数据类型&#xff0c;如数字、字符串、对象、甚至其他数组。数组的创建有两种常见方式&#xff1a; 字面量表示法&#xff1a;let fruits [apple…...

第三讲 架构详解:“隐语”可信隐私计算开源框架

目录 隐语架构 隐语架构拆解 产品层 算法层 计算层 资源层 互联互通 跨域管控 本文主要是记录参加隐语开源社区推出的第四期隐私计算实训营学习到的相关内容。 隐语架构 隐语架构拆解 产品层 产品定位&#xff1a; 通过可视化产品&#xff0c;降低终端用户的体验和演…...

JDBC编程---Java

目录 一、数据库编程的前置 二、Java的数据库编程----JDBC 1.概念 2.JDBC编程的优点 三.导入MySQL驱动包 四、JDBC编程的实战 1.创造数据源&#xff0c;并设置数据库所在的位置&#xff0c;三条固定写法 2.建立和数据库服务器之间的连接&#xff0c;连接好了后&#xff…...

Python绘制太极八卦

文章目录 系列目录写在前面技术需求1. 图形绘制库的支持2. 图形绘制功能3. 参数化设计4. 绘制控制5. 数据处理6. 用户界面 完整代码代码分析1. rset() 函数2. offset() 函数3. taiji() 函数4. bagua() 函数5. 绘制过程6. 技术亮点 写在后面 系列目录 序号直达链接爱心系列1Pyth…...

Spring框架特性及包下载(Java EE 学习笔记04)

1 Spring 5的新特性 Spring 5是Spring当前最新的版本&#xff0c;与历史版本对比&#xff0c;Spring 5对Spring核心框架进行了修订和更新&#xff0c;增加了很多新特性&#xff0c;如支持响应式编程等。 更新JDK基线 因为Spring 5代码库运行于JDK 8之上&#xff0c;所以Spri…...

Linux关于vim的笔记

Linux关于vim的笔记&#xff1a;(vimtutor打开vim 教程) --------------------------------------------------------------------------------------------------------------------------------- 1. 光标在屏幕文本中的移动既可以用箭头键&#xff0c;也可以使用 hjkl 字母键…...

linux mount nfs开机自动挂载远程目录

要在Linux系统中实现开机自动挂载NFS共享目录&#xff0c;你需要编辑/etc/fstab文件。以下是具体步骤和示例&#xff1a; 确保你的系统已经安装了NFS客户端。如果没有安装&#xff0c;可以使用以下命令安装&#xff1a; sudo apt-install nfs-common 编辑/etc/fstab文件&#…...

【vue】导航守卫

什么是导航守卫 在vue路由切换过程中对行为做个限制 全局前置守卫 route.beforeEach((to, from, next)) > {// to是切换到的路由// from是正要离开的路由// next控制是否允许进入目标路由next(false); //不允许 }路由级别的导航守卫 const routes [{path: /User,name: U…...

基于Matlab实现LDPC编码

在无线通信和数据存储领域&#xff0c;LDPC&#xff08;低密度奇偶校验码&#xff09;编码是一种高效、纠错能力强大的错误校正技术。本MATLAB仿真程序全面地展示了如何在AWGN&#xff08;加性高斯白噪声&#xff09;信道下应用LDPC编码与BPSK&#xff08;二进制相移键控&#…...

PostgreSQL 中约束Constraints

在 PostgreSQL 中&#xff0c;约束&#xff08;Constraints&#xff09;是用于限制进入数据库表中数据的规则。它们确保数据的准确性和可靠性&#xff0c;通过定义规则来防止无效数据的插入或更新。PostgreSQL 支持多种类型的约束&#xff0c;每种约束都有特定的用途和语法。以…...

✨系统设计时应时刻考虑设计模式基础原则

目录 &#x1f4ab;单一职责原则 (Single Responsibility Principle, SRP)&#x1f4ab;开放-封闭原则 (Open-Closed Principle, OCP)&#x1f4ab;依赖倒转原则 (Dependency Inversion Principle, DIP)&#x1f4ab;里氏代换原则 (Liskov Substitution Principle, LSP)&#x…...

【Linux】多线程(下)

目录 一、生产者消费者模型 1.1 概念 1.2 基于阻塞队列 1.3 POSIX信号量 初始化信号量 销毁信号量 等待信号量 发布信号量 1.4 基于环形队列和POSIX信号量 二、线程池 2.1 概念 2.2 代码 三、封装Linux线程库 四、单例模式 4.1 概念 4.2 单例模式的实现方式 4…...

Element-Plus如何修改日期选择器输入框el-date-picker的圆角

使用 el-date-picker 的 style 属性 :style"{ --el-border-radius-base: 10px }"<!-- 日期 --> <el-form-item label"日期" prop"establishmentDate"><el-date-picker v-model"form.establishmentDate" type"dat…...

skywalking es查询整理

索引介绍 sw_records-all 这个索引用于存储所有的采样记录&#xff0c;包括但不限于慢SQL查询、Agent分析得到的数据等。这些记录数据包括Traces、Logs、TopN采样语句和告警信息。它们被用于性能分析和故障排查&#xff0c;帮助开发者和运维团队理解服务的行为和性能特点。 …...

故障排除-------K8s挂载集群外NFS异常

故障排除-------K8s挂载集群外NFS异常 1. 故障现象2. 原因梳理2.1 排查思路2.2 确认yaml内容2.3 创建k8s内的nfs测试2.3.1 创建nfs和svc2.3.2 测试创建pvc2.3.3 测试结果 2.4 NFS服务端故障排除2.4.1 网络阻断排除2.4.2 排除服务状态问题2.4.3 排查NFS权限问题 3. 故障排除 1. …...

Easyexcel(6-单元格合并)

相关文章链接 Easyexcel&#xff08;1-注解使用&#xff09;Easyexcel&#xff08;2-文件读取&#xff09;Easyexcel&#xff08;3-文件导出&#xff09;Easyexcel&#xff08;4-模板文件&#xff09;Easyexcel&#xff08;5-自定义列宽&#xff09;Easyexcel&#xff08;6-单…...

解决登录Google账号遇到手机上Google账号无法验证的问题

文章目录 场景小插曲解决方案总结 场景 Google账号在新的设备上登录的时候&#xff0c;会要求在手机的Google上进行确认验证&#xff0c;而如果没有安装Google play就可能出现像我一样没有任何弹框&#xff0c;无法实现验证 小插曲 去年&#xff0c;我在笔记本上登录了Googl…...

【Redis_Day5】String类型

【Redis_Day5】String类型 String操作String的命令set和get&#xff1a;设置、获取键值对mset和mget&#xff1a;批量设置、获取键值对setnx/setex/psetexincr和incrby&#xff1a;对字符串进行加操作decr/decrby&#xff1a;对字符串进行减操作incrbyfloat&#xff1a;浮点数加…...

Python MySQL SQLServer操作

Python MySQL SQLServer操作 Python 可以通过 pymysql 连接 MySQL&#xff0c;通过 pymssql 连接 SQL Server。以下是基础操作和代码实战示例&#xff1a; 一、操作 MySQL&#xff1a;使用 pymysql python 操作数据库流程 1. 安装库 pip install pymysql2. 连接 MySQL 示例 …...

400网站建设电话/网络营销策划方案论文

控件 -属性&#xff1a; --id:每一个的唯一标识 --layout_width,layout_height:宽度&#xff0c;高度(match_parent,fill_parent,wrap_content) --text:指定显示内容 --gravity:指定文字的对齐方式(top,bottom,left,right,center) --textSize:文字大小 --textColor:文本颜色 --…...

余姚哪里有做淘宝网站的/推广网站的方法

将CentOS系统光盘放入光驱&#xff0c;如果是VM虚拟机&#xff0c;选择加载ISO镜像在VM选项卡--设置--硬件-CD/DVD/--使用ISO映像文件&#xff0c;选择你ISO的路径1.手动挂载光驱先建立挂载目录&#xff0c;#mkdir /media/cdrom挂载光驱#mount /dev/cdrom /media/cdrom2.开机自…...

做网站的数据库/搜索引擎优化要考虑哪些方面?

&#xfeff;&#xfeff;1 新建一个项目&#xff1a;TCPServer.pro A 改动TCPServer.pro,注意&#xff1a;假设是想使用网络库。须要加上network SOURCES \ TcpServer.cpp \ TcpClient.cpp HEADERS \ TcpServer.h \ TcpClient.h QT gui widgets network CONFIG C11 B 新…...

vr超市门户网站建设/淘宝交易指数换算工具

ZMap 类 功能介绍 ZMap 是学习百度地图 api 接口&#xff0c;开发基本功能后整的一个脚本类&#xff0c;本类方法功能大多使用 prototype 原型 实现&#xff1b; 包含的功能有&#xff1a;轨迹回放&#xff0c;圈画区域可编辑&#xff0c;判断几个坐标是否在一个圆圈内&#xf…...

有哪些免费做外贸的网站/window优化大师

安装了 64 位 Windows 系统&#xff0c;它其实包含了 32 位系统兼容库&#xff0c;并且有 32 位单独的文件夹&#xff0c;可以运行大部分 32 位的软件。 但 32 位却不能使用 64 位的软件。 现阶段使用的是32位的 因为大部分库支持32位 HOperatorSet.OpenFramegrabber error…...

滁州网站开发/互联网运营推广

知识点1----ALTER 下列代码意义&#xff1a;向已存在的表my_foods中新增自动排列的列 作为主键 ALTER TABLE my_contacts  --表名称ADD COLUMN id INT NOT NULL AUTO_INCREMENT FIRST,   --新的 列 id&#xff0c;自动排列&#xff0c;该列于第一位 ADD PRIMARY KEY (id);…...