当前位置: 首页 > news >正文

NeuralCF 模型:神经网络协同过滤模型

实验和完整代码

完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main

引言

NeuralCF 模型由新加坡国立大学研究人员于 2017 年提出,其核心思想在于将传统协同过滤方法与深度学习技术相结合,从而更为有效地捕捉用户与物品之间的复杂交互关系。该模型利用神经网络自动学习用户和物品的低维表示,并通过这些表示实现对用户评分的精准预测。

1. NeuralCF模型简介

NeuralCF 模型融合了矩阵分解与深度学习两种方法的优势,采用基于神经网络的结构来建模用户与物品间的非线性交互。传统矩阵分解方法通过计算用户与物品隐向量的内积来进行评分预测,而 NeuralCF 则利用多层感知机(MLP)对用户与物品隐向量进行联合建模。具体而言,模型首先为每个用户和物品分配低维嵌入向量,然后将这些向量进行拼接(concatenate),再输入到深层神经网络中以提取潜在交互特征,最后通过输出层得到预测评分。

2. NeuralCF的模型架构

NeuralCF 模型的架构主要包括以下关键组件:

  1. 用户与物品嵌入(Embedding)
    与传统矩阵分解方法类似,NeuralCF 为每个用户与物品分配低维嵌入向量,分别表征用户兴趣和物品特征。
  2. 嵌入向量拼接(Concatenation)
    在模型中,用户与物品的嵌入向量被拼接为一个更高维度的向量,作为神经网络的输入。这种拼接不仅保留了各自的特征信息,同时为网络提供了学习复杂交互模式的可能性。
  3. 多层感知机(MLP)
    拼接后的向量经过多个全连接层(MLP)的处理,每一层均采用激活函数(通常为 ReLU)引入非线性变换,以便捕捉用户与物品之间更高阶的特征交互。
  4. 输出层
    多层感知机的输出经过一个线性层转换后,最终得到评分预测。在实际应用中,该预测值可以代表二分类问题(例如点击与否)或回归问题(例如具体评分)

在这里插入图片描述

2.1 数学模型

1. 用户和物品嵌入(Embedding)

NeuralCF模型首先为每个用户和每个物品分配一个低维度的隐向量。假设有 M M M 个用户和 N N N 个物品,用户 u u u 的隐向量为 p u ∈ R d \mathbf{p_u} \in \mathbb{R}^d puRd,物品 i i i 的隐向量为 q i ∈ R d \mathbf{q_i} \in \mathbb{R}^d qiRd,其中 d d d 是隐向量的维度。

2. 嵌入向量的拼接

传统的矩阵分解方法直接计算用户和物品隐向量的内积来预测评分,而NeuralCF通过将用户和物品的隐向量拼接(concatenate)在一起,构成一个新的向量:

z = concat ( p u , q i ) ∈ R 2 d \mathbf{z} = \text{concat}(\mathbf{p_u}, \mathbf{q_i}) \in \mathbb{R}^{2d} z=concat(pu,qi)R2d

3. 多层感知机(MLP)

将拼接后的向量z输入到包含L层的多层感知机中,每一层的变换公式为:

h l = ReLU ( W l h l − 1 + b l ) , l = 1 , 2 , … , L \mathbf{h_l} = \text{ReLU}(\mathbf{W_l} \mathbf{h_{l-1}} + \mathbf{b_l}), \quad l = 1, 2, \dots, L hl=ReLU(Wlhl1+bl),l=1,2,,L

其中,初始输入为 h 0 = z \mathbf{h_0} = \mathbf{z} h0=z W l \mathbf{W_l} Wl 和偏置 b l \mathbf{b_l} bl分别为第 l层的权重和偏置参数

4. 输出层

经过多层感知机后,最终输出层采用线性变换:

r u i ^ = σ ( W L h L + b L ) \hat{r_{ui}} = \sigma(\mathbf{W_L} \mathbf{h_L} + \mathbf{b_L}) rui^=σ(WLhL+bL)

其中, σ \sigma σ​ 表示sigmoid激活函数,输出值位于 0 与 1 之间,适用于二分类任务;对于回归任务,则可去除 Sigmoid 激活。

5. 损失函数

针对不同任务,NeuralCF 可采用不同的损失函数:

  • 回归问题:通常使用均方误差(MSE):

