当前位置: 首页 > news >正文

李宏毅2023机器学习作业1--homework1

一、前期准备

下载训练数据和测试数据

# dropbox link
!wget -O covid_train.csv https://www.dropbox.com/s/lmy1riadzoy0ahw/covid.train.csv?dl=0
!wget -O covid_test.csv https://www.dropbox.com/s/zalbw42lu4nmhr2/covid.test.csv?dl=0

导入包

# Numerical Operations
import math
import numpy as np        # numpy操作数据,增加删除查找修改# Reading/Writing Data
import pandas as pd       # pandas读取csv文件
import os                 # 进行文件夹操作
import csv# For Progress Bar
from tqdm import tqdm     # 可视化# Pytorch
import torch              # pytorch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter

定义一些功能函数

def same_seed(seed):'''Fixes random number generator seeds for reproducibility.'''torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsenp.random.seed(seed)torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed_all(seed)# 划分训练数据集和验证数据集
def train_valid_split(data_set, valid_ratio, seed):'''Split provided training data into training set and validation set'''valid_set_size = int(valid_ratio * len(data_set))train_set_size = len(data_set) - valid_set_sizetrain_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))return np.array(train_set), np.array(valid_set)

配置项

device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {'seed': 5201314,      # Your seed number, you can pick your lucky number. :)'select_all': False,   # Whether to use all features.'valid_ratio': 0.2,   # validation_size = train_size * valid_ratio'n_epochs': 5000,     # Number of epochs.'batch_size': 256,'learning_rate': 1e-5,'early_stop': 600,    # If model has not improved for this many consecutive epochs, stop training.'save_path': './models/model.ckpt'  # Your model will be saved here.
}

二、创建数据

创建Dataset

class COVID19Dataset(Dataset):'''x: Features.y: Targets, if none, do prediction.'''def __init__(self, x, y=None):if y is None:self.y = yelse:self.y = torch.FloatTensor(y)self.x = torch.FloatTensor(x)def __getitem__(self, idx):if self.y is None:return self.x[idx]else:return self.x[idx], self.y[idx]def __len__(self):return len(self.x)

特征选择

删除了belife和mental 的特征,belife和mental都是心理上精神上的特征,感觉可能和阳性率的偏差较大,就删去了这两类的特征

def select_feat(train_data, valid_data, test_data, select_all=True):'''Selects useful features to perform regression'''# [:,-1]第一个维度选择所有,选取所有行,第二个维度选择-1,-1是倒数第一个元素,也就是标签labely_train, y_valid = train_data[:,-1], valid_data[:,-1]   # 选择标签元素# [:,:-1]第一个维度选择所有,所有行,第二个维度从开始元素到倒数第一个元素(不包含倒数第一个元素)raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_dataif select_all:feat_idx = list(range(raw_x_train.shape[1]))else:# feat_idx = list(range(35, raw_x_train.shape[1])) # TODO: Select suitable feature columns."""删除了belife和mental 的特征[0, 38, 39, 46, 51, 56, 57, 64, 69, 74, 75, 82, 87]是belife和mental所在列"""del_col = [0, 38, 39, 46, 51, 56, 57, 64, 69, 74, 75, 82, 87]  raw_x_train = np.delete(raw_x_train, del_col, axis=1) # numpy数组增删查改方法raw_x_valid = np.delete(raw_x_valid, del_col, axis=1)raw_x_test = np.delete(raw_x_test, del_col, axis=1)return raw_x_train, raw_x_valid, raw_x_test, y_train, y_validreturn raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid

 创建 Dataloader

读取文件,设置训练,验证和测试数据集

