社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

针对深度学习的“失忆症”,科学家提出基于相似性加权交错学习,登上PNAS

AI科技评论 • 2 年前 • 292 次点击  

与人类不同,人工神经网络在学习新事物时会迅速遗忘先前学到的信息,必须通过新旧信息的交错来重新训练;但是,交错全部旧信息非常耗时,并且可能没有必要。只交错与新信息有实质相似性的旧信息可能就足够了。

近日,美国科学院院报(PNAS)刊登了一篇论文,“Learning in deep neural networks and brains with similarity-weighted interleaved learning”,由加拿大皇家学会会士、知名神经科学家 Bruce McNaughton 的团队发表。他们的工作发现,通过将旧信息与新信息进行相似性加权交错训练,深度网络可以快速学习新事物,不仅降低了遗忘率,而且使用的数据量大幅减少。


论文作者还作出一个假设:通过跟踪最近活跃的神经元和神经动力学吸引子(attractor dynamics)的持续兴奋性轨迹,可以在大脑中实现相似性加权交错。这些发现可能会促进神经科学和机器学习的进一步发展。

作者 | Rajat Saxena et al.
编译 | bluemin
编辑 | 陈彩娴



1

研究背景

了解大脑如何终身学习仍然是一项长期挑战。
在人工神经网络(ANN)中,过快地整合新信息会产生灾难性干扰,即先前获得的知识突然丢失。互补学习系统理论 (Complementary Learning Systems Theory,CLST) 表明,通过将新记忆与现有知识交错,新记忆可以逐渐融入新皮质。
CLST指出,大脑依赖于互补的学习系统:海马体 (HC) 用于快速获取新记忆,新皮层 (NC) 用于将新数据逐渐整合到与上下文无关的结构化知识中。在“离线期间”,例如睡眠和安静的清醒休息期间,HC触发回放最近在NC中的经历,而NC自发地检索和交错现有类别的表征。交错回放允许以梯度下降的方式逐步调整NC突触权重,以创建与上下文无关的类别表征,从而优雅地整合新记忆并克服灾难性干扰。许多研究已经成功地使用交错回放实现了神经网络的终身学习。
然而,在实践中应用CLST时,有两个重要问题亟待解决。首先,当大脑无法访问所有旧数据时,如何进行全面的信息交错呢?一种可能的解决方案是“伪排练”,其中随机输入可以引发内部表征的生成式回放,而无需显式访问先前学习的示例。类吸引子动力学可能使大脑完成“伪排练”,但“伪排练”的内容尚未明确。因此,第二个问题是,每进行新的学习活动之后,大脑是否有充足的时间交织所有先前学习的信息。
相似性加权交错学习(Similarity-Weighted Interleaved Learning,SWIL)算法被认为是第二个问题的解决方案,这表明仅交错与新信息具有实质表征相似性的旧信息可能就足够了。实证行为研究表明,高度一致的新信息可以快速整合到NC结构化知识中,几乎没有干扰。这表明整合新信息的速度取决于其与先验知识的一致性。受此行为结果的启发,并通过重新检查先前获得的类别之间的灾难性干扰分布,McClelland等人证明SWIL可以在具有两个上义词类别(例如,“水果”是“苹果”和“香蕉”的上义词)的简单数据集中,每个epoch使用少于2.5倍的数据量学习新信息,实现了与在全部数据上训练网络相同的性能。然而,研究人员在使用更复杂的数据集时并没有发现类似的效果,这引发了对该算法可扩展性的担忧。
实验表明,深度非线性人工神经网络可以通过仅交错与新信息共享大量表征相似性的旧信息子集来学习新信息。通过使用SWIL算法,ANN能够以相似的精度水平和最小的干扰快速学习新信息,同时使用的每个时期呈现的旧信息量少之又少,这意味着数据利用率高且可以快速学习。
同时,SWIL也可应用于序列学习框架。此外,学习一种新类别可以极大地提高数据利用率 。如果旧信息与之前学习过的类别有着非常少的相似性,那么呈现的旧信息数量就会少得多,这很可能是人类学习的实际情况。
最后,作者提出了一个关于SWIL如何在大脑中实现的理论模型,其兴奋性偏差与新信息的重叠成正比。



