JAX is Autograd and XLA, brought together for high-performance machine learning research.
Currently supported JAX versions:
The modules contain JAX for Python with GPU support via CUDA.
If you find that some package is missing, you can often install it
pip install --user. See our Python
for more information on how to install packages yourself. If you think
that some important JAX-related package should be included in
the modules provided by CSC, please contact our
The JAX modules are Singularity-based but wrapper scripts have been
provided so that common commands such as
pip3 should work as normal. For more information, see CSC's
general instructions on how to run Singularity
JAX is licensed under Apache License 2.0.
To use this software on Puhti or Mahti, initialize it with:
module load jax
to access the default version.
Please note that the JAX modules already include CUDA and cuDNN libraries, so there is no need to load cuda and cudnn modules separately!
This will show all available versions of JAX:
module avail jax
The JAX modules include several libraries from the JAX ecosystem (e.g. Haiku, Flax, Trax, Objax, and Elegy). To check the exact packages and versions included in the loaded module you can run:
Note that the login nodes are not intended for heavy computing, please use slurm batch jobs instead. See our instructions on how to use the batch job system.
Please do not read a huge number of files from the shared file system, use fast local disk or package your data into larger files instead! See the Data storage section in our machine learning guide for more details.
Last edited Fri Jun 10 2022