如果一个用户选择了JAX,那基本上只有一个原因:速度。例如同一段函数功能,可以看到numpy的实现需要大概851毫秒。而如果换成JAX,结果缩短到5.54ms,实现了超过numpy 150倍的性能提升!如果画成直方图,性能优势就显得更明显了。而JAX计算更快的原因是使用了TPU,而NumPy只使用CPU。虽然实际使用中并不是「用上JAX,你的程序就会加速150倍」那么简单,但仍然有很多理由来使用它。JAX为科学计算提供了一个通用的基础,它对不同领域的人有不同的用途。从根本上说,如果你在任何与科学计算有关的领域,你都应该了解JAX。作者列出了6个应该使用JAX原因:1. 加速NumPy。NumPy是用Python进行科学计算的基本软件包之一,但它只与CPU兼容。JAX提供了一个NumPy的实现(具有近乎相同的API),可以非常容易地在GPU和TPU上工作。对于许多用户来说,仅仅这一点就足以证明使用JAX的合理性。2. XLA,即加速线性代数(Accelerated Linear Algebra),是一个全程序优化编译器,专门为线性代数设计。JAX是建立在XLA之上的,大大提升了计算速度的上限。3. JIT。JAX允许用户使用XLA将函数转化为JIT(just in time)编译的版本。这意味用户可以通过给计算函数添加一个简单的函数装饰器来提高计算速度,可能是几个数量级的性能提升。4. 自动求导。JAX文档将JAX称为Autograd和XLA的结合体。自动求导的能力在科学计算的许多领域都是至关重要的,而JAX提供了几个强大的自动求导工具。5. 深度学习。虽然JAX本身不是一个深度学习框架,但它肯定为深度学习提供了一个更充分的基础。现在有许多建立在JAX之上的深度学习库,例如Flax、Haiku和Elegy。甚至有研究人员在PyTorch vs TensorFlow文章中强调JAX也是一个值得关注的「框架」,推荐其用于基于TPU的深度学习研究。JAX对Hessians的高效计算也与深度学习有关,因为它们使高阶优化技术更进一步。6. 通用可微分编程范式。虽然可以使用JAX来构建和训练深度学习模型,但它也为通用可微分编程提供了一个框架。这意味着JAX可以通过使用基于模型的机器学习方法来解决实际问题。 2022年,该学JAX吗?