L = 1 N ∑ ( u , i ) ( r u i − r u i ^ ) 2 \mathcal{L} = \frac{1}{N} \sum_{(u,i)} \left( r_{ui} - \hat{r_{ui}} \right)^2 L=N1(u,i)(ruirui^)2

  • 二分类问题, 损失函数为交叉熵:

L = − 1 N ∑ ( u , i ) ( r u i log ⁡ ( r u i ^ ) + ( 1 − r u i ) log ⁡ ( 1 − r u i ^ ) ) \mathcal{L} = -\frac{1}{N} \sum_{(u,i)} \left( r_{ui} \log(\hat{r_{ui}}) + (1 - r_{ui}) \log(1 - \hat{r_{ui}}) \right) L=N1(u,i)(ruilog(rui^)+(1rui)log(1rui^))

3 NeuralCF混合模型

为进一步提升特征组合能力和非线性表达能力,NeuralCF 在原有架构基础上引入了广义矩阵分解(Generalized Matrix Factorization, GMF)模块。需要指出的是,GMF 与 MLP 部分分别采用独立的嵌入层,这一设计有效提升了模型的灵活性和表现力。

在这里插入图片描述

3.2 GMF广义矩阵分解

广义矩阵分解模型扩展了传统矩阵分解方法,通过引入不同的用户与物品交互方式来建模。与经典矩阵分解方法通过内积计算用户与物品之间的相似性不同,GMF 采用元素积(Hadamard 乘积)来刻画二者间的交互关系:

ϕ 1 ( p u , q i ) = p u ⊙ q i \phi_1(p_u, q_i) = p_u \odot q_i ϕ1(pu,qi)=puqi

其中, p u p_u pu q i q_i qi 是用户和物品的嵌入向量, ⊙ \odot 是元素积操作。

3.4 GMF和MLP的融合

为了解决共享嵌入层的限制,本方法提出了让GMF和MLP分别学习独立的嵌入层,并通过连接它们的最后一层隐藏层进行融合。具体而言,GMF和MLP的输出通过以下公式进行联合建模:

  1. GMF 部分

ϕ G M F = p u G ⊙ q i G , \phi_{GMF} = p_u^G \odot q_i^G, ϕGMF=puGqiG,

其中, p u G p_u^G puG q i G q_i^G qiG 分别表示GMF部分的用户和物品嵌入向量。

  1. MLP 部分

通过多层非线性变换,MLP 部分的用户与物品嵌入向量先进行拼接,再逐层传递,形式上可描述为:

ϕ M L P = a L ( W L T ( a L − 1 ( . . . a 2 ( W 2 T [ p u M q i M ] + b 2 ) . . . ) ) + b L ) , \phi_{MLP} = a_L(W_L^T (a_{L-1}(...a_2(W_2^T [p_u^M \quad q_i^M] + b_2)...)) + b_L), ϕMLP=aL(WLT(aL1(...a2(W2T[puMqiM]+b2)...))+bL),

其中, p u M p_u^M puM q i M q_i^M qiM 分别表示MLP部分的用户和物品嵌入向量; a L ( ⋅ ) a_L(\cdot) aL() 是激活函数, W L W_L WL b L b_L bL 是MLP的权重和偏置参数。

  1. 融合与预测

最后,GMF和MLP的输出通过全连接层进行融合并计算最终预测:

y ^ u i = σ ( h T [ ϕ G M F ϕ M L P ] ) \hat{y}_{ui} = \sigma(h^T [\phi_{GMF} \quad \phi_{MLP}]) y^ui=σ(hT[ϕGMFϕMLP])

其中, σ ( ⋅ ) \sigma(\cdot) σ() 是Sigmoid激活函数, h T h^T hT 是融合层的权重。

该融合策略使得模型可以分别从不同角度捕捉用户与物品的特征,并通过联合表示进一步提升预测准确性与模型灵活性。

4.代码实现

以下代码段展示了基于 PyTorch 的 NeuralCF 模型实现,包括模型配置、数据集构建与模型定义。

模型配置与数据集构建