2

应用于图像分类数据集的

DNN动力学模型

McClelland等人的实验表明,在具有一个隐藏层的深度线性网络中,SWIL可以学习一个新类别,类似于完全交错学习 (Fully Interleaved Learning,FIL),即将整个旧类别与新类别交错,但使用的数据量减少了40%。
然而,网络是在一个非常简单的数据集上训练的,只有两个上义词类别,这就对算法的可扩展性提出了疑问。
首先针对更复杂的数据集(如Fashion-MNIST),探索不同类别的学习在具有一个隐藏层的深度线性神经网络中如何演变。移出了“boot”(“靴子”)和“bag”(“纸袋”)类别后,该模型在剩余的8个类别上的测试准确率达到了87%。然后作者团队重新训练模型,在两种不同的条件下学习(新的)“boot”类,每个条件重复10次:
1)集中学习(Focused Learning ,FoL),即仅呈现新的“boot”类;
2)完全交错学习 (FIL),即所有类别(新类别+以前学过的类别)以相等的概率呈现。在这两种情况下,每个epoch总共呈现180张图像,每个epoch中的图像相同。
该网络在总共9000张从未见过的图像上进行了测试,其中测试数据集由每类1000张图像组成,不包括“bag”类别。当网络的性能达到渐近线时,训练停止。
不出所料,FoL对旧类别造成了干扰,而FIL克服了这一点(图1第2列)。如上所述,FoL对旧数据的干扰因类别而异,这是SWIL最初灵感的一部分,并表明新“boot”类别和旧类别之间存在分级相似关系。例如,“sneaker”(“运动鞋”)和“sandals”(“凉鞋”)的召回率比“trouser”(“裤子”)下降得更快(图1第2列),可能是因为整合新的“boot”类会选择性地改变代表“sneaker”和“sandals”类的突触权重,从而造成更多的干扰。

图1:预训练网络在两种情况下学习新“boot”类的性能对比分析:FoL(上)和 FIL(下)。从左到右依次为预测新“boot”类别的召回率(橄榄色)、现有类别的召回率(用不同颜色绘制)、总准确度(高分意味着低误差)和交叉熵损失(总误差的度量)曲线,是保留的测试数据集上与epoch数有关的函数。



3

计算不同类别之间的相似度

FoL在学习新类别的时候,在相似的旧类别上的分类性能会大幅下降。
之前已经探讨了多类别属性相似度和学习之间的关系,并且表明深度线性网络可以快速获取已知的一致属性。相比之下,在现有类别层次结构中添加新分支的不一致属性,需要缓慢、渐进、交错的学习。
在当前的工作中,作者团队使用已提出的方法在特征级别计算相似度。简言之,计算目标隐藏层(通常是倒数第二层)现有类别和新类别的平均每类激活向量之间的余弦相似度。图2A显示了基于Fashion MNIST数据集的新“boot”类别和旧类别,作者团队根据预训练网络的倒数第二层激活函数计算的相似度矩阵。
类别之间的相似性与我们对物体的视觉感知一致。例如,在层次聚类图(图2B)中,我们可以观察到“boot”类与“sneaker”和“sandal”类之间、以及“shirt”(“衬衫”)和“t-shirt”(“T恤”)类之间具有较高的相似性。相似度矩阵(图2A)与混淆矩阵(图2C)完全对应。相似度越高,越容易混淆,例如,“衬衫”类与“T恤”、“套头衫”和“外套”类图像容易混淆,这表明相似性度量预测了神经网络的学习动态。
在上一节的FoL结果图(图1)中,旧类别的召回率曲线中存在相近的类相似度曲线。与不同的旧类别(“trouser”等)相比,FoL学习新“boot”类的时候会快速遗忘相似的旧类别(“sneaker” 和 “sandal”)。

图2:( A ) 作者团队根据预训练网络的倒数第二层激活函数,计算的现有类别和新“boot”类的相似度矩阵,其中对角线值(同一类别的相似性绘制为白色)被删除。( B ) 对A中的相似矩阵进行层次聚类。( C ) FIL算法在训练学习“boot”类后生成的混淆矩阵。为了缩放清晰,删除了对角线值。



