【人工智能】Python常用库-PyTorch常用方法教程
PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。以下是 PyTorch 的详细教程,涵盖从基础到实际应用的使用方法。
1. 安装与导入
1.1 安装 PyTorch
访问 PyTorch 官方网站,根据系统、Python 版本和 CUDA 支持选择安装命令。
常用安装命令:
pip install torch torchvision torchaudio
1.2 导入库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
2. PyTorch 基础
2.1 张量(Tensor)
张量是 PyTorch 的核心数据结构,可以看作是一个高维数组。
# 创建张量
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])# 基本运算
c = a + b
print(c) # 输出 tensor([5., 7., 9.])# 随机张量
random_tensor = torch.rand((2, 3)) # 2行3列随机数
print(random_tensor)
输出结果
tensor([5., 7., 9.])
tensor([[0.9980, 0.2970, 0.5257],[0.8807, 0.0471, 0.7896]])
2.2 自动求导
PyTorch 提供动态计算图支持自动求导。
x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 4y.backward() # 自动求导
print(x.grad) # 输出 dy/dx = 2*x + 3 = 7.0
输出结果
tensor(7.)
3. 数据加载
PyTorch 提供强大的数据加载功能。
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
4. 构建神经网络
4.1 使用 nn.Module 构建模型
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28 * 28) # 展平输入x = self.relu(self.fc1(x))x = self.softmax(self.fc2(x))return xmodel = SimpleNN()print(model)
输出结果
SimpleNN((fc1): Linear(in_features=784, out_features=128, bias=True)(relu): ReLU()(fc2): Linear(in_features=128, out_features=10, bias=True)(softmax): Softmax(dim=1)
)
5. 模型训练
5.1 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)
5.2 训练循环
for epoch in range(5):for images, labels in train_loader:optimizer.zero_grad() # 梯度清零outputs = model(images)loss = criterion(outputs, labels) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新权重print(f"Epoch {epoch+1}, Loss: {loss.item()}")
完整代码
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoaderclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28 * 28) # 展平输入x = self.relu(self.fc1(x))x = self.softmax(self.fc2(x))return xmodel = SimpleNN()# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):for images, labels in train_loader:optimizer.zero_grad() # 梯度清零outputs = model(images)loss = criterion(outputs, labels) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新权重print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
输出结果
Epoch 1, Loss: 1.482284665107727
Epoch 2, Loss: 1.4968496561050415
Epoch 3, Loss: 1.5289227962493896
Epoch 4, Loss: 1.4832825660705566
Epoch 5, Loss: 1.5070817470550537
6. 模型评估
6.1 在测试集上评估
test_data = MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {correct / total * 100:.2f}%")
输出结果
Test Accuracy: 10.32%
7. GPU 加速
PyTorch 支持使用 GPU 加速。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)# 将数据也移动到 GPU
for images, labels in train_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)
8. 保存与加载模型
8.1 保存模型
torch.save(model.state_dict(), 'model.pth')
8.2 加载模型
model = SimpleNN()
model.load_state_dict(torch.load('model.pth'))
model.eval() # 切换到评估模式
9. 实际案例
9.1 CIFAR-10 图像分类
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms# CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(16 * 16 * 16, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = x.view(-1, 16 * 16 * 16)x = self.fc1(x)return xmodel = CNN()
# 后续训练步骤类似
10. PyTorch 优势总结
- 动态计算图:支持动态构建与修改模型。
- 灵活性:适合研究和开发,易于调试。
- 强大的社区支持:广泛的教程、示例和扩展工具。
通过实践,PyTorch 能够帮助用户更好地理解和实现深度学习算法!
相关文章:
【人工智能】Python常用库-PyTorch常用方法教程
PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。以下是 PyTorch 的详细教程,涵盖从基础到实际应用的使用方法。 1. 安装与导入 1.1 安装 PyTorch 访问 PyTorch 官方网站,根据系统、Python 版本和 CUDA 支持选择安…...
Android Studio安装TalkX AI编程助手
文章目录 TalkX简介编程场景 TalkX安装TalkX编程使用ai编程助手相关文章 TalkX简介 TalkX是一款将OpenAI的GPT 3.5/4模型集成到IDE的AI编程插件。它免费提供特定场景的AI编程指导,帮助开发人员提高工作效率约38%,甚至在解决编程问题的效率上提升超过2倍…...
#渗透测试#红蓝攻防#HW#漏洞挖掘#漏洞复现02-永恒之蓝漏洞
免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…...
gitlab自动打包python项目
现在新版的gitlab可以不用自己配置runner什么的了 直接写.gitlab-ci.yml文件就行,这里给出一个简单的依靠setup把python项目打包成whl文件的方法 首先写.gitlab-ci.yml文件,放到项目根目录里 stages: # List of stages for jobs, and their or…...
残差神经网络
目录 1. 梯度消失问题 2. 残差学习的引入 3. 跳跃连接(Shortcut Connections) 4. 恒等映射与维度匹配 5. 反向传播与梯度流 6. 网络深度与性能 总结 残差神经网络的原理是基于“残差学习”的概念,它旨在解决深度神经网络训练中的梯度消…...
mini-spring源码分析
IOC模块 关键解释 beanFactory:beanFactory是一个hashMap, key为beanName, Value为 beanDefination beanDefination: BeanDefinitionRegistry,BeanDefinition注册表接口,定义注册BeanDefinition的方法 beanReference:增加Bean…...
黑马程序员Java项目实战《苍穹外卖》Day01
苍穹外卖-day01 课程内容 软件开发整体介绍苍穹外卖项目介绍开发环境搭建导入接口文档Swagger 项目整体效果展示: 管理端-外卖商家使用 用户端-点餐用户使用 当我们完成该项目的学习,可以培养以下能力: 1. 软件开发整体介绍 作为一…...
uniapp开发支付宝小程序自定义tabbar样式异常
解决方案: 这个问题应该是支付宝基础库的问题,除了依赖于官方更新之外,开发者可以利用《自定义 tabBar》曲线救国 也就是创建一个空内容的自定义tabBar,这样即使 tabBar 被渲染出来,但从视觉上也不会有问题 1.官方文…...
python+django5.1+docker实现CICD自动化部署springboot 项目前后端分离vue-element
一、开发环境搭建和配置 # channels是一个用于在Django中实现WebSocket、HTTP/2和其他异步协议的库。 pip install channels#channels-redis是一个用于在Django Channels中使用Redis作为后台存储的库。它可以用于处理#WebSocket连接的持久化和消息传递。 pip install channels…...
python代码示例(读取excel文件,自动播放音频)
目录 python 操作excel 表结构 安装第三方库 代码 自动播放音频 介绍 安装第三方库 代码 python 操作excel 表结构 求出100班同学的平均分 安装第三方库 因为这里的表结构是.xlsx文件,需要使用openpyxl库 如果是.xls格式文件,需要使用xlrd库 pip install openpyxl /…...
【第十课】Rust并发编程(一)
目录 前言 Fork和Join 前言 本节会介绍Rust中的并发编程,并发编程在编程中是提升cpu使用率的一大利器,通过多线程技术提升效率,Rust的并发和其他编程语言的并发不同的地方在于,Rust号称无畏并发。更重要的一点是安全。Rust中所有…...
图形渲染性能优化
variable rate shading conditional render 设置可见性等, 不需要重新build command buffer indirect draw glMultiDraw* - 直接支持多次绘制glMultiDrawIndirect - 间接多次绘制multithreading 多线程录制 实例化渲染 lod texture array 小对象剔除 投影到…...
elasticsearch的索引模版使用方法
5 索引模版⭐️⭐️⭐️⭐️⭐️ 索引模板就是创建索引时要遵循的模板规则索引模板仅对新创建的索引有效,已经创建的索引并不受索引模板的影响 5.1 索引模版的基本使用 1.查看所有的索引模板 GET 10.0.0.91:9200/_index_template2.创建自定义索引模板 xixi &…...
论文学习——进化动态约束多目标优化:测试集和算法
论文题目:Evolutionary Dynamic Constrained Multiobjective Optimization: Test Suite and Algorithm 进化动态约束多目标优化:测试集和算法(Guoyu Chen ,YinanGuo , Member, IEEE, Yong Wang , Senior Member, IEEE, Jing Liang , Senior …...
C++中的volatile关键字
作用: 1.它用于修饰变量,告知编译器该变量的值可能会在程序的外部被改变,编译器不能对这个变量的访问进行优化。这是因为编译器通常会对代码进行优化,例如把变量的值缓存到寄存器中,但对于 volatile 变量,…...
linux桌面qt应用程序UI自动化实现之dogtail
1. 前言 Dogtail适用于Linux 系统上进行 GUI 自动化测试,利用 Accessibility 技术与桌面程序通信;Dogtail 包含一个名为 sniff 的组件,这是一个嗅探器,用于 GUI 程序追踪; 源码下载:dogtail PyPI 可通过sudo python setup.py install安装或sudo pip install dogt…...
Hello World C#
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using System; 引入了System命名空间,基本输入输出。一般只用这个,后面的不用 using System.Collections.Generic; 包含了定…...
SAP开发语言ABAP开发入门
1. 了解ABAP开发环境和基础知识 - ABAP简介 - ABAP(Advanced Business Application Programming)是SAP系统中的编程语言,主要用于开发企业级的业务应用程序,如财务、物流、人力资源等模块的定制开发。 - 开发环境搭建 - 首先需…...
应急响应靶机——easy溯源
载入虚拟机,开启虚拟机: (账户密码:zgsfsys/zgsfsys) 解题程序.exe是额外下载解压得到的: 1. 攻击者内网跳板机IP地址 2. 攻击者服务器地址 3. 存在漏洞的服务(提示:7个字符) 4. 攻击者留下的flag(格式…...
【前端】vscode报错: 无法加载文件 D:\nodejs\node_global\yarn.ps1,因为在此系统上禁止运行脚本。
vscode运行前端代码时候,执行yarn install时候报错 问题: 无法加载文件 D:\nodejs\node_global\yarn.ps1,因为在此系统上禁止运行脚本。 解决方式: 首先用管理员身份运行vscode 查看 get-ExecutionPolicy,Restrict…...
MFC内存泄露
1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...
跨链模式:多链互操作架构与性能扩展方案
跨链模式:多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈:模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展(H2Cross架构): 适配层…...
Java 加密常用的各种算法及其选择
在数字化时代,数据安全至关重要,Java 作为广泛应用的编程语言,提供了丰富的加密算法来保障数据的保密性、完整性和真实性。了解这些常用加密算法及其适用场景,有助于开发者在不同的业务需求中做出正确的选择。 一、对称加密算法…...
AI编程--插件对比分析:CodeRider、GitHub Copilot及其他
AI编程插件对比分析:CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展,AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者,分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...
企业如何增强终端安全?
在数字化转型加速的今天,企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机,到工厂里的物联网设备、智能传感器,这些终端构成了企业与外部世界连接的 “神经末梢”。然而,随着远程办公的常态化和设备接入的爆炸式…...
基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解
JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用,结合SQLite数据库实现联系人管理功能,并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能,同时可以最小化到系统…...
LRU 缓存机制详解与实现(Java版) + 力扣解决
📌 LRU 缓存机制详解与实现(Java版) 一、📖 问题背景 在日常开发中,我们经常会使用 缓存(Cache) 来提升性能。但由于内存有限,缓存不可能无限增长,于是需要策略决定&am…...
C++中vector类型的介绍和使用
文章目录 一、vector 类型的简介1.1 基本介绍1.2 常见用法示例1.3 常见成员函数简表 二、vector 数据的插入2.1 push_back() —— 在尾部插入一个元素2.2 emplace_back() —— 在尾部“就地”构造对象2.3 insert() —— 在任意位置插入一个或多个元素2.4 emplace() —— 在任意…...
GB/T 43887-2024 核级柔性石墨板材检测
核级柔性石墨板材是指以可膨胀石墨为原料、未经改性和增强、用于核工业的核级柔性石墨板材。 GB/T 43887-2024核级柔性石墨板材检测检测指标: 测试项目 测试标准 外观 GB/T 43887 尺寸偏差 GB/T 43887 化学成分 GB/T 43887 密度偏差 GB/T 43887 拉伸强度…...
mcts蒙特卡洛模拟树思想
您这个观察非常敏锐,而且在很大程度上是正确的!您已经洞察到了MCTS算法在不同阶段的两种不同行为模式。我们来把这个关系理得更清楚一些,您的理解其实离真相只有一步之遥。 您说的“select是在二次选择的时候起作用”,这个观察非…...