class Config:num_users = 1000num_items = 2000embed_dim = 16hidden_dims = [64, 32, 16]batch_size = 32lr = 0.001num_epochs = 30# 自定义数据集类
class CFDataset(Dataset):def __init__(self, num_samples=10000):# 生成示例数据(实际使用时替换为真实数据)self.user_ids = np.random.randint(0, Config.num_users, size=num_samples)self.item_ids = np.random.randint(0, Config.num_items, size=num_samples)self.labels = np.random.randint(0, 2, size=num_samples).astype(np.float32)def __len__(self):return len(self.user_ids)def __getitem__(self, idx):return (torch.tensor(self.user_ids[idx], dtype=torch.long),torch.tensor(self.item_ids[idx], dtype=torch.long),torch.tensor(self.labels[idx], dtype=torch.float))

NeuralCF 模型实现

class NeuralCF(nn.Module):def __init__(self, Config):super().__init__()# 定义用户和物品的隐向量self.user_embed_gmf = nn.Embedding(Config.num_users, Config.embed_dim)  # GMF用户隐向量self.item_embed_gmf = nn.Embedding(Config.num_items, Config.embed_dim)  # GMF物品隐向量self.user_embed_mlp = nn.Embedding(Config.num_users, Config.embed_dim)  # MLP用户隐向量self.item_embed_mlp = nn.Embedding(Config.num_items, Config.embed_dim)  # MLP物品隐向量# MLP层input_dim = 2 * Config.embed_dimmlp_layers = []for output_dim in Config.hidden_dims:mlp_layers.append(nn.Linear(input_dim, output_dim))mlp_layers.append(nn.ReLU())input_dim = output_dimself.mlp = nn.Sequential(*mlp_layers)# 输出层total_dim = Config.embed_dim + Config.hidden_dims[-1]  # GMF + MLP层维度self.fc = nn.Sequential(nn.Linear(total_dim, 1),nn.Sigmoid())def forward(self, user_ids, item_ids):# 获取用户和物品的隐向量user_emb_gmf = self.user_embed_gmf(user_ids)item_emb_gmf = self.item_embed_gmf(item_ids)user_emb_mlp = self.user_embed_mlp(user_ids)item_emb_mlp = self.item_embed_mlp(item_ids)# GMF: 逐元素乘积gmf = user_emb_gmf * item_emb_gmf# MLP: 拼接并通过多层感知机concat_emb = torch.cat([user_emb_mlp, item_emb_mlp], dim=1)mlp = self.mlp(concat_emb)# 拼接GMF和MLP的结果neuralcf_emb = torch.cat([mlp, gmf], dim=1)# 输出层output = self.fc(neuralcf_emb).squeeze()return output

5. NeuralCF的优势

NeuralCF 模型通过引入深度神经网络,有效突破了传统矩阵分解方法的线性限制,能够捕捉用户与物品之间的复杂非线性交互。其主要优势包括:

  • 非线性建模能力:利用多层神经网络对用户与物品的隐向量进行非线性组合,充分发掘潜在高阶交互信息。
  • 架构灵活性:模型结构可以根据实际问题需求灵活调整隐藏层层数和神经元数量,适应不同数据规模与复杂度。
  • 优异的泛化性能:深度学习框架使得 NeuralCF 在处理稀疏数据时能够更好地防止过拟合,提升了模型的泛化能力。

Reference

[1]. He, X., Liao, L., Zhang, H., Nie, L., Hu, X., & Chua, T.-S. (2017). Neural Collaborative Filtering. In Proceedings of the 26th International Conference on World Wide Web (WWW ’17), 173–182. ACM.

[2]. 王喆 《深度学习推荐系统》

相关文章:

NeuralCF 模型:神经网络协同过滤模型

实验和完整代码 完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main 引言 NeuralCF 模型由新加坡国立大学研究人员于 2017 年提出,其核心思想在于将传统协同过滤方法与深度学习技术相结…...

【前端】【Ts】【知识点总结】TypeScript知识总结

一、总体概述 TypeScript 是 JavaScript 的超集,主要通过静态类型检查和丰富的类型系统来提高代码的健壮性和可维护性。它涵盖了从基础数据类型到高级类型、从函数与对象的类型定义到类、接口、泛型、模块化及装饰器等众多知识点。掌握这些内容有助于编写更清晰、结…...

JAVA架构师进阶之路

JAVA架构师进阶之路 前言 苦于网络上充斥的各种java知识,多半是互相抄袭,导致很多后来者在学习java知识中味同嚼蜡,本人闲暇之余整理了进阶成为java架构师所必须掌握的核心知识点,后续会不断扩充。 废话少说,直接上正…...

