找回密码
 立即注册
首页 业界区 业界 ESMM学习笔记:如何解决CVR预估中的样本选择偏差与数据 ...

ESMM学习笔记:如何解决CVR预估中的样本选择偏差与数据稀疏难题

忿媚饱 前天 15:07
ESMM模型精解:如何解决CVR预估中的样本选择偏差与数据稀疏难题

引言

在现代推荐系统与计算广告中,对点击后转化率(Post-Click Conversion Rate, CVR)的精准预估是优化平台收益与用户体验的核心环节。然而,传统的 CVR 预估模型在工业实践中普遍面临两大技术瓶颈:样本选择偏差(Sample Selection Bias, SSB)与数据稀疏性(Data Sparsity, DS)。这两个问题严重制约了模型的泛化能力与学习效率。
为了应对这些挑战,阿里巴巴的研究团队提出了“全体空间多任务模型”(Entire Space Multi-Task Model, ESMM),通过一种巧妙的建模思路,从根本上解决了上述难题。本文将从模型背景、核心原理、梯度更新机制、优缺点及适用场景等角度,对 ESMM 模型进行系统性剖析。
一、 CVR 预估的核心挑战:问题的根源

在深入了解 ESMM 之前,我们必须先理解它所要解决的问题根源。

  • 1.1 样本选择偏差 (SSB)
    SSB 问题源于模型训练与预测阶段的数据分布不一致 。具体来说,CVR 模型的目标是预估用户在“点击”行为发生后产生“转化”的概率。因此,模型的训练样本天然地被限定在已发生点击行为的样本空间内。然而,在实际的线上推理阶段,模型需要对所有曝光的物品进行 CVR 预估,以便与点击率(CTR)等指标结合进行排序。
    如下图所示,训练空间(Click Space)仅仅是推理空间(Impression Space)的一个很小的子集 。在这种有偏的样本上训练出的模型,直接应用于全体样本时,其泛化能力会受到严重损害。
    1.png

  • 1.2 数据稀疏性 (DS)
    数据稀疏性是另一个严峻的挑战。在真实的业务场景中,用户从曝光到点击、再到转化的行为是一个发生率逐级递减的“漏斗”。用于 CVR 任务训练的“点击且转化”的样本,数量远远少于 CTR 任务的“点击”样本,更不用说海量的“曝光”样本了。例如,在淘宝的生产数据集中,用于 CVR 任务的样本量仅为 CTR 任务的 4% 左右。这种极端的正样本稀疏性,使得模型,尤其是需要海量数据进行学习的深度网络,难以充分学习到有效的特征表示,拟合过程非常困难。
二、 ESMM 模型核心原理:巧妙的全局建模

面对上述挑战,ESMM 并没有在模型结构上进行复杂的改造,而是从问题的定义本身出发,提出了一种全新的建模范式。
2.png


  • 2.1 核心思想概述
    ESMM 的核心思想是:不再直接建模有偏的 pCVR,而是通过利用用户行为的“曝光 → 点击 → 转化”这一序列依赖关系,在完整的、无偏的曝光样本空间上进行多任务联合建模
  • 2.2 关键公式与模型架构

    • 公式拆解
      ESMM 框架的理论基石是以下概率链式法则公式:

      \[p(y=1, z=1 | x) = p(y=1 | x) \times p(z=1 | y=1, x)\]
      其中,x 代表曝光的特征,y=1 代表点击事件,z=1 代表转化事件。这个公式可以被解读为:

      \[pCTCVR = pCTR \times pCVR\]

      • pCTR (Post-View Click-Through Rate): 曝光后点击率,即 p(y=1|x)。
      • pCVR (Post-Click Conversion Rate): 点击后转化率,即 p(z=1|y=1,x)。这是我们最终想要求解的目标。
      • pCTCVR (Post-View Click-Through & Conversion Rate): 曝光后点击且转化率,即 p(y=1,z=1|x)。

    • 架构解析
      基于上述公式,ESMM 设计了一个由两个子网络(任务塔)构成的多任务学习架构。

      • CTR Tower: 一个独立的子网络,用于根据输入特征 x 预测 pCTR。
      • CVR Tower: 另一个独立的子网络,用于根据相同的输入特征 x 预测 pCVR。
      • 共享 Embedding 层: 两个任务塔共享底层的特征嵌入层(Embedding Layer)。这是缓解数据稀疏性问题的关键。
      • *最终输出**: CTR 塔和 CVR 塔的输出 pCTR 和 pCVR 相乘,得到 pCTCVR 的预测值。


  • 2.3 手段总结
    ESMM 解决两大难题的手段可以清晰地归纳为:

    • 针对 SSB:模型不直接使用有偏的点击样本来监督 CVR 塔。而是通过监督在全体曝光样本上都具有明确标签的 pCTR 和 pCTCVR 两个任务,让 CVR 塔作为中间变量被隐式地学习。由于 pCTR 和 pCTCVR 都是在无偏的全空间上建模的,因此间接得到的 pCVR 也自然地适用于全空间,从根本上解决了样本选择偏差问题。
    • 针对 DS:通过让 CVR 任务塔与拥有海量训练样本的 CTR 任务塔共享底层的 Embedding 表示,实现了参数的迁移学习。CVR 任务可以从 CTR 任务中“借力”,学习到更鲁棒的特征表示,从而有效缓解了自身训练数据稀疏的问题。

