社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

Nature主刊文章分享-可应用于临床结构化小型数据的深度学习框架精读笔记(一)

灵活胖子的科研进步之路 • 2 月前 • 96 次点击  
文章首页
文章首页

全文地址:https://pmc.ncbi.nlm.nih.gov/articles/PMC11711098/

代码地址:https://priorlabs.ai/tabpfn-nature/

数据地址:https://zenodo.org/records/13981285

思维导图
思维导图

使用表格基础模型对小数据进行精确预测

表格数据,即以行和列组织的电子表格,在从生物医学到粒子物理、经济学和气候科学等科学领域中无处不在。基于其余列填充标签列缺失值这一基本预测任务对于生物医学风险模型、药物研发和材料科学等各种应用至关重要。 尽管深度学习已经彻底改变了从原始数据中学习的方式,并带来了众多备受瞩目的成功案例,但在过去 20 年中,梯度提升决策树- gradient-boosted decision trees-在表格数据处理方面占据主导地位。在此,我们介绍表格先验数据拟合网络(TabPFN),这是一种表格基础模型,在样本数量多达 10,000 的数据集上大幅超越此前所有方法,且训练时间大幅缩短。在 2.8 秒内,TabPFN 在分类任务中的表现优于经过 4 小时调优的最强基线集成模型。作为一种基于生成式变换器的基础模型,该模型还允许进行微调、数据生成、密度估计和学习可重复使用的嵌入。TabPFN 是一种在数百万个合成数据集上学习得到的学习算法,展示了这种方法在算法开发方面的强大能力。通过提高不同领域的建模能力,TabPFN 有潜力加速科学发现并改善各个领域的重要决策。

在人工智能的历史进程中,手动创建的算法组件已被性能更优的端到端学习组件所取代。计算机视觉中手工设计的特征,如 SIFT(尺度不变特征变换)和 HOG(方向梯度直方图),已被学习型卷积所替代;自然语言处理中基于语法的方法已被学习型变换器所取代;游戏中定制的开局和终局库设计已被端到端学习策略所超越。 在此,我们将这种端到端学习扩展到无处不在的表格数据领域。

深度学习方法在处理表格数据方面一直存在困难,这是由于数据集之间的异质性以及原始数据本身的异质性:表格包含列(也称为特征),这些列具有各种规模和类型(布尔型、分类型、有序型、整数型、浮点型)不平衡或缺失的数据、不重要的特征、异常值等等。这使得非深度学习方法,如基于树的模型,成为迄今为止最有力的竞争者。

然而,这些传统机器学习模型存在一些缺点。如果不进行大量修改,它们在进行分布外预测时表现不佳,并且在知识从一个数据集转移到另一个数据集方面也表现较差。最后,它们难以与神经网络相结合,因为它们不传播梯度。

作为一种补救措施,我们引入了 TabPFN,这是一种适用于中小规模表格数据的基础模型。这种新的有监督表格学习方法可以应用于任何中小规模数据集,并且在包含多达 10,000 个样本和 500 个特征的数据集上表现出卓越的性能。在单次前向传播中,TabPFN 在我们的基准测试上显著优于最先进的基线方法,包括梯度提升决策树,即使这些基线方法被允许进行 4 小时的调优,其在分类任务上的速度提升达到 5140 倍,在回归任务上达到 3000 倍。最后,我们展示了 TabPFN 的各种基础模型特性,包括微调、生成能力和密度估计。

Principled in-context learning-原则性的上下文学习

TabPFN利用上下文学习(ICL),即促使大型语言模型取得惊人性能的相同机制,来生成一种经过充分学习的强大表格预测算法。虽然上下文学习(ICL)最初是在大型语言模型中被发现的,但近期的研究表明,transformers可以通过 ICL 学习简单算法,如逻辑回归。先验数据拟合网络(PFNs)- Prior-data Fitted Net- works 已经证明,即使是复杂算法,如高斯过程和贝叶斯神经网络,也可以用 ICL 来近似。ICL 使我们能够学习更广泛的可能算法空间,包括那些不存在闭式解的情况。