# Set seed for reproducibility
same_seed(config['seed'])# train_data size: 3009 x 89 (35 states + 18 features x 3 days)  
# train_data共3009条数据,每条数据89个维度
# test_data size: 997 x 88 (without last day's positive rate)
# test_data共997条数据,每条数据88个维度,没有最后一天的最后一列数据positive rate# pands读取csv数据
train_data, test_data = pd.read_csv('./covid_train.csv').values, pd.read_csv('./covid_test.csv').values     # train_valid_split切分训练集和验证集
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])# Print out the data size.打印数据尺寸
print(f"""train_data size: {train_data.shape}
valid_data size: {valid_data.shape}
test_data size: {test_data.shape}""")# Select features 选择特征
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])# Print out the number of features. 打印特征数
print(f'number of features: {x_train.shape[1]}')# 生成dataset
train_dataset, valid_dataset, test_dataset = COVID19Dataset(x_train, y_train), \COVID19Dataset(x_valid, y_valid), \COVID19Dataset(x_test)# Pytorch data loader loads pytorch dataset into batches.
# pytorch的dataloder加载dataset
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)

 三、创建神经网络模型

class My_Model(nn.Module):         def __init__(self, input_dim):super(My_Model, self).__init__()# TODO: modify model's structure, be aware of dimensions.self.layers = nn.Sequential(nn.Linear(input_dim, 16),nn.ReLU(),nn.Linear(16, 8),nn.ReLU(),nn.Linear(8, 1))def forward(self, x):x = self.layers(x)x = x.squeeze(1) # (B, 1) -> (B)return x

四、模型训练和模型测试

模型训练

def trainer(train_loader, valid_loader, model, config, device):criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.# Define your optimization algorithm.# TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.# TODO: L2 regularization (optimizer(weight decay...) or implement by your self).optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9)writer = SummaryWriter() # Writer of tensoboard.# 如果没有models文件夹,创建名称为models的文件夹,保存模型if not os.path.isdir('./models'):    os.mkdir('./models') # Create directory of saving models.# math.inf为无限大n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0for epoch in range(n_epochs):model.train() # Set your model to train mode.loss_record = []    # 记录损失# tqdm is a package to visualize your training progress.train_pbar = tqdm(train_loader, position=0, leave=True)for x, y in train_pbar:optimizer.zero_grad()               # Set gradient to zero.x, y = x.to(device), y.to(device)   # Move your data to device.pred = model(x)                     # 数据传入模型model,生成预测值predloss = criterion(pred, y)           # 预测值pred和真实值y计算损失loss  loss.backward()                     # Compute gradient(backpropagation).optimizer.step()                    # Update parameters.step += 1loss_record.append(loss.detach().item())   # 当前步骤的loss加到loss_record[]# Display current epoch number and loss on tqdm progress bar.train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')train_pbar.set_postfix({'loss': loss.detach().item()})mean_train_loss = sum(loss_record)/len(loss_record)      # 计算训练集上平均损失writer.add_scalar('Loss/train', mean_train_loss, step)   model.eval() # Set your model to evaluation mode.loss_record = []for x, y in valid_loader:x, y = x.to(device), y.to(device)with torch.no_grad():pred = model(x)loss = criterion(pred, y)loss_record.append(loss.item())mean_valid_loss = sum(loss_record)/len(loss_record)      # 计算验证集上平均损失     print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')writer.add_scalar('Loss/valid', mean_valid_loss, step)# 保存验证集上平均损失最小的模型if mean_valid_loss < best_loss:         best_loss = mean_valid_losstorch.save(model.state_dict(), config['save_path']) # Save your best modelprint('Saving model with loss {:.3f}...'.format(best_loss))early_stop_count = 0else:early_stop_count += 1# 设置早停early_stop_count# 如果early_stop_count次数,验证集上的平均损失没有变化,模型性能没有提升,停止训练if early_stop_count >= config['early_stop']:   print('\nModel is not improving, so we halt the training session.')return

模型测试

# 测试数据集的预测
def predict(test_loader, model, device):model.eval() # Set your model to evaluation mode.preds = []for x in tqdm(test_loader):x = x.to(device)with torch.no_grad():   # 关闭梯度pred = model(x)preds.append(pred.detach().cpu())preds = torch.cat(preds, dim=0).numpy()return preds


 

