Py-jax

Jul 20, 2023

Differentiate, compile, and transform Numpy code

JAX is Autograd and XLA, brought together for high-performance machine learning research.

With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation a.k.a. backpropagation via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

What’s new is that JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX also lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API, jit. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without leaving Python. You can even program multiple GPUs or TPU cores at once using pmap, and differentiate through the whole thing.

Dig a little deeper, and you’ll see that JAX is really an extensible system for composable function transformations. Both grad and jit are instances of such transformations. Others are vmap for automatic vectorization and pmap for single-program multiple-data SPMD parallel programming of multiple accelerators, with more to come.



Checkout these related ports:
  • Zn_poly - C library for polynomial arithmetic
  • Zimpl - Language to translate the LP models into .lp or .mps
  • Zegrapher - Software for plotting mathematical objects
  • Zarray - Dynamically typed N-D expression system based on xtensor
  • Z3 - Z3 Theorem Prover
  • Yices - SMT solver
  • Yacas - Yet Another Computer Algebra System
  • Xtensor - Multi-dimensional arrays with broadcasting and lazy computing
  • Xtensor-python - Python bindings for xtensor
  • Xtensor-io - Xtensor plugin to read/write images, audio files, numpy npz and HDF5
  • Xtensor-blas - BLAS extension to xtensor
  • Xspread - Spreadsheet program for X and terminals
  • Xppaut - Graphical tool for solving differential equations, etc
  • Xplot - X11 plotting package
  • Xlife++ - XLiFE++ eXtended Library of Finite Elements in C++