Transformers是一种深度学习架构,在自然语言处理及其他领域有重要应用。

  • 核心机制:基于注意力机制运作。与传统的循环神经网络(RNN)或卷积神经网络(CNN)不同,它不依赖于序列的顺序处理或固定的局部感受野。注意力机制能让模型动态地关注输入序列的不同部分,从而更好地捕捉长距离依赖关系。例如在文本处理中,它可以根据上下文语义,为每个单词分配不同的注意力权重,以确定该单词与其他单词的关联程度。
  • 结构特点:通常由多个编码器和解码器层组成。编码器负责对输入进行编码,将其转换为一种包含丰富语义信息的表示形式;解码器则根据编码器的输出和其他相关信息生成目标序列。在机器翻译任务中,编码器处理源语言文本,解码器生成目标语言文本。每层内部包含多头注意力机制、前馈神经网络等组件,多头注意力机制通过多个并行的注意力头从不同角度捕捉信息,进一步增强了模型的表达能力。
  • 应用领域:在自然语言处理方面,广泛应用于机器翻译、文本生成、问答系统等任务,如著名的 GPT 系列模型就是基于 Transformers 架构,在语言生成任务上取得了显著成果;在计算机视觉领域,也有研究将其应用于图像分类、目标检测等任务,通过将图像转换为序列形式进行处理;在语音识别领域,同样有潜力改善语音信号的处理和转录效果。

先验数据拟合网络(Prior-data Fitted Networks,PFNs)是一种机器学习模型架构。它的核心特点在于能够利用先验知识和大量数据进行训练和学习。通过在合成数据上进行训练,PFNs 可以学习到数据的潜在模式和规律。在 TabPFN 中,就利用了 PFNs 的思想,先在数百万个具有不同特征和目标关系的合成表格数据集上训练模型,让模型学习如何根据输入特征预测目标值。这种训练方式使得模型能够近似复杂的算法,如高斯过程和贝叶斯神经网络等,从而在面对实际的表格数据预测任务时,能够快速准确地进行预测,并且具有一定的泛化能力,适用于不同的数据集和任务场景。


闭式解(a closed - form solution)是指能够用有限次的基本运算(如加、减、乘、除、指数、对数、三角函数等)表示的精确解。

在数学和工程问题中,如果一个方程或问题存在闭式解,那么就可以通过直接计算得出精确的结果,而不需要借助迭代、数值逼近等方法。

在机器学习算法中,有些简单的模型或问题可能存在闭式解,比如线性回归在某些条件下可以通过最小二乘法得到闭式解来确定模型参数。但在很多复杂的情况下,如深度学习中的一些模型训练和优化问题,由于数据的复杂性和模型的非线性等因素,往往不存在闭式解,需要使用基于梯度下降等迭代方法来逐步逼近最优解。


梯度下降是一种常用的优化算法,用于寻找函数的最小值。

在机器学习中,通常将损失函数看作是需要优化的目标函数。其原理如下:

  • 首先,对目标函数求导,得到函数在某一点的梯度。梯度是一个向量,它指向函数增长最快的方向。
  • 然后,在每次迭代中,沿着梯度的反方向更新参数。这是因为沿着梯度的反方向是函数值下降最快的方向。
  • 不断重复这个过程,直到满足停止条件,如达到一定的迭代次数、函数值的变化小于某个阈值或者梯度的范数小于某个值等。

通过这种方式,梯度下降算法能够逐步调整参数,使得目标函数的值不断减小,最终找到一个局部最小值或者在一定程度上接近全局最小值。在实际应用中,如线性回归、神经网络等模型的训练中广泛使用梯度下降或其变体(如随机梯度下降、小批量梯度下降等)来优化模型的参数,以提高模型的性能。


我们基于 TabPFN的一个初步版本进行研究,该版本在原则上展示了上下文学习 对表格数据的适用性,但存在诸多限制,致使其在大多数情况下并不适用。基于一系列的改进,新的 TabPFN 能够扩展到 50 倍大的数据集,支持回归任务、分类数据和缺失值,并且对不重要的特征和异常值具有鲁棒性。