五、训练模型

model = My_Model(input_dim=x_train.shape[1]).to(device) # put your model and data on the same computation device.trainer(train_loader, valid_loader, model, config, device)

六、测试模型,生成预测值

def save_pred(preds, file):''' Save predictions to specified file '''with open(file, 'w') as fp:writer = csv.writer(fp)writer.writerow(['id', 'tested_positive'])for i, p in enumerate(preds):writer.writerow([i, p])model = My_Model(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))    # 加载模型
preds = predict(test_loader, model, device)               # 生成预测结果preds
save_pred(preds, 'pred.csv')                              # 保存preds到pred.csv   

tensorboard可视化训练和验证损失图像


%reload_ext tensorboard
%tensorboard --logdir=./runs/

参考:

李宏毅_机器学习_作业1(详解)_COVID-19 Cases Prediction (Regression)-物联沃-IOTWORD物联网

【深度学习】2023李宏毅homework1作业一代码详解_李宏毅作业1-CSDN博客

np.delete详解-CSDN博客

相关文章:

李宏毅2023机器学习作业1--homework1

一、前期准备 下载训练数据和测试数据 # dropbox link !wget -O covid_train.csv https://www.dropbox.com/s/lmy1riadzoy0ahw/covid.train.csv?dl0 !wget -O covid_test.csv https://www.dropbox.com/s/zalbw42lu4nmhr2/covid.test.csv?dl0 导入包 # Numerical Operation…...

Mysql的SQL调优-面试

面试SQL优化的具体操作&#xff1a; 1、在表中建立索引&#xff0c;优先考虑where、group by使用到的字段。 2、尽量避免使用select *&#xff0c;返回无用的字段会降低查询效率。错误如下&#xff1a; SELECT * FROM table 优化方式&#xff1a;使用具体的字段代替 *&#xf…...

Unity 2021.3发布WebGL设置以及nginx的配置

使用unity2021.3发布webgl 使用Unity制作好项目之后建议进行代码清理&#xff0c;这样会即将不用的命名空间去除&#xff0c;不然一会在发布的时候有些命名空间webgl会报错。 平台转换 将平台设置为webgl 设置色彩空间压缩方式 Compression Format 设置为DisabledDecompre…...

【鸿蒙 HarmonyOS 4.0】数据持久化

一、数据持久化介绍 数据持久化是将内存数据(内存是临时的存储空间)&#xff0c;通过文件或数据库的形式保存在设备中。 HarmonyOS提供两种数据持久化方案&#xff1a; 1.1、用户首选项&#xff08;Preferences&#xff09;&#xff1a; 通常用于保存应用的配置信息。数据通…...

mysql mgr集群多主部署

一、前言 mgr多主集群是将集群中的所有节点都设为可写&#xff0c;减轻了单主节点的写压力&#xff0c;从而提高了mysql的写入性能 二、部署 基础部署与mgr集群单主部署一致&#xff0c;只是在创建mgr集群时有所不同 基础部署参考&#xff1a;mysql mgr集群部署-CSDN博客 设置…...

【开源】JAVA+Vue.js实现医院门诊预约挂号系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 功能性需求2.1.1 数据中心模块2.1.2 科室医生档案模块2.1.3 预约挂号模块2.1.4 医院时政模块 2.2 可行性分析2.2.1 可靠性2.2.2 易用性2.2.3 维护性 三、数据库设计3.1 用户表3.2 科室档案表3.3 医生档案表3.4 医生放号…...

《图解设计模式》笔记(一)适应设计模式

图灵社区 - 图解设计模式 - 随书下载 评论区 雨帆 2017-01-11 16:14:04 对于设计模式&#xff0c;我个人认为&#xff0c;其实代码和设计原则才是最好的老师。理解了 SOLID&#xff0c;如何 SOLID&#xff0c;自然而然地就用起来设计模式了。Github 上有一个 tdd-training&…...

