Article
11 Min Read

Google JAX: What It Is and How To Get Started

Discover what Google JAX is and how it improves machine learning tasks with JIT compilation, GPU/TPU accelerators, and automatic differentiation.

Google JAX: What It Is and How To Get Started
Listen to the audio version
13:47
/
19:15
1x

Google JAX lets developers build high-performance machine learning and deep learning applications. It works like Numpy but offers more features, including central processing unit (CPU), tensor processing unit (TPU), and graphics processing unit (GPU) support. And if you’re an avid NumPy user, JAX can still help compile your NumPy functions to run on hardware accelerators.

Another benefit to Google JAX is that it’s written from scratch using Python. Python has a simple syntax and is easy to learn and master, even for beginners. So you can install and import JAX in your projects—just like with other Python dependencies. This usability, among its other features, makes JAX useful for scientific and machine learning research.

In this article, we’ll discuss the Google Jax library, including its core features, how it fits in the machine learning landscape, and how it compares to Numpy. Finally, we’ll highlight how Jax works in real-world applications.

What is Google JAX?

Google JAX is a Python library for numerical computing that helps you solve math problems and create algorithms for machine learning research. It’s open-source, so anyone can use it for free.

JAX is becoming more popular in the field of artificial intelligence (AI) because it has better features than older libraries like NumPy. For example, it supports both forward-mode and backpropagation automatic differentiation, which are important for training neural networks efficiently.

Plus, JAX works with XLA, further enabling it to optimize models, especially on CPUs and GPUs. Since JAX is made with Python, it’s compatible with many other AI ecosystem libraries, making it easier for developers to build full-scale machine learning applications.

JAX also has good documentation that you can use if you need help. There’s a growing community of JAX users, too, so you can connect with them to learn more about how to use it.

Besides machine learning research, JAX is also useful for scientific computing, reinforcement learning, processing big datasets, and making generative models. It has a lot of potential in different parts of the Big Data and AI ecosystems.

Core JAX features

Jax is a robust library that brings numerous benefits to your machine learning projects. Below, we discuss some of its core features.

1. Automatic differentiation

JAX supports automatic differentiation (auto-diff) out of the box. In this context, differentiation focuses on executing numerical functions over time. By managing functions automatically, JAX facilitates gradient-based optimization, leading to more efficient machine learning algorithms and neural networks.

Previously, NumPy enabled automatic differentiation using the Autograd library, but Autograd is no longer being actively developed. JAX provides a newer and more efficient way of computing gradients. Using grad(), you can quickly evaluate different functions and return results, simplifying complex calculations.

To use the grad function, import JAX in your Python project. Next, access the grad() function from the JAX library as jax.grad(). You can then pass parameters—in this case, integers—to the grad function. The code for this step is below:

--CODE language-markup line-numbers--
Import jax
grad_tan = jax.grad(jax.numpy.tanh)
print(grad_tan(0.8))

JAX’s grad() method is much more efficient than the Autograd function, which has much more boilerplate code:

--CODE language-markup line-numbers--
import autograd.numpy as np  # Thinly-wrapped numpy
from autograd import grad
def tanh(x):
y = np.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh)
grad_tanh(0.8)


In both cases, the output is 0.55905503––but JAX returns a more rounded-off result.

2. JIT compilation

JAX supports just-in-time compilation (JIT), which enables faster code execution. This can boost the speed of your machine learning applications.

Specifically, JAX provides the @jit decorator to let you run multiple operations at the same time. JIT is compatible with XLA (Accelerated Linear Algebra), which optimizes models for faster execution.

For example, we wrote a simple program to calculate the sum of squares and used @jit decorator to force JAX to compile it to machine code for faster execution:

--CODE language-markup line-numbers--
import jax
import jax.numpy as jnp
from jax import jit
import time
@jit  #Using the JIT decorator
def sum_of_squares_withjit(x):
return jnp.sum(x**2)
# Example usage
in_array = jnp.arange(2_000_000)
starttime = time.time()
resultjit = sum_of_squares_withjit(in_array)
endtime = time.time()
print(“JIT Compilation result:”, resultjit)
print(“JIT Execution time:”, endtime - starttime,  “seconds”)

Below, you’ll see the same program but without JIT implementation:

--CODE language-markup line-numbers--
import jax
import jax.numpy as jnp
from jax import jit
import time
def sum_of_squares_withnojit(y):
return jnp.sum(y**2)
in_array = jnp.arange(2_000_000)
starttime = time.time()
resultnojit = sum_of_squares_withnojit(in_array)
endtime = time.time()
print(“Result without JIT:”, resultnojit)
print(“Without JIT executime time:”, endtime - starttime,  “seconds”)

When we ran the two programs, the execution time for the one with JIT compilation was 0.051 seconds, while the one without JIT was 0.063 seconds. While the difference is minimal for this small example, it showcases JIT’s upper hand in enhancing the execution speed for large programs.

3. Accelerators and high-performance

JAX runs on GPUs and TPUs by default for enhanced performance.  TPUs support large-scale matrix operations and are optimized for training and inferences out of the box. On the other hand, GPUs also facilitate parallelism and accelerate the execution of tasks.

But if you don’t have hardware accelerators like GPUs and TPUs, JAX has a fallback feature, enabling it to run on CPUs. This compatibility makes JAX suitable for creating machine learning applications for different hardware configurations.

JAX uses techniques like multithreading to enhance CPU performance, where multi-processes execute specific tasks. This minimizes cases of computer systems freezing due to heavy loads. However, CPU performance can slow down in the long run, especially when running intensive computations.

JAX and machine learning

Machine learning helps computers use data input to do tasks they weren’t originally programmed for. To do this, computers use numerical functions and math, like calculus and statistics, to look at data and find unique patterns, relationships, and features. This is where JAX comes in, making it easier to create efficient machine learning algorithms.

Next, we discuss how JAX is used in machine learning research and compare it to other deep learning frameworks like TensorFlow and PyTorch.

Neural networks and deep learning

JAX was made with large machine learning and deep learning projects in mind. It supports automatic differentiation right out of the box using the grad() function. This makes it faster to calculate gradients for backpropagation (reverse mode differentiation) and forward differentiation, which are important techniques for training neural network models.

JAX’s JIT feature, along with the XLA compiler, also helps code compile and run more quickly. In the core features of JAX section, we saw how the JIT compilation feature made a simple program that quickly calculated the sum of squares. This feature can bring even greater performance improvements in neural networks, where hundreds of calculations have to happen for these neural network models to process data.

Additionally, JAX supports auto-vectorization using the vmap() function. Auto-vectorization helps manage loops in code better and can optimize performance. JAX’s vmap() function changes array loops into primitive operators, which are faster to process, leading to better performance.

And since JAX works with CPUs and hardware accelerators like GPUs and TPUs, you can easily use neural networks on different hardware setups, including edge devices, mobile devices, and the cloud.

TensorFlow and PyTorch

TensorFlow and PyTorch are popular frameworks for developing deep learning applications. Like JAX, these platforms are also developed using Python, meaning they use simple syntax and offer a huge ecosystem of AI-related libraries and dependencies.

TensorFlow, PyTorch, and JAX also support hardware accelerators like TPUs and GPUs. However, TensorFlow is more mature when working with TPUs due to its built-in support for Tensors. It also supports a wide range of applications, including self-driving cars, Google rankings, drug discovery, and computer vision.

However, Tensorflow and PyTorch are more comprehensive and may have a steeper learning curve compared to JAX. TensorFlow and PyTorch also have larger communities of users and developers due to their longer history of use.

JAX vs. NumPy

JAX and NumPy have many useful functions for machine learning and scientific computing. But NumPy came out first, so it helped shape many of JAX’s functions and features. Even though they share some functions, there are a few key differences between JAX and NumPy.

First, NumPy doesn’t have built-in support for automatic differentiation. This means machine learning engineers have to manually optimize gradients, which makes the development process more complicated. JAX, on the other hand, has automatic differentiation built-in, making gradient computation easier.

Second, JAX has a lot of support for primitive operators like map, scan, and reduce. NumPy doesn’t have these functions, so you have to write more code to do the same things.

Third, regular NumPy is mainly designed to run on CPUs. You’ll need to use other libraries to take advantage of hardware accelerators. But JAX works with XLA, allowing users to run their projects on CPUs, GPUs, and TPUs. By using hardware accelerators like GPUs, JAX makes it possible to create high-performing machine learning applications.

Advanced JAX features