TabPFN背后的核心思想是生成大量的合成表格数据集,然后训练一个基于Transformer的神经网络,让其学习解决这些合成预测任务。尽管传统方法需要针对诸如缺失值等数据难题手动设计解决方案,但我们的方法通过解决包含这些难题的合成任务,自主学习有效的策略。这种方法将上下文学习(ICL)作为一种基于示例的算法声明式编程框架。我们通过生成展示所需行为的多样化合成数据集,来设计理想的算法行为,然后训练一个模型对满足该行为的算法进行编码。这就将算法设计过程从编写明确指令转变为定义输入 - 输出示例,为在各个领域创建算法开辟了可能性。在此,我们将这种方法应用于具有重大影响力的表格学习领域,生成一种强大的表格预测算法。

我们的上下文学习(ICL)方法与标准的有监督深度学习有着根本区别。通常情况下,模型是针对每个数据集进行训练的,依据诸如Adam这类手工设计的权重更新算法,在单个样本或批次上更新模型参数。

在推理阶段,已训练好的模型应用于测试样本。相比之下,我们的方法是跨数据集进行训练,并且在推理时应用于整个数据集,而非单个样本。在应用于实际数据集之前,该模型首先在代表不同预测任务的数百万个合成数据集上进行预训练。在推理阶段,模型接收一个包含有标记训练样本和无标记测试样本的全新数据集,并通过一次神经网络前向传播,在这个数据集上完成训练和预测。

图1和图2概述了我们的方法:

  1. 数据生成:我们定义一个生成过程(称为我们的先验),以合成各种表格数据集,这些数据集的特征与目标之间具有不同的关系,旨在涵盖我们模型可能遇到的广泛潜在场景。我们从该生成过程中采样数百万个数据集。对于每个数据集,一部分样本的目标值被屏蔽,以此模拟一个有监督的预测问题。我们先验设计的更多细节在“基于因果模型的合成数据”部分展示。

  2. 预训练:我们训练一个Transformer模型,即我们的先验数据拟合网络(PFN),在给定输入特征和未屏蔽样本作为上下文的情况下,预测所有合成数据集的屏蔽目标值。这一步在模型开发过程中仅执行一次,学习一种通用的学习算法,可用于预测任何数据集。

  3. 实际预测:训练好的模型现在可应用于任意未知的实际数据集。训练样本作为上下文提供给模型,模型通过上下文学习(ICL)预测这些未知数据集的标签。


图1
图1

这张图展示了TabPFN(一种表格预测算法)的训练和预测过程,以及其网络结构。

图a部分:训练和预测流程

  • 左侧(训练阶段):TabPFN在合成数据集上进行训练。合成数据集由训练数据( 和 )和测试数据(,其中  最初是未知的)组成。TabPFN是一个由参数  决定的神经网络,将整个数据集作为输入,并通过前向传播进行预测。训练损失()是在数百万个数据集上进行优化的目标,以调整模型参数。
  • 右侧(预测阶段):经过训练的TabPFN可以应用于任意未知的实际数据集。实际数据集同样包含训练数据( 和 )和测试数据( 待预测),模型通过前向传播对测试数据进行预测。

图b部分:TabPFN的网络结构

  • 输入数据集:展示了一个简单的输入表格,包含训练样本和一个待预测的测试样本。训练样本有特征值( 和 )和对应的目标值(),测试样本只有特征值,目标值未知。
  • 2D TabPFN层(12x):这是TabPFN的核心结构,由多个层组成。
    • 1D特征注意力(1D feature attention):对每个样本的特征之间进行注意力计算,以捕捉特征之间的关系。每个节点代表表格中的一个元素。
    • 1D样本注意力(1D sample attention):对不同样本之间进行注意力计算,以捕捉样本之间的关系。
    • 多层感知机(MLP):将经过注意力计算的特征向量转换为分段常数(黎曼)分布,最终输出预测的  分布。

通过这种方式,TabPFN能够处理表格数据,并在一次前向传播中对整个数据集进行预测,适用于各种实际应用场景。