图文说明Linux云服务器如何更改实例镜像

一、应用场景举例 在学习Linux的vim时&#xff0c;我们难免要对vim进行一些配置&#xff0c;这里我们提供一个vim插件的安装包&#xff1a; curl -sLf https://gitee.com/HGtz2222/VimForCpp/raw/master/install.sh -o./install.sh && bash ./install.sh 但是此安装包…...

RabbitMQ学习整理————基于RabbitMQ实现RPC

基于RabbitMQ实现RPC 前言什么是RPCRabbitMQ如何实现RPCRPC简单示例通过Spring AMQP实现RPC 前言 这边参考了RabbitMQ的官网&#xff0c;想整理一篇关于RabbitMQ实现RPC调用的博客&#xff0c;打算把两种实现RPC调用的都整理一下&#xff0c;一个是使用官方提供的一个Java cli…...

Linux-基础知识(黑马学习笔记)

硬件和软件 我们所熟知的计算机是由&#xff1a;硬件和软件组成。 硬件&#xff1a;计算机系统中电子&#xff0c;机械和光电元件等组成的各种物理装置的总称。 软件&#xff1a;是用户和计算机硬件之间的接口和桥梁&#xff0c;用户通过软件与计算机进行交流。 而操作系统…...

SpringBoot项目启动报java.nio.charset.MalformedInputException Input length = 1解决方案

报错详情 SpringBoot启动报错java.nio.charset.MalformedInputException: Input length 1 报错原因 出现这个的原因&#xff0c;就是解析yml文件时&#xff0c;中文字符集不是utf-8的原因&#xff0c;这是maven在项目编译时&#xff0c;默认字符集编码是GBK。 解决方式 检…...

【Unity2019.4.35f1】配置JDK、NDK、SDK、Gradle

目录 JDK NDK SDK 环境变量 Gradle JDK JDK&#xff1a;jdk-1.8版本Java Downloads | Oracle 下载要登录&#xff0c;搜索JDK下载公用账号&#xff1a;Oracle官网 JDK下载 注册登录公共账号和密码_oracle下载账号-CSDN博客 路径&#xff1a;C:\Program Files\Java\jd…...

MySQL中的高级查询

通过条件查询可以查询到符合条件的数据&#xff0c;但如同要实现对字段的值进行计算、根据一个或多个字段对查询结果进行分组等操作时&#xff0c;就需要使用更高级的查询&#xff0c;MySQL提供了聚合函数、分组查询、排序查询、限量查询、内置函数以实现更复杂的查询需求。接下…...

leetcode383赎金信

用字符数组ch来记录magazine每个字母出现频率&#xff0c;用ransomNote的字母减去字符数组ch对应的字符出现频率&#xff0c;如果该字符对应的频率小于0&#xff0c;则不够&#xff0c;无法组成ransomNote&#xff01; class Solution { public:bool canConstruct(string rans…...

【Unity3D】ASE制作天空盒

找到官方shader并分析 下载对应资源包找到\DefaultResourcesExtra\Skybox-Cubed.shader找到\CGIncludes\UnityCG.cginc观察变量, 观察tag, 观察代码 需要注意的内容 ASE要处理的内容 核心修改 添加一个Custom Expression节点 code内容为: return DecodeHDR(In0, In1);outp…...

MyBatisPlus常用注解

目录 一、TableName 二、TableId 三、TableField 四、TableLogic 一、TableName 在使用MyBatis-Plus实现基本的CRUD时&#xff0c;我们并没有指定要操作的表&#xff0c;只是在Mapper接口继承BaseMapper时&#xff0c;设置了泛型User&#xff0c;而操作的表为user表 由此得出…...

Putty中运行matlab文件

首先使用命令 cd /home/ya/CodeTest/Matlab进入路径&#xff1a;到Matlab文件夹下 然后键入matlab&#xff0c;进入matlab环境&#xff0c;如果main.m文件在Matlab文件夹下&#xff0c;直接键入main即可运行该文件。细节代码如下&#xff1a; Unable to use key file "y…...