JAX also has some advanced features that make it work better overall:

  • Vectorization with vmap and pmap. You need enough hardware and resources to handle big calculations when making and training neural networks. JAX improves performance using auto-vectorization with the vmap() technique. The vmap() method makes it easier to manage loops by turning them into simpler forms. The built-in pmap() method also lets functions run at the same time on XLA devices, including TPUs and GPUs, making calculations more efficient.
  • XLA (Accelerated Linear Algebra). XLA changes and compiles your NumPy code so it can run on hardware accelerators like TPUs and GPUs. When used with just-in-time (JIT) compilation, XLA can optimize models and make calculations faster.
  • Function transformations. JAX provides automatic differentiation using the built-in grad() method to make it easier to change and transform functions efficiently.

JAX in real-world applications

From scientific computing to machine learning research, JAX is used in many different areas, helping people and organizations create high-performing neural networks. Real-world uses of JAX include:

  • DeepMind. DeepMind uses JAX to improve its research, especially when in training and testing AI systems. With JAX, DeepMind can also develop and scale AI models faster, reducing the time it takes to bring different AI products to market. JAX’s features, like automatic differentiation, auto-vectorization, and JIT compilation, allow DeepMind researchers to build efficient algorithms for better experimentation and improved performance.
  • Reinforcement learning. This type of machine learning needs a lot of computation to train neural agents to do specific tasks. Once again, JAX’s support for JIT, XLA, and automatic differentiation makes it possible to train neural models faster using methods like gradient descent.
  • High-performance computing. Whether it’s scientific or numerical computing, machine learning applications need efficient computation for better performance. JAX improves computation performance through auto differentiation, vectorization, XLA, and JIT compilation features.
  • Drug discovery. JAX’s efficient computation and ability to handle large-scale data make it valuable for drug discovery. Researchers can use JAX to develop and train machine learning models that help identify potential drug candidates, predict their properties, and optimize their structures, accelerating the drug discovery process and reducing costs.

Getting started with JAX

To get started with the JAX Python Library, you first need to install Python on your computer. If you have Python installed, you can run the following command in the terminal to install JAX:

--CODE language-markup--
pip install jax

An alternative to accessing JAX is through an online development environment like Google Colab, which comes pre-installed with JAX.

In Google Colab, you can create a notebook and import JAX into your project using the following command:

--CODE language-markup--
import jax

For example, we wrote the following program to calculate the sum of numbers in two arrays. We imported JAX and the NumPy array, used them in our project, and accessed the add() function to determine the sum of integers in an array. See below:

--CODE language-markup line-numbers--
import jax
import jax.numpy as jnp
def calculate_sum(a, b):
x = jnp.add(a, b)
return x # It will return x which is the sum
array_one = jnp.array(20)
array_two = jnp.array(5)
result = calculate_sum(array_one, array_two)
print(“The sum is:”, result)

The program above shows you the basics of how to import JAX and use its different functions in your project. But in typical machine learning projects, you might need to use more complex libraries like Flax and CUDA to add extra features to your application. Take a look at the official JAX documentation to learn more about how you can use other functions in your work.

Community and resources

JAX is open source, meaning you can access and use it for free in your projects. You can go through JAX source code and documentation on GitHub to see how it works behind the scenes. Following open issues and pull requests on GitHub can also allow you to understand issues other developers are experiencing while using JAX, and how to best navigate them.

And as JAX continues to grow in popularity, you’ll be more able to connect with peers and mentors on platforms like Reddit and StackOverflow for support and continuous learning.  

Find top machine learning experts on Upwork

Machine learning lets you create apps that can do things like work with natural language, detect objects, analyze data, and more. JAX, a new tool in the deep learning ecosystem, provides better and more efficient ways to create high-performing neural networks and models. Techniques like JIT compilation and XLA allow your project to run faster and on different platforms.

Even though JAX is incredibly powerful, people need the right technical skills to make the most of what it can do. Understanding how machine learning works takes time, especially for beginners. Consider working with machine learning experts on Upwork to help you integrate AI into your workflow.

And if you’re an expert looking for work, Upwork can connect you with different machine learning jobs to help grow your portfolio. Get started today!

Heading
asdassdsad
Projects related to this article:
No items found.

Author Spotlight

Google JAX: What It Is and How To Get Started
The Upwork Team

Upwork is the world’s work marketplace that connects businesses with independent talent from across the globe. We serve everyone from one-person startups to large, Fortune 100 enterprises with a powerful, trust-driven platform that enables companies and talent to work together in new ways that unlock their potential.

Latest articles

Popular articles

Create your freelance profile today