在神经网络中,前向传播是一种基本的计算过程。以TabPFN为例,其运作如下:

  • 数据输入:将表格数据集(包含训练数据  以及测试数据 )输入到神经网络中。比如在图b中,输入表格里的训练样本特征值( 等)和测试样本特征值,这些数据就开始了前向传播的旅程。
  • 逐层计算:数据依次经过神经网络的各个层,如在TabPFN的2D TabPFN层中,先经过1D特征注意力层,在这一层会计算每个样本特征之间的关系,通过注意力机制为不同特征分配权重;接着进入1D样本注意力层,计算不同样本之间的关系;最后进入多层感知机(MLP)。
  • 输出结果:在经过一系列的计算和变换后,最终从神经网络输出预测结果。在TabPFN中,输出的是预测的  分布,像图b中右侧的预测  分布图表所展示的那样。

简单来说,前向传播就是让数据从神经网络的输入层流入,经过中间各层的处理,最终从输出层流出得到预测结果的过程,它是神经网络进行预测的基础步骤,后续的反向传播等过程也是基于前向传播的结果来进行参数调整等操作。


反向传播(Backpropagation,简称“backprop” )是训练人工神经网络的常用且有效算法,是人工智能领域核心概念。以下从几个方面详细介绍:

基本原理

其基本理论基于大学高等数学中的导数计算,尤其是链式法则。在神经网络中,它沿着从输出层到输入层的顺序,依次计算并存储目标函数有关神经网络各层的中间变量以及参数的梯度。

具体过程

  1. 前向传播:将训练集数据输入到人工神经网络(ANN)的输入层,数据经过隐藏层处理,最后到达输出层并输出结果。
  2. 反向传播:由于ANN的输出结果与实际值存在误差,计算估计值与实际值之间的误差,并将该误差从输出层向隐藏层反向传播,直至传播到输入层。
  3. 权重更新:在反向传播的过程中,根据误差调整各种参数(权重和偏置)的值。不断重复前向传播、反向传播和权重更新这三个过程,直至模型收敛(即误差达到可接受范围或满足停止条件)。

作用

通过帮助神经网络从错误中学习,调整网络权重和偏置,以降低预测误差,使网络的预测尽可能准确,是训练深度神经网络的主要方法。


正如参考文献22中所述,我们的方法也有一个理论基础。它可以被视为对由合成数据集定义的先验进行贝叶斯预测的近似。经过训练的先验数据拟合网络(PFN)将近似后验预测分布,从而针对在PFN预训练期间使用的人工数据集上的指定分布返回贝叶斯预测。

图2
图2

图2 | TabPFN先验概述。

  • a,对于每个数据集,我们首先采样高级超参数。
  • b,基于这些超参数,我们构建一个结构因果模型,该模型对生成数据集的计算函数进行编码。计算图中的每个节点持有一个向量,并且每条边根据连接类型之一实现一个函数。在步骤1中,我们使用随机噪声变量生成初始化数据,将其输入到图的根节点,并针对每个待生成的样本通过计算图进行传播。在步骤2中,我们在图中随机采样特征和目标节点的位置,分别标记为F和T。在步骤3中,我们在采样的特征和目标节点位置提取中间数据表示。在步骤4中,我们对提取的数据进行后处理。c,我们获取最终的数据集。我们绘制特征对之间的相互作用,并且节点颜色表示样本的类别。

这张图展示了生成合成表格数据集的过程,主要分为三个步骤:

a. 采样底层参数

  • 数据点数量采样:确定数据集中数据点的数量。
  • 特征数量采样:确定数据集中特征的数量。
  • 节点数量采样:确定计算图中节点的数量。
  • 图复杂度采样:对计算图的复杂度进行采样。
  • 图结构采样:采样得到具体的图结构,如图左下角展示的简单图结构示例。

b. 构建计算图和图结构

  1. 数据传播:对于每个生成的样本,将初始化数据通过计算图进行传播。
  2. 节点位置采样:随机采样特征(F)和目标(T)节点的位置。
  3. 数据读取:在采样得到的节点位置处读取数据。
  4. 后处理:进行后处理,包括量化和扭曲等操作。 此外,图中还展示了不同的连接类型,如神经网络(Neural network)、树(Tree)和离散化(Discretization),分别以不同的图示形式呈现。