ES6 | (一)ES6 新特性(上) | 尚硅谷Web前端ES6教程

文章目录 &#x1f4da;ES6新特性&#x1f4da;let关键字&#x1f4da;const关键字&#x1f4da;变量的解构赋值&#x1f4da;模板字符串&#x1f4da;简化对象写法&#x1f4da;箭头函数&#x1f4da;函数参数默认值设定&#x1f4da;rest参数&#x1f4da;spread扩展运算符&a…...

生产环境下,应用模式部署flink任务,通过hdfs提交

前言 通过通过yarn.provided.lib.dirs配置选项指定位置&#xff0c;将flink的依赖上传到hdfs文件管理系统 1. 实践 &#xff08;1&#xff09;生产集群为cdh集群&#xff0c;从cm上下载配置文件&#xff0c;设置环境 export HADOOP_CONF_DIR/home/conf/auth export HADOOP_CL…...

【lesson59】线程池问题解答和读者写者问题

文章目录 线程池问题解答什么是单例模式什么是设计模式单例模式的特点饿汉和懒汉模式的理解STL中的容器是否是线程安全的?智能指针是否是线程安全的&#xff1f;其他常见的各种锁 读者写者问题 线程池问题解答 什么是单例模式 单例模式是一种 “经典的, 常用的, 常考的” 设…...

【LeetCode每日一题】单调栈316去除重复字母

题目&#xff1a;去除重复字母 给你一个字符串 s &#xff0c;请你去除字符串中重复的字母&#xff0c;使得每个字母只出现一次。需保证 返回结果的字典序最小&#xff08;要求不能打乱其他字符的相对位置&#xff09;。 示例 1&#xff1a; 输入&#xff1a;s “bcabc” 输…...

【Git】Gitbash使用ssh 上传本地项目到github

SSH Git上传项目到GitHub&#xff08;图文&#xff09;_git ssh上传github-CSDN博客 前提 ssh-keygen -t rsa -C “自己的github电子邮箱” 生成密钥&#xff0c;公钥保存到自己的github的ssh里 1.先创建一个仓库&#xff0c;复制ssh地址 git init git add . git commit -m …...

activeMq将mqtt发布订阅转成消息队列

1、activemq.xml置文件新增如下内容 2、mqttx测试发送&#xff1a; 主题&#xff08;配置的模糊匹配&#xff0c;为了并发&#xff09;&#xff1a;VirtualTopic/device/sendData/12312 3、mqtt接收的结果 4、程序处理 package comimport cn.hutool.core.date.DateUtil; imp…...

Go语言教程

一、引言 Go&#xff08;又称Golang&#xff09;是由Google开发的一种静态类型、编译型的开源编程语言。它旨在提供简单、快速和可靠的软件开发体验。Go语言结合了动态语言的开发效率和静态语言的安全性能&#xff0c;特别适用于网络编程、系统编程和并发编程。本教程将介绍Go…...

分布式锁的应用场景及实现

文章目录 分布式锁的应用场景及实现1. 应用场景2. 分布式锁原理3. 分布式锁的实现3.1 基于数据库 分布式锁的应用场景及实现 1. 应用场景 电商网站在进行秒杀、特价等大促活动时&#xff0c;面临访问量激增和高并发的挑战。由于活动商品通常是有限库存的&#xff0c;为了避免…...

嵌入式Linux中apt、apt-get命令用法汇总

在Linux环境开发过程中接触ubuntu虚拟机时&#xff0c;在安装软件或者更新软件时apt和apt-get命令使用相对较频繁&#xff0c;下面对这两个命令的用法进行汇总。 apt&#xff08;Advanced Package Tool&#xff09;和 apt-get 是用于在基于 Debian 的 Linux 发行版中进行软件包…...

Unity之ShaderGraph如何实现水面波浪

