We present GenJAX, a new language and compiler for vectorized
programmable probabilistic inference.
GenJAX integrates the vectorizing map (vmap) operation from
array programming frameworks such as JAX
into the programmable inference paradigm, enabling compositional
vectorization of features such as probabilistic program traces,
stochastic branching
(for expressing mixture models), and programmable inference interfaces
for writing custom probabilistic inference algorithms.
We formalize vectorization as a source-to-source program
transformation on a core calculus for probabilistic programming ($\gen$), and
prove that it correctly vectorizes both modeling and inference operations.
We have implemented our approach in
\href{https://github.com/probcomp/genjax}{the GenJAX language and
compiler}, and have empirically evaluated this implementation on
several benchmarks and case studies. Our results show that our implementation
supports a wide and expressive set of programmable inference
patterns and delivers
performance comparable to hand-optimized JAX code.