c. 最终数据集

通过前面的步骤,生成了最终的合成表格数据集,图中以不同的可视化形式展示了多个最终数据集的样子,每个数据集可能具有不同的特征和分布。

总体而言,这张图详细描述了从底层参数采样到构建计算图,再到生成最终合成表格数据集的完整过程,目的是为TabPFN模型的训练提供多样化的合成数据。 这张图展示了生成合成表格数据集的过程,主要分为三个步骤:

a. 采样底层参数

  • 数据点数量采样:确定数据集中数据点的数量。
  • 特征数量采样:确定数据集中特征的数量。
  • 节点数量采样:确定计算图中节点的数量。
  • 图复杂度采样:对计算图的复杂度进行采样。
  • 图结构采样:采样得到具体的图结构,如图左下角展示的简单图结构示例。

b. 构建计算图和图结构

  1. 数据传播:对于每个生成的样本,将初始化数据通过计算图进行传播。
  2. 节点位置采样:随机采样特征(F)和目标(T)节点的位置。
  3. 数据读取:在采样得到的节点位置处读取数据。
  4. 后处理:进行后处理,包括量化和扭曲等操作。 此外,图中还展示了不同的连接类型,如神经网络(Neural network)、树(Tree)和离散化(Discretization),分别以不同的图示形式呈现。

c. 最终数据集

通过前面的步骤,生成了最终的合成表格数据集,图中以不同的可视化形式展示了多个最终数据集的样子,每个数据集可能具有不同的特征和分布。

总体而言,这张图详细描述了从底层参数采样到构建计算图,再到生成最终合成表格数据集的完整过程,目的是为TabPFN模型的训练提供多样化的合成数据。


结构因果模型(Structural Causal Model,简称SCM)是一种用于描述和分析因果关系的数学框架,在人工智能、统计学等领域有广泛应用,以下从几个方面介绍:

基本构成

  • 变量:模型包含多种变量,如外生变量(模型外部给定,不受其他变量影响,类似自变量)和内生变量(由模型内其他变量决定,类似因变量 )。例如在研究“降雨量(外生变量)”对“农作物产量(内生变量)”的影响时,降雨量是外部给定条件,而农作物产量受降雨量等因素影响。
  • 结构方程:通过一组方程来表示变量之间的因果关系。这些方程不是简单的相关关系描述,而是明确因果传递路径。比如在一个简单经济模型中,可能有方程“消费 = 固定消费 + 边际消费倾向×收入”,体现了收入对消费的因果决定关系。
  • 因果图:常以图形化方式呈现,节点代表变量,边代表变量间因果联系。箭头从原因变量指向结果变量,直观展示因果结构。例如在疾病研究中,因果图可能显示“吸烟”节点指向“患肺癌风险”节点,表示吸烟是导致患肺癌风险变化的原因之一。

作用

  • 因果推断:帮助分析因果效应,判断一个变量变化如何影响另一个变量。在医学试验中,可利用结构因果模型判断新药物(干预变量)对患者康复情况(结果变量)的因果影响,排除其他混杂因素干扰。
  • 反事实推理:基于模型可进行反事实思考,即想象在不同条件下会发生什么。例如分析“如果某公司没有进行某项投资决策,其利润会如何变化”,为决策评估提供参考。

与传统模型区别

相比传统只描述变量间相关性的统计模型,结构因果模型更强调因果关系,能更深入理解数据生成机制,在解决复杂因果问题上优势明显。

to be continued


如果您对真实世界研究/临床因果估计方法/生信分析/影像组学人工智能算法感兴趣可以 通过下方的微信加我的交流群

助教微信-程老师
助教微信-程老师
助教微信-金老师
助教微信-金老师

欢迎关注我的视频号-每周定期直播免费文献分享会

扫一扫,添加我的视频号
扫一扫,添加我的视频号

欢迎关注我的小红书

欢迎关注我的B站账号-公开课及文献分享视频会更新至此

我的B站
我的B站

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/179254
 
96 次点击