Nominal Semantics for First-class Automatic Differentiation
Automatic differentiation (AD) is a suite of techniques for mechanically computing the derivatives of functions represented as programs. First-class AD exposes differentiation as a higher- order construct within a programming language, enabling nested applications of the derivative operator. Beyond supporting the computation of higher and mixed partial derivatives, first-class AD also enables differentiation through optimizers, solvers, and inference engines that themselves rely on automated derivatives. This has applications to meta-learning, hyperparameter tuning, and Bayesian “reasoning-about-reasoning” in cognitive science.
This abstract studies a popular implementation strategy for first-class forward-mode AD, developed by Siskind and Pearlmutter and implemented in widely used systems for scientific computing and machine learning, including JAX and Julia. This strategy, tagged forward-mode, dynamically generates unique tags at runtime to identify distinct applications of the derivative operator. This avoids a problem known as perturbation confusion, which may compromise correctness when derivatives are nested. The soundness of this approach has usually been justified by informal algebraic arguments, analogizing these tags to distinct infinitesimals. In this work, we precisely formulate tagged forward-mode as a meaning-preserving compiler from a high-level source language with first-class differentiation to a low-level language with fresh tag generation as a computational effect. We then establish its correctness via logical relations.