Problem scenario
You want to test Jax. What should you do?
Answer
1. Enter a virtaulenv. Use How Do You Install and Create a Virtual Environment? if necessary.
2. Run this: $ pip3 install jax
3. Run this: $ pip3 install jaxlib
4. Run this: $ python3
5. Run this (without the "> "): > import jax.numpy as contint
6. Run this (without the "> "): > foobar = contint.ones((4000, 4000))
7. You are done. The above will indicate if no GPU or TPU was found (e.g., your server only has regular CPUs)