掌握@PostConstruct与@PreDestroy,优化Spring Bean的初始化和销毁

在Spring中,PostConstruct和PreDestroy注解就像是对象的“入职”和“离职”仪式。 1. PostConstruct注解:这个注解标记的方法就像是员工入职后的“岗前培训”。当一个对象(比如一个Bean)被Spring容器创建并注入依赖后,…...

Java设计模式:行为型模式→状态模式

Java 状态模式详解 1. 定义 状态模式(State Pattern)是一种行为型设计模式,它允许对象在内部状态改变时改变其行为。状态模式通过将状态需要的行为封装在不同的状态类中,实现对象行为的动态改变。该模式的核心思想是分离不同状态…...

景联文科技:专业数据采集标注公司 ,助力企业提升算法精度!

随着人工智能技术加速落地,高质量数据已成为驱动AI模型训练与优化的核心资源。据统计,全球AI数据服务市场规模预计2025年突破200亿美元,其中智能家居、智慧交通、医疗健康等数据需求占比超60%。作为国内领先的AI数据服务商,景联文…...

ES面试题

1、Elasticsearch的基本构成: (1)index 索引: 索引类似于mysql 中的数据库,Elasticesearch 中的索引是存在数据的地方,包含了一堆有相似结构的文档数据。 (2)type 类型&#xff1a…...

LabVIEW2025中文版软件安装包、工具包、安装教程下载

下载链接:LabVIEW及工具包大全-三易电子工作室http://blog.eeecontrol.com/labview6666 《LabVIEW2025安装图文教程》 1、解压后,双击install.exe安装 2、选中“我接受上述2条许可协议”,点击下一步 3、点击下一步,安装NI Packa…...

算法与数据结构(合并K个升序链表)

思路 有了合并两个链表的基础后,这个的一种方法就是可以进行顺序合并,我们可以先写一个函数用来合并两个链表,再在合并K个链表的的函数中循环调用它。 解题过程 解析这个函数 首先,可以先判断,如果a为空&#xff0c…...

洛谷 P4552 [Poetize6] IncDec Sequence C语言

P4552 [Poetize6] IncDec Sequence - 洛谷 | 计算机科学教育新生态 题目描述 给定一个长度为 n 的数列 a1​,a2​,…,an​,每次可以选择一个区间 [l,r],使这个区间内的数都加 1 或者都减 1。 请问至少需要多少次操作才能使数列中的所有数都一样&#…...

保姆级教程Docker部署Zookeeper官方镜像

目录 1、安装Docker及可视化工具 2、创建挂载目录 3、运行Zookeeper容器 4、Compose运行Zookeeper容器 5、查看Zookeeper运行状态 6、验证Zookeeper是否正常运行 1、安装Docker及可视化工具 Docker及可视化工具的安装可参考:Ubuntu上安装 Docker及可视化管理…...

javaEE-6.网络原理-http

目录 什么是http? http的工作原理: 抓包工具 fiddler的使用 HTTP请求数据: 1.首行:​编辑 2.请求头(header) 3.空行: 4.正文(body) HTTP响应数据 1.首行:​编辑 2.响应头 3.空行: 4.响应正文…...

【戒抖音系列】短视频戒除-1-对推荐算法进行干扰

如今推荐算法已经渗透到人们生活的方方面面,尤其是抖音等短视频核心就是推荐算法。 【短视频的危害】 1> 会让人变笨,慢慢让人丧失注意力与专注力 2> 让人丧失阅读长文的能力 3> 让人沉浸在一个又一个快感与嗨点当中。当我们刷短视频时&#x…...

9.建造者模式 (Builder Pattern)

定义 建造者模式(Builder Pattern)是一种创建型设计模式,旨在将复杂对象的构建过程与它的表示分离,使得同样的构建过程可以创建不同的表示。该模式的核心思想是通过一步步地构建一个复杂的对象,每个步骤独立且可扩展&…...

OpenCV:特征检测总结

目录 一、什么是特征检测? 二、OpenCV 中的常见特征检测方法 1. Harris 角点检测 2. Shi-Tomasi 角点检测 3. Canny 边缘检测 4. SIFT(尺度不变特征变换) 5. ORB 三、特征检测的应用场景 1. 图像匹配 2. 运动检测 3. 自动驾驶 4.…...

