JAX
JAX is Autograd and XLA, brought together for high-performance machine learning research.
News
5.10.2022 Due to Puhti's update to Red Hat Enterprise Linux 8 (RHEL8), the number of supported JAX versions has been reduced. Please contact our servicedesk if you really need access to older versions.
Available
Currently supported JAX versions:
Version | Module | Puhti | Mahti | LUMI | Notes |
---|---|---|---|---|---|
0.4.20 | jax/0.4.20 |
X | X | - | default version |
0.4.18 | jax/0.4.18 |
- | - | X* | |
0.4.14 | jax/0.4.14 |
X | X | - | |
0.4.13 | jax/0.4.13 |
X | X | - | |
0.4.1 | jax/0.4.1 |
X | X | - | |
0.3.13 | jax/0.3.13 |
X | X | - |
The modules contain JAX for Python with GPU support via CUDA/ROCm.
Versions in LUMI, marked as "X*" are still experimental with limited support. They are subject to change at any time without notice. Note that JAX is also available in the LUMI Software Library.
If you find that some package is missing, you can often install it
yourself with pip install --user
. See our Python
documentation
for more information on how to install packages yourself. If you think
that some important JAX-related package should be included in
the modules provided by CSC, please contact our
servicedesk.
All modules are based on containers using Apptainer (previously known
as Singularity). Wrapper scripts have been provided so that common
commands such as python
, python3
, pip
and pip3
should work as
normal. For other commands, you need to prefix them with
apptainer_wrapper exec
. For more information, see CSC's general
instructions on how to run Apptainer
containers.
With recent modules it is also possible to use Python virtual
environments. To
create a virtual environment use the command
python3 -m venv --system-site-packages venv
.
License
JAX is licensed under Apache License 2.0.
Usage
To use the default version on Puhti or Mahti, initialize it with:
module load jax
To access CSC-installed JAX on LUMI:
module use /appl/local/csc/modulefiles/
module load jax
Please note that the JAX modules already include the corresponding CUDA and cuDNN or ROCm libraries, so there is no need to load any cuda, cudnn, or rocm modules separately!
This will show all available versions of JAX:
module avail jax
The JAX modules include several libraries from the JAX ecosystem (e.g. Haiku, Flax, and Objax). To check the exact packages and versions included in the loaded module you can run:
list-packages
Note
Note that the login nodes are not intended for heavy computing, please use slurm batch jobs instead. See our instructions on how to use the batch job system.
Note
Please do not read a huge number of files from the shared file system, use fast local disk or package your data into larger files instead! See the Data storage section in our machine learning guide for more details.