4

深度线性神经网络实现快速和

高效学习新事物

接下来在前两个条件基础上增加了3种新条件,研究了新的分类学习动态,其中每个条件重复10次:
1)FoL(共计n=6000张图像/epoch);
2) FIL(共计n=54000张图像/epoch,6000张图像/类);
3) 部分交错学习 (Partial Interleaved Learning,PIL)使用了很小的图像子集(共计n=350张图像/epoch,大约39张图像/类),每一类别(新类别+现有类别)的图像以相等的概率呈现;
4) SWIL,每个epoch使用与PIL 相同的图像总数进行重新训练,但根据与(新)“boot”类别的相似性对现有类别图像进行加权;
5)等权交错学习(Equally Weighted Interleaved Learning,EqWIL),使用与SWIL相同数量的“boot”类图像重新训练,但现有类别图像的权重相同(图3A)。
作者团队使用了上述相同的测试数据集(共有n=9000张图像)。当在每种条件下神经网络的性能都达到渐近线时,停止训练。尽管每个epoch使用的训练数据较少,预测新“boot”类的准确率需要更长的时间达到渐近线,与FIL(H=7.27,P<0.05)相比,PIL的召回率更低(图3B第1列和表1“New class”列)。
对于SWIL,相似度计算用于确定要交错的现有旧类别图像的比例。在此基础上,作者团队从每个旧类别中随机抽取具有加权概率的输入图像。与其他类别相比,“sneaker”和“sandal”类最相似,从而导致被交错的比例更高(图3A)。
根据树状图(图2B),作者团队将“sneaker”和“sandal”类称为相似的旧类,其余则称为不同的旧类。与PIL(H=5.44,P<0.05)相比,使用SWIL时,模型学习新“boot”类的速度更快,对现有类别的干扰也相近。此外,SWIL(H=0.056,P>0.05)的新类别召回率(图3B第1列和表1“New class”列)、总准确率和损失与FIL相当。EqWIL(H=10.99,P<0.05)中新“boot”类的学习与SWIL相同,但对相近的旧类别有更大程度的干扰(图3B第2列和表1“Similar old class”列)。
作者团队使用以下两种方法比较SWIL和FIL:
1) 内存比,即FIL和SWIL中存储的图像数量之比,表示存储的数据量减少;
2) 加速比,即在FIL和SWIL中呈现的内容总数的比率,以达到新类别回忆的饱和精度,表明学习新类别所需的时间减少。
SWIL可以在数据需求减少的情况下学习新内容,内存比=154.3x (54000/350),并且速度更快,加速比=77.1x (54000/(350×2))。即使和新内容有关的图像数量较少,该模型也可以通过使用SWIL,利用模型先验知识的层次结构实现相同的性能。SWIL在PIL和EqWIL之间提供了一个中间缓冲区,允许集成一个新类别,并将对现有类别的干扰降到最低。

图3 ( A ) 作者团队在五种不同的学习条件下预训练神经网络学习新的“boot”类(橄榄绿),直到性能平稳:1)FoL(共计n=6000张图像/epoch);2)FIL(共计n=54000张图像/epoch);3) PIL(共计n=350张图像/epoch);4) SWIL(共计n=350张图像/epoch)和 5) EqWIL(共计n=350张图像/epoch)。(B)FoL(黑色)、FIL(蓝色)、PIL(棕色)、SWIL(洋红色)和 EqWIL(金色)预测新类别、相似旧类别(“sneaker”和“sandals”)和不同旧类别的召回率,预测所有类别的总准确率,以及在测试数据集上的交叉熵损失,其中横坐标都是epoch数。



5

基于CIFAR10使用SWIL

在CNN中学习新类别