Clion开发STM32时使用stlink下载程序与Debug调试

一、下载程序 先创建一个文件夹: 命名:stlink.cfg 写入以下代码: # choose st-link/j-link/dap-link etc. #adapter driver cmsis-dap #transport select swdsource [find interface/stlink.cfg]transport select hla_swdsource [find target/stm32f4x.…...

电脑开机键一闪一闪打不开

家人们谁懂啊!本来打算愉快地开启游戏时光,或者高效处理工作任务,结果按下电脑开机键后,它就只是一闪一闪的,怎么都打不开。相信不少朋友都遭遇过这种令人崩溃的场景,满心的期待瞬间化为焦急与无奈。电脑在…...

深度学习 Pytorch 基础网络手动搭建与快速实现

为了方便后续练习的展开,我们尝试自己创建一个数据生成器,用于自主生成一些符合某些条件、具备某些特性的数据集。 导入相关的包 # 随机模块 import random# 绘图模块 import matplotlib as mpl import matplotlib.pyplot as plt# 导入numpy import nu…...

Sqli-labs靶场实录(一):Basic Challenges

sqli-labs靶场实录:Basic Challenges sql手注基本流程Less-11.1探测注入点1.2判断字段数1.3判断回显位1.4提取数据库基本信息1.5拖取敏感数据 Less-2Less-3Less-4Less5爆表爆列名 Less6爆库爆表爆列名 Less7猜解数据库长度逐字符爆破数据库名 Less8爆库 Less9爆库 Less10Less11…...

2024最新版Node.js详细安装教程(含npm配置淘宝最新镜像地址)

一:Node.js安装 浏览器中搜索Nodejs,或直接用网址:Node.js — 在任何地方运行 JavaScript 建议此处下载长期支持版本(红框内): 开始下载,完成后打开文件: 进入安装界面,在此处勾选,再点击n…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中,接口是一种抽象类型,它定义了一组方法的集合: // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的: // 矩形结构体…...

1.3 VSCode安装与环境配置

进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件,然后打开终端,进入下载文件夹,键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...

JVM暂停(Stop-The-World,STW)的原因分类及对应排查方案

JVM暂停(Stop-The-World,STW)的完整原因分类及对应排查方案,结合JVM运行机制和常见故障场景整理而成: 一、GC相关暂停​​ 1. ​​安全点(Safepoint)阻塞​​ ​​现象​​:JVM暂停但无GC日志,日志显示No GCs detected。​​原因​​:JVM等待所有线程进入安全点(如…...

基于IDIG-GAN的小样本电机轴承故障诊断

目录 🔍 核心问题 一、IDIG-GAN模型原理 1. 整体架构 2. 核心创新点 (1) ​梯度归一化(Gradient Normalization)​​ (2) ​判别器梯度间隙正则化(Discriminator Gradient Gap Regularization)​​ (3) ​自注意力机制(Self-Attention)​​ 3. 完整损失函数 二…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

怎么让Comfyui导出的图像不包含工作流信息,

为了数据安全,让Comfyui导出的图像不包含工作流信息,导出的图像就不会拖到comfyui中加载出来工作流。 ComfyUI的目录下node.py 直接移除 pnginfo(推荐)​​ 在 save_images 方法中,​​删除或注释掉所有与 metadata …...

从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障

关键领域软件测试的"安全密码":Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力,从金融交易到交通管控,这些关乎国计民生的关键领域…...

人工智能 - 在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型

在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型。这些平台各有侧重,适用场景差异显著。下面我将从核心功能定位、典型应用场景、真实体验痛点、选型决策关键点进行拆解,并提供具体场景下的推荐方案。 一、核心功能定位速览 平台核心定位技术栈亮…...

用js实现常见排序算法

以下是几种常见排序算法的 JS实现,包括选择排序、冒泡排序、插入排序、快速排序和归并排序,以及每种算法的特点和复杂度分析 1. 选择排序(Selection Sort) 核心思想:每次从未排序部分选择最小元素,与未排…...

在ubuntu等linux系统上申请https证书

使用 Certbot 自动申请 安装 Certbot Certbot 是 Let’s Encrypt 官方推荐的自动化工具,支持多种操作系统和服务器环境。 在 Ubuntu/Debian 上: sudo apt update sudo apt install certbot申请证书 纯手动方式(不自动配置)&…...