前言 这几天通过一个水的波浪数学公式,实现了一个波浪效果,感觉成就感满满,下面给大家分享一下 首先先给大家看一下公式; 把公式转为ShaderGraph 第一行公式:waveType = z*-1*Mathf.Cos(wave.WaveAngle/360*2*Mathf.PI)+x*Mathf.Sin(WaveAngle/360*-2*Mathf.PI) 转换…...

无线局域网(WLAN)简单概述

无线局域网 无线局域网概述 无限局域网&#xff08;Wireless Local Area Network&#xff0c;WLAN&#xff09;是一种短距离无线通信组网技术&#xff0c;它是以无线信道为传输媒质构成的计算机网络&#xff0c;通过无线电传播技术来实现在空间传输数据。 WLAN是传输范围在1…...

学习数仓工具 dbt

DBT 是一个有趣的工具&#xff0c;它通过一种结构化的方式定义了数仓中各种表、视图的构建和填充方式。 dbt 面相的对象是数据开发团队&#xff0c;提供了如下几个最有价值的能力&#xff1a; 支持多种数据库通过 select 来定义数据&#xff0c;无需编写 DML构建数据时&#…...

高录用快见刊【最快会后两个月左右见刊】第三届社会科学与人文艺术国际学术会议 (SSHA 2024)

第三届社会科学与人文艺术国际学术会议 (SSHA 2024) 2024 3rd International Conference on Social Sciences and Humanities and Arts *文章投稿均可免费参会 *高录用快见刊【最快会后两个月左右见刊】 重要信息 会议官网&#xff1a;icssha.com 大会时间&#xff1a;202…...

ps制作网站首页面教程/如何优化网站推广

之前谈到免货运&#xff0c;真的很难&#xff0c;你想如果是你在网上购买食品&#xff0c;你一般会购买多少&#xff1f;有的人会说&#xff0c;我会很多&#xff0c;但是你在看看食品的价格是多少&#xff0c;也许你购买了一大袋食品价格也许不会超过100块&#xff0c;但是从重…...

广州品牌网站开发/站长之家域名

/*Name: NYOJ--24--素数距离问题Author: shen_渊 Date: 17/04/17 16:42Description: 原来代码看不下去了o(╯□╰)o */ #include<iostream> using namespace std; int isPrime(int); int main(){ios::sync_with_stdio(false);int T;cin>>T;while(T--){int n;cin&…...

wordpress使用技巧/网销是什么工作好做吗

据悉&#xff0c;2022国际人工智能大会SAIL奖TOP30榜单近来正式出炉。作为国际人工智能大会的最高奖项&#xff0c;SAIL奖&#xff08;Superior AI Leader&#xff0c;卓越人工智能引领者&#xff09;坚持“追求卓越、引领未来”的理念&#xff0c;从全球规模发掘在人工智能范畴…...

网站 公司/朋友圈推广文案

spark能跑Python么&#xff1f;spark是可以跑Python程序的。python编写好的算法&#xff0c;或者扩展库的&#xff0c;比如sklearn都可以在spark上跑。直接使用spark的mllib也是可以的&#xff0c;大部分算法都有。Spark 是一个通用引擎&#xff0c;可用它来完成各种各样的运算…...

网站建设行业税率/一键生成网站

简介 Google 的 gflags 是一套命令行参数处理的开源库。比 getopt 更方便&#xff0c;更功能强大&#xff0c;从 C的库更好的支持 C&#xff08;如 C的 string 类型&#xff09;。 example 源代码先看 example 源代码&#xff0c;然后逐步介绍。 example.cc 1 2 3 4 5 6 7 8 9 …...

gps建站教程视频/好用吗

com.mysql.jdbc.Driver和mysql-connector-java 5搭配使用 com.mysql.cj.jdbc.Driver和mysql-connector-java 6搭配使用 只是这个是6具有的一个新特性&#xff0c;6添加了一个时区的概念...