接下来,为了测试SWIL是否可以在更复杂的环境中工作,作者团队训练了一个具有全连接输出层的6层非线性CNN(图4A),以识别CIFAR10数据集中剩余8个不同类别(“cat”和“car”除外)的图像。他们还对模型进行了重新训练,在之前定义的5种不同训练条件(FoL、FIL、PIL、SWIL和EqWIL)下学习“cat”(“猫”)类。图4C显示了5种情况下每类图像的分布。对于SWIL、PIL和EqWIL条件,每个epoch的总图像数为2400,而对于FIL和FoL,每个epoch的总图像数分别为45000和5000。作者团队针对每种情况对网络分别进行训练,直到性能趋于稳定。
他们在之前未见过的总共9000张图像(1000张图像/类,不包括“car”(“轿车”)类)上对该模型进行了测试。图4B是作者团队基于CIFAR10数据集计算的相似性矩阵。“cat”类和“dog”(“狗”)类更类似,而其他动物类属于同一分支(图4B左)。
根据树状图(图4B),将“truck” (“货车”)、“ship”(“轮船”) 和 “plane”(“飞机”) 类别称为不同的旧类别,除“cat”类外其余的动物类别称为相似的旧类别。对于FoL,模型学习了新的“cat”类,但遗忘了旧类别。与Fashion-MNIST数据集结果类似,“dog”类(与“cat”类相似性最大)和“truck”类(与“cat”类相似性最小)均存在干扰梯度,其中“dog”类的遗忘率最高,而“truck”类遗忘率最低。
如图4D所示,FIL算法学习新的“cat”类时克服了灾难性的干扰。对于PIL算法,模型在每个epoch使用18.75倍的数据量学习新的“cat”类,但“cat”类的召回率比FIL(H=5.72,P<0.05)低。对于SWIL,在新类别、相似和不同旧类别上的召回率、总准确率和损失与FIL相当(H=0.42,P>0.05;见表2和图4D)。SWIL对新“cat”类的召回率高于PIL(H=7.89,P<0.05)。使用EqWIL算法时,新“cat”类的学习情况与SWIL和FIL相似,但对相似旧类别的干扰较大(H=24.77,P<0.05;见表2)。
FIL、PIL、SWIL和EqWIL这4种算法预测不同旧类别的性能相当(H=0.6,P>0.05)。SWI比PIL更好地融合了新的“cat”类,并有助于克服EqWIL中的观测干扰。与FIL相比,使用SWIL学习新类别速度更快,加速比=31.25x (45000×10/(2400×6)),同时使用更少的数据量 (内存比=18.75x)。这些结果证明,即使在非线性CNN和更真实的数据集上,SWIL也可以有效学习新类别事物。

图4:( A ) 作者团队使用具有全连接输出层的6层非线性CNN学习CIFAR10数据集中的8类事物。( B ) 相似度矩阵 (右)是在呈现新的“cat”类之后,作者团队根据最后一个卷积层的激活函数计算获得。对相似矩阵应用层次聚类(左),在树状图中显示动物(橄榄绿)和交通工具(蓝色)两个上义词类别的分组情况。( C ) 作者团队在5种不同的条件下预训练CNN学习新的“cat”类(橄榄绿),直到性能平稳:1)FoL(共计n=5000张图像/epoch);2)FIL(共计n=45000张图像/epoch);3) PIL(共计n=2400张图像/epoch);4) SWIL(共计n=2400张图像/epoch);5) EqWIL(共计n=2400张图像/epoch)。每个条件重复10次。(D)FoL(黑色)、FIL(蓝色)、PIL(棕色)、SWIL(洋红色)和 EqWIL(金色)预测新类别、相似旧类别(CIFAR10数据集中的其他动物类)和不同旧类别(“plane” 、“ship” 和 “truck”)的召回率,预测所有类别的总准确率,以及在测试数据集上的交叉熵损失,其中横坐标都是epoch数。



6

新内容与旧类别的一致性

对学习时间和所需数据的影响