三、 ESMM 的参数更新机制:梯度的流动

理解 ESMM 的参数更新机制,有助于我们深入其联合训练的本质。

  • 3.1 联合损失函数
    ESMM 的总损失函数由 CTR 任务和 CTCVR 任务的损失加权构成,不包含 CVR 任务的直接损失项。

    \[L(\theta_{ctr}, \theta_{cvr}) = \sum_{i=1}^{N} L_{ctr}(y_i, f(x_i; \theta_{ctr})) + \sum_{i=1}^{N} L_{ctcvr}(y_i\&z_i, f(x_i; \theta_{ctr}) \times f(x_i; \theta_{cvr}))\]
    其中 θ_ctr 和 θ_cvr 分别是 CTR 和 CVR 网络的参数。
3.png


  • 3.2 梯度流向分析
    在反向传播过程中,梯度从两个损失源头出发:

    • 来自 L_ctr 的梯度路径:这部分梯度只与 pCTR 相关。因此,它会反向传播更新 CTR 塔 的参数,并继续向下更新 共享 Embedding 层 的参数。
    • 来自 L_ctcvr 的梯度路径:这部分梯度是联合训练的关键。根据链式求导法则,对于乘法操作 pCTCVR = pCTR * pCVR,梯度会“兵分两路”:

      • 一路流向 pCVR,更新 CVR 塔 的参数,并继续向下更新 共享 Embedding 层
      • 另一路流向 pCTR,更新 CTR 塔 的参数,并继续向下更新 共享 Embedding 层


  • 3.3 参数更新归属总结
    综合两条路径,各个模块参数的最终梯度来源如下:

    • CVR 塔参数仅由 L_ctcvr 更新。
    • CTR 塔参数:由 L_ctr 和 L_ctcvr 共同更新。
    • 共享 Embedding 参数:由 L_ctr 和 L_ctcvr 共同更新。

四、 优缺点及适用场景分析


  • 4.1 主要优点

    • 根本解决 SSB:通过在全空间上建模,从理论上彻底消除了传统 CVR 预估的样本选择偏差。
    • 有效缓解 DS:利用迁移学习的思想,共享 Embedding,极大地缓解了 CVR 任务数据稀疏的问题。
    • 结构优雅:模型设计巧妙,乘法形式不仅避免了除法可能带来的数值不稳定问题,还能保证 pCVR 的值在 [0,1] 范围内。

  • 4.2 潜在局限性

    • 强依赖于行为序列:模型的设计强依赖于“曝光 → 点击 → 转化”这样清晰、固定的序列依赖关系。如果任务之间是平行的,或者不遵循这种漏斗模式,ESMM 则不适用。
    • 任务关系固定:模型假设了 pCTR 和 pCVR 之间是简单的乘法关系,这可能无法捕捉现实世界中更复杂的任务关联。

  • 4.3 适用场景
    ESMM 极其适用于具有明显序列依赖关系的多阶段用户行为预测场景。

    • 电商平台:经典的“曝光 → 点击 → 购买”转化漏斗。
    • 信息流推荐:“曝光 → 点击 → 关注/收藏/分享”等多步转化行为。
    • 在线广告:“广告展示 → 点击 → 表单提交/App下载”等转化链路。

代码实现
  1. import torch
  2. import torch.nn as nn
  3. class ESMM(nn.Module):
  4.     """
  5.     ESMM: Entire Space Multi-Task Model PyTorch Implementation.
  6.     This class implements the ESMM model as described in the paper:
  7.     "Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate"
  8.     核心思想 Core Ideas:
  9.     1.  **全空间建模 (Entire Space Modeling)**: 通过 `pCTCVR = pCTR * pCVR` 的关系,在全体曝光样本上联合训练CTR和CTCVR任务,从而无偏地学习CVR
  10.     2.  **特征表示共享 (Shared Feature Representation)**: CVR任务和CTR任务共享底层的Embedding层,有效缓解CVR任务因样本稀疏导致的学习不充分问题
  11.     Args:
  12.         feature_columns (dict): 一个描述输入稀疏特征的字典,key为特征名,value为该特征的词表大小 (vocabulary size)。
  13.                                 Example: {'user_id': 10000, 'item_id': 5000}
  14.         embedding_dim (int): Embedding向量的维度。
  15.         tower_dims (list): 一个定义两个任务塔 (Tower) 隐藏层维度和结构的列表。
  16.                            Example: [256, 128]
  17.     """
  18.     def __init__(self, feature_columns, embedding_dim, tower_dims):
  19.         super(ESMM, self).__init__()
  20.         self.feature_columns = feature_columns
  21.         self.embedding_dim = embedding_dim
  22.         # --- 模块一: 共享的 Embedding 层 (Shared Embedding Layer) ---
  23.         # 这是ESMM解决数据稀疏性(Data Sparsity)问题的关键。
  24.         # CVR和CTR任务共享同一套Embedding参数,使得CVR任务可以从更丰富的CTR样本中学习特征表示
  25.         self.embedding_layer = nn.ModuleDict({
  26.             feat: nn.Embedding(vocab_size, self.embedding_dim)
  27.             for feat, vocab_size in self.feature_columns.items()
  28.         })
  29.         # 计算两个任务塔(MLP)的输入维度
  30.         # 输入维度 = 特征数量 * 每个特征的Embedding维度
  31.         self.tower_input_dim = len(self.feature_columns) * self.embedding_dim
  32.         # --- 模块二: CTR 任务塔 (CTR Tower) ---
  33.         # 这个子网络负责预测点击率 pCTR。
  34.         self.ctr_tower = self._build_mlp_tower(self.tower_input_dim, tower_dims)
  35.         # --- 模块三: CVR 任务塔 (CVR Tower) ---
  36.         # 这个子网络负责预测点击后转化率 pCVR。
  37.         # 注意它的输入与CTR Tower完全相同,都是来自共享的Embedding层。
  38.         self.cvr_tower = self._build_mlp_tower(self.tower_input_dim, tower_dims)
  39.     def _build_mlp_tower(self, input_dim, tower_dims):
  40.         """一个辅助函数,用于构建MLP网络(即任务塔)。"""
  41.         layers = []
  42.         for hidden_dim in tower_dims:
  43.             layers.append(nn.Linear(input_dim, hidden_dim))
  44.             layers.append(nn.ReLU())
  45.             input_dim = hidden_dim
  46.         # 输出层,得到一个logit
  47.         layers.append(nn.Linear(input_dim, 1))
  48.         return nn.Sequential(*layers)
  49.     def forward(self, x):
  50.         """
  51.         ESMM的前向传播逻辑。
  52.         Args:
  53.             x (dict): 输入的特征字典,key为特征名,value为特征值的tensor。
  54.                       Example: {'user_id': tensor([[1], [2]]), 'item_id': tensor([[10], [12]])}
  55.         Returns:
  56.             tuple: 包含三个预测值的元组 (pCTR, pCVR, pCTCVR)。
  57.                    - pCTR: 预测的点击率 (Predicted CTR)
  58.                    - pCVR: 预测的点击后转化率 (Predicted CVR)
  59.                    - pCTCVR: 预测的点击且转化率 (Predicted CTCVR)
  60.         """
  61.         # --- 流程 1: 特征 Embedding (共享) ---
  62.         # 将输入的稀疏特征ID通过共享的Embedding层转换为稠密向量
  63.         embedded_features = [
  64.             self.embedding_layer[feat](x[feat]) for feat in self.feature_columns
  65.         ]
  66.         # 将所有特征的Embedding向量拼接成一个长向量,作为两个任务塔的共同输入
  67.         concatenated_embeddings = torch.flatten(torch.cat(embedded_features, dim=1), start_dim=1)
  68.         # --- 流程 2: CTR 任务预测 ---
  69.         # 将拼接后的向量输入CTR塔,得到CTR的logit
  70.         ctr_logit = self.ctr_tower(concatenated_embeddings)
  71.         # 通过Sigmoid激活函数得到预测的点击率 pCTR
  72.         pCTR = torch.sigmoid(ctr_logit)
  73.         # --- 流程 3: CVR 任务预测 ---
  74.         # 将【相同】的拼接向量输入CVR塔,得到CVR的logit
  75.         cvr_logit = self.cvr_tower(concatenated_embeddings)
  76.         # 通过Sigmoid激活函数得到预测的点击后转化率 pCVR
  77.         pCVR = torch.sigmoid(cvr_logit)
  78.         # --- 流程 4: 计算 pCTCVR ---
  79.         # 这是ESMM解决样本选择偏差(SSB)问题的核心步骤。
  80.         # 通过pCTR和pCVR的乘积,将CVR的预测从“点击空间”隐式地转换到“全曝光空间”
  81.         pCTCVR = pCTR * pCVR
  82.         return pCTR, pCVR, pCTCVR
复制代码
总结

ESMM 模型为 CVR 预估领域提供了一个里程碑式的解决方案。它最大的贡献在于,没有将思路局限在设计更复杂的网络结构上,而是通过重构学习目标和建模空间,从根本上规避了 SSB 和 DS 这两个长期存在的业界难题。其优雅的设计、显著的效果以及在淘宝等大规模工业场景中的成功落地,使其成为推荐和广告领域从业者必学的经典模型之一。

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册