CodingDeepLearningProgrammingLanguages This week I’ve been setting up a coding environment for deep learning development for my project on Generative Audio. I’m utilizing the NYU HPC for compute resources and flexibility in setting up custom development environments.
I am comfortable with PyTorch and have wanted to utilize JAX in my personal projects for some time. As a brief introduction, JAX is a numerical computing framework for machine learning research and scientific computing. At the highest level, one might view JAX as GPU accelerated NumPy, but it is also more than that. Here are the essential highlights of JAX features that pulled me towards using it:
- GPU utilization for computations with multi-dimensional arrays
- Automatic differentiation
- Just-In-Time (JIT) compilation for dynamic code generation at run-time
- Vectorization for processing mini-batches
- Functional programming patterns i.e. eliminating side effects via pure functions
- Deep Learning and Optimization libraries like Flax and Optax
- Backend for the probabilistic programming language NumPyro for Bayesian ML
The last point is the main selling point for me: ideally I would have a unified computing environment for both my deep learning and probabilistic modeling experiments.
Unfortunately, this unification is not fully possible in the current world of Deep Learning frameworks and researchers are forced to make various tradeoffs in terms of their environment. As I am still learning and mastering my tools, JAX/Flax/NumPyro is my current choice but is by no means a perfect solution.
Indeed, PyTorch and Tensorflow both contain tried and true modules for data processing pipelines that can easily utilize multiprocessing for data and model parallelism when training on multiple GPUs. However, users have faced issues trying to utilize PyTorch dataloader’s multiprocessing capabilities when using multiple GPUs.
For this reason, I chose to use NVIDIA’s DALI for data loading and processing which supports audio processing pipelines and native support for JAX and Flax. Below is a summary of the packages I will be using for this project:
jax
: autodiff backendflax
: NN libraryoptax
: Optimization and gradient computationsdali
: Data loading and pre-processing
My compute environment is Ubuntu 22.04.4
with miniconda using Python 3.11.9
, Cuda 12.3.2
and cuDNN 9.0.0
. The first hurdle was when installing the most recent version of JAX, 0.4.31, which requires cudnn >= 9.1
, I was faced with the following error:
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907
The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
The solution (as in the above link) is to downgrade to jax <= 0.4.28
, but for some reason my conda environment could not resolve installing that version and required that I find the correct the correct Cuda build of jaxlib
on the conda-forge
channel by running:
conda search jaxlib
and selecting the respective build for cuda==12.0
. I also could not resolve such issues using the classic
conda solver and utilized libmamba
. I also needed to utilize the CONDA_OVERRIDE_CUDA
flag in order to utilize the MPI binaries available on NYU HPC, per JAX’s conda
installation instructions:
CONDA_OVERRIDE_CUDA="12.0" conda install jaxlib=0.4.28=cuda120py311had15d7a_200 jax cuda-nvcc -c conda-forge -c nvidia --solver=libmamba
The next hurdle arose due to the following error thrown by chex
(a testing framework utilized by optax
):
AttributeError: module 'jax.interpreters.xla' has no attribute 'DeviceArray'
When installing optax
naively using conda install optax -c conda-forge
it downgraded my version of chex
that is incompatible with the selected version of jax==0.4.28
, so my work around is to run the following command to upgrade chex
and using --freeze-installed
to ensure optax
doesn’t downgrade chex
automatically:
conda install chex=0.1.86 optax=0.2.2 flax=0.8.5 -c conda-forge --override-channels --freeze-installed
There is no NVIDIA supported conda channel to install pre-built binaries for DALI, so I reverted to using the pip installation:
pip install --extra-index-url https://pypi.nvidia.com --upgrade nvidia-dali-cuda120