如果一项新内容可以添加到先前学习过的类别中,而不需要对网络进行较大更改,则称二者具有一致性。基于此框架,与干扰多个现有类别(低一致性)的新类别相比,学习干扰更少现有类别(高一致性)的新类别可以更容易地集成到网络中。
为了测试上述推断,作者团队使用上一节中经过预训练的CNN,在前面描述的所有5种学习条件下,学习了一个新的“car”类别。图5A显示了“car”类别的相似性矩阵,与其他现有类别相比,“car”和“truck”、“ship”和“plane”在同一层次节点下,说明它们更相似。为了进一步确认,作者团队在用于相似性计算的激活层上进行了t-SNE降维可视化分析(图5B)。研究发现“car”类与其他交通工具类(“truck”、“ship”和“plane”)有显著重叠,而“cat”类与其他动物类(“dog”、 “frog”(“青蛙”)、“horse”(“马”)、“bird”(“鸟”)和“deer”(“鹿”))有重叠。
和作者团队预期相符,FoL学习“car”类别时会产生灾难性干扰,对相近的旧类别干扰性更强,而使用FIL克服了这一点(图5D)。对于PIL、SWIL和EqWIL,每个epoch总共有n=2000张图像(图5C)。使用SWIL算法,模型学习新的“car”类别可以达到和FIL(H=0.79,P>0.05)相近的精度,而对现有类别(包括相似和不同类别)的干扰最小。如图5D第2列所示,使用EqWIL,模型学习新“car”类的方式与SWIL相同,但对其他相似类别(例如“truck”)的干扰程度更高(H=53.81,P<0.05)。
与FIL相比,SWIL可以更快地学习新内容加速比=48.75x(45000×12/(2000×6)),内存需求减少,内存比=22.5x。与“cat”(48.75x vs.31.25x)相比,“car”可以通过交错更少的类(如“truck”、“ship”和“plane”)更快地学习,而“cat”与更多的类别(如“dog” 、“frog” 、“horse” 、“frog” 和“deer”)重叠。这些仿真实验表明,交叉和加速学习新类别所需的旧类别数据量,取决于新信息与先验知识的一致性。

图 5:( A ) 作者团队根据倒数第二层激活函数计算获得相似度矩阵(左),以及呈现新的“car”类别后对相似度矩阵进行层次聚类后的结果图(右)。( B ) 模型分别学习新的“car”类别和“cat”类别,经过最后一个卷积层过激活函数后,作者团队进行t-SNE降维可视化的结果图。( C ) 作者团队在5种不同的条件下预训练CNN学习新的“car”类(橄榄绿),直到性能平稳:1)FoL(共计n=5000张图像/epoch);2)FIL(共计n=45000张图像/epoch);3) PIL(共计n=2000张图像/epoch);4) SWIL(共计n=2000张图像/epoch);5) EqWIL(共计n=2000张图像/epoch)。(D)FoL(黑色)、FIL(蓝色)、PIL(棕色)、SWIL(洋红色)和 EqWIL(金色)预测新类别、相似旧类别(“plane” 、“ship” 和 “truck”)和不同旧类别(CIFAR10数据集中的其他动物类)的召回率,预测所有类别的总准确率,以及在测试数据集上的交叉熵损失,其中横坐标都是epoch数。每张图显示的是重复10次后的平均值,阴影区域为±1 SEM。



7

利用SWIL进行序列学习

接下来,作者团队测试是否可以使用SWIL学习序列化形式呈现的新内容(序列学习框架)。为此他们采用了图4中经过训练的CNN模型,在FIL和SWIL条件下学习CIFAR10数据集中的“cat”类(任务1),只在CIFAR10的剩余9个类别上训练,然后在每个条件下训练模型学习新的“car”类(任务2)。图6第1列显示了SWIL条件下学习“car”类别时,其他各项类别的图像数量分布情况(共计n=2500张图像/epoch)。需要注意的是,预测“cat”类时也交叉学习新的“car”类。由于在FIL条件下模型性能最佳,SWIL仅与FIL进行了结果比较。
如图6所示,SWIL预测新、旧类别的能力与FIL相当(H=14.3,P>0.05)。模型使用SWIL算法可以更快地学习新的“car”类别,加速比为45x(50000×20/(2500×8)),每个epoch的内存占用比FIL少20倍。模型学习“cat”和“car”类别时,在SWIL条件下每个epoch使用的图像数量(内存比和加速比分别为18.75x 和 20x),少于在FIL条件下每个epoch使用的整个数据集(内存比和加速比分别为31.25x 和45x),并且仍然可以快速学习新类别。扩展这一思想,随着学过的类别数目不断增加,作者团队预期模型的学习时间和数据存储会成倍减少,从而更高效地学习新类别,这或许反映了人类大脑实际学习时的情况。
实验结果表明,SWIL可在序列学习框架中集成多个新类,使神经网络能够在不受干扰的情况下持续学习。
图6:作者团队训练6层CNN学习新的“cat”类(任务1),然后学习“car”类(任务2),直到性能在以下两种情况下趋于稳定:1)FIL:包含所有旧类别(以不同颜色绘制)和以相同概率呈现的新类别(“cat”/“car”)图像;2) SWIL:根据与新类别(“cat”/“car”)的相似性进行加权并按比例使用旧类别示例。同时将任务1中学习的“cat”类包括在内,并根据任务2中学习“car”类的相似性进行加权。第1张子图表示每个epoch使用的图像数量分布情况,其余各子图分别表示FIL(蓝色)和SWIL(洋红色)预测新类别、相似旧类别和不同旧类别的召回率,预测所有类别的总准确率,以及在测试数据集上的交叉熵损失,其中横坐标都是epoch数。



8

利用SWIL扩大类别间的距离,

减少学习时间和数据量

作者团队最后测试了SWIL算法的泛化性,验证其是否可以学习包括更多类别的数据集,以及是否适用于更复杂的网络架构。
他们在CIFAR100数据集(训练集500张图像/类,测试集100张图像/类)上训练了一个复杂的CNN模型-VGG19(共有19层),学习了其中的90个类别。然后对网络进行再训练,学习新类别。图7A显示了基于CIFAR100数据集,作者团队根据倒数第二层的激活函数计算的相似性矩阵。如图7B所示,新“train”(“火车”)类与许多现有的交通工具类别(如“bus” (“公共汽车”)、“streetcar” (“有轨电车”)和“tractor”(“拖拉机”)等)很相似。
与FIL相比,SWIL可以更快地学习新事物(加速比=95.45x (45500×6/(1430×2)))并且使用的数据量 (内存比=31.8x) 显著减少,而性能基本相同(H=8.21, P>0.05) 。如图7C所示,在PIL(H=10.34,P<0.05)和EqWIL(H=24.77,P<0.05)条件下,模型预测新类别的召回率较低并且产生的干扰较大,而SWIL克服了上述不足。
同时,为了探索不同类别表征之间的较大距离是否构成了加速模型学习的基本条件,作者团队另外训练了两种神经网络模型:
1)6层CNN(与基于CIFAR10的图4和图5相同);
2)VGG11(11层)学习CIFAR100数据集中的90个类别,仅在FIL和SWIL两个条件下对新的“train”类进行训练。
如图7B所示,对于上述两种网络模型,新的“train”类和交通工具类别之间的重叠度更高,但与VGG19模型相比,各类别的分离度较低。与FIL相比,SWIL学习新事物的速度与层数的增加大致呈线性关系(斜率=0.84)。该结果表明,类别间表征距离的增加可以加速学习并减少内存负载。

图7:( A ) VGG19学习新的“train”类后,作者团队根据倒数第二层激活函数计算的相似性矩阵。“truck” 、“streetcar” 、“bus” 、“house” 和 “tractor”5种类别与“train”的相似性最大。从相似度矩阵中排除对角元素(相似度 =1)。(B,左)作者团队针对6层CNN、VGG11和VGG19网络,经过倒数第二层激活函数后,进行t-SNE降维可视化的结果图。(B,右)纵轴表示加速比(FIL/SWIL),横轴表示3个不同网络的层数相对于6层CNN的比率。黑色虚线、红色虚线和蓝色实线分别代表斜率 =1的标准线、最佳拟合线和仿真结果。( C ) VGG19模型的学习情况:FoL(黑色)、FIL(蓝色)、PIL(棕色)、SWIL(洋红色)和 EqWIL(金色)预测新“train”类、相似旧类别(交通工具类别)和不同旧类别(除了交通工具类别)的召回率,预测所有类别的总准确率,以及在测试数据集上的交叉熵损失,其中横坐标都是epoch数。每张图显示的是重复10次后的平均值,阴影区域为±1 SEM。( D ) 从左到右依次表示模型预测Fashion-MNIST“boot”类(图3)、CIFAR10“cat”类(图4)、CIFAR10“car”类(图5)和CIFAR100“train”类的召回率,是SWIL(洋红色)和FIL(蓝色)使用的图像总数(对数比例)的函数。“N”表示每种学习条件下每个epoch使用的图像总数(包括新、旧类别)。
如果在更多非重叠类上训练网络,并且各表征之间的距离更大,速度是否会进一步提升?
为此,作者团队采用了一个深度线性网络(用于图1-3中的Fashion-MNIST示例),并对其进行训练,以学习由8个Fashion-MNIST类别(不包括“bags”和“boot”类)和10个Digit-MNIST类别形成的组合数据集,然后训练网络学习新的“boot”类别。
和作者团队的预期相符,“boot”与旧类别“sandals”和“sneaker”相似度更高,其次是其余的Fashion-MNIST类(主要包括服饰类图像),最后Digit-MNIST类(主要包括数字类图像)。
基于此,作者团队首先交织了更多相似的旧类别样本,再交织Fashion-MNIST和Digit-MNIST类样本(共计n=350张图像/epoch)。实验结果表明,与FIL类似,SWIL可以快速学习新类别内容而不受干扰,但使用的数据子集要小得多,内存比为325.7x (114000/350) ,加速比为162.85x (228000/1400)。作者团队在当前结果中观察到的加速比为2.1x (162.85/77.1),与Fashion-MNIST数据集相比,类别数目增加了 2.25倍 (18/8)。
本节的实验结果有助于确定SWIL可以适用于更复杂的数据集 (CIFAR100) 和神经网络模型(VGG19),证明了该算法的泛化性。同时证明了扩大类别之间的内部距离或增加非重叠类别的数量,可能会进一步提高学习速度并降低内存负载。



9

总结

人工神经网络在持续学习方面面临重大挑战,通常表现出灾难性干扰。为了克服此问题,许多研究都使用了完全交错学习(FIL),即新旧内容交叉学习,联合训练网络。FIL需要在每次学新信息时交织所有现有信息,使其成为一个生物学意义上不可信且耗时的过程。最近,有研究表明FIL可能并非必需,仅交错与新内容具有实质表征相似性的旧内容,即采用相似性加权交错学习(SWIL)的方法可以达到相同的学习效果。然而,有人对SWIL的可扩展性表示了担忧。
本文扩展了SWIL算法,并基于不同的数据集(Fashion-MNIST、CIFAR10 和 CIFAR100)和神经网络模型(深度线性网络和CNN)对其进行了测试。在所有条件下,与部分交错学习(PIL)相比,相似性加权交错学习(SWIL)和等权交错学习(EqWIL)在学习新类别方面的表现更好。这和作者团队的预期相符,因为与旧类别相比,SWIL和EqWIL增加了新类别的相对频率。
本文同时还证明,与同等子抽样现有类别(即EqWIL方法)相比,仔细选择和交织相似内容减少了对相近旧类别的灾难性干扰。在预测新类别和现有类别方面,SWIL的性能与FIL类似,却显著加快了学习新内容的速度(图7D),同时大大减少了所需的训练数据。SWIL可以在序列学习框架中学习新类别,进一步证明了其泛化能力。
最后,与许多旧类别具有相似性的新类别相比,如果其与之前学过的类别重叠更少(距离更大),可以缩短集成时间,并且数据效率更高。总体来说,实验结果提供了一种可能的见解,即大脑事实上通过减少不切实际的训练时间,克服了原始CLST模型的一项主要弱点。
原文链接:
https://www.pnas.org/doi/10.1073/pnas.2115229119

更多内容,点击下方关注:

扫码添加 AI 科技评论 微信号,投稿&进群:
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/136797
 
292 次点击