Nx (Numerical Elixir) is now publicly available

  • José Valim
  • February 18th, 2021
  • nx, gpu, exla

Sean Moriarity and I are glad to announce that the project we have been working on for the last 3 months, Nx, is finally publicly available on GitHub. Our goal with Nx is to provide the foundation for Numerical Elixir.

In this blog post, I am going to outline the work we have done so far, some of the design decisions, and what we are planning to explore next. If you are looking for other resources to learn about Nx, you can hear me unveiling Nx on the ThinkingElixir podcast.


Nx is a multi-dimensional tensors library for Elixir with multi-staged compilation to the CPU/GPU. Let’s see an example:

iex> t = Nx.tensor([[1, 2], [3, 4]])
    [1, 2],
    [3, 4]

As you see, tensors have a type (s64) and a shape (2x2). Tensor operations are also done with the Nx module. To implement the Softmax function:

iex> t = Nx.tensor([[1, 2], [3, 4]])
iex> Nx.divide(Nx.exp(t), Nx.sum(Nx.exp(t)))
    [0.03205860328008499, 0.08714431874203257],
    [0.23688281808991013, 0.6439142598879722]

The high-level features in Nx are:

  • Typed multi-dimensional tensors, where the tensors can be unsigned integers (u8, u16, u32, u64), signed integers (s8, s16, s32, s64), floats (f32, f64) and brain floats (bf16);

  • Named tensors, allowing developers to give names to each dimension, leading to more readable and less error prone codebases;

  • Automatic differentiation, also known as autograd. The grad function provides reverse-mode differentiation, useful for simulations, training probabilistic models, etc;

  • Tensors backends, which enables the main Nx API to be used to manipulate binary tensors, GPU-backed tensors, sparse matrices, and more;

  • Numerical definitions, known as defn, provide multi-stage compilation of tensor operations to multiple targets, such as highly specialized CPU code or the GPU. The compilation can happen either ahead-of-time (AOT) or just-in-time (JIT) with a compiler of your choice;

For Python developers, Nx currently takes its main inspirations from Numpy and JAX but packaged into a single unified library.

Our initial efforts have focused on the underlying abstractions. For example, while Nx implements dense tensors out-of-the-box, we also want the same high-level API to be valid for sparse tensors. You should also be able to use all functions in the Nx module with tensors that are backed by Elixir binaries and with tensors that are stored directly in the GPU.

By ensuring the underlying tensor backend is ultimately replaceable, we can build an ecosystem of libraries on top of Nx, and allow end-users to experiment with different backends, hardware, and approaches to run their software on.

Nx’s mascot is the Numbat, a marsupial native to southern Australia. Unfortunately the Numbat are endangered and it is estimated to be fewer than 1000 left. If you are excited about Nx, consider donating to Numbat conservation efforts, such as Project Numbat and Australian Wildlife Conservancy.

Numerical definitions

One of the most important features in Nx is the numerical definition, called defn. Numerical definitions are a subset of Elixir tailored for numerical computing. Here is the softmax formula above, now written with defn:

defmodule Formula do
  import Nx.Defn

  defn softmax(t) do
    Nx.exp(t) / Nx.sum(Nx.exp(t))

The first difference we see with defn is that Elixir’s built-in operators have been augmented to also work with tensors. Effectively, defn replaces Elixir’s Kernel with Nx.Defn.Kernel.

However, defn goes even further. When using defn, Nx builds a computation with all of your tensor operations. Let’s inspect it:

defn softmax(t) do
  inspect_expr(Nx.exp(t) / Nx.sum(Nx.exp(t)))

Now when invoked, you will see this printed:

iex(3)> Formula.softmax(Nx.tensor([[1, 2], [3, 4]]))
  parameter a                                 s64[2][2]
  b = exp [ a ]                               f64[2][2]
  c = exp [ a ]                               f64[2][2]
  d = sum [ c, axes: nil, keep_axes: false ]  f64
  e = divide [ b, d ]                         f64[2][2]
    [0.03205860328008499, 0.08714431874203257],
    [0.23688281808991013, 0.6439142598879722]

This computation graph can also be transformed programatically. The transformation is precisely how we implement automatic differentiation, also known as autograd, by traversing each node and computing their derivative:

defn grad_softmax(t) do
  grad(t, Nx.exp(t) / Nx.sum(Nx.exp(t)))

Finally, this computation graph can also be handed out to different compilers. As an example, we have implemented bindings for Google’s XLA compiler, called EXLA. We can ask the softmax function to use this new compiler with a module attribute:

@defn_compiler {EXLA, client: :host}
defn softmax(t) do
  Nx.exp(t) / Nx.sum(Nx.exp(t))

Once softmax is called, Nx.Defn will invoke EXLA to emit a just-in-time and highly-specialized compiled version of the code, tailored to the tensor type and shape. By passing client: :cuda or client: :rocm, the code can be compiled for the GPU. For reference, here are some benchmarks of the function above when called with a tensor of one million random float values on different clients:

Name                       ips        average  deviation         median         99th %
xla gpu f32 keep      15308.14      0.0653 ms    ±29.01%      0.0638 ms      0.0758 ms
xla gpu f64 keep       4550.59        0.22 ms     ±7.54%        0.22 ms        0.33 ms
xla cpu f32             434.21        2.30 ms     ±7.04%        2.26 ms        2.69 ms
xla gpu f32             398.45        2.51 ms     ±2.28%        2.50 ms        2.69 ms
xla gpu f64             190.27        5.26 ms     ±2.16%        5.23 ms        5.56 ms
xla cpu f64             168.25        5.94 ms     ±5.64%        5.88 ms        7.35 ms
elixir f32                3.22      311.01 ms     ±1.88%      309.69 ms      340.27 ms
elixir f64                3.11      321.70 ms     ±1.44%      322.10 ms      328.98 ms

xla gpu f32 keep      15308.14
xla gpu f64 keep       4550.59 - 3.36x slower +0.154 ms
xla cpu f32             434.21 - 35.26x slower +2.24 ms
xla gpu f32             398.45 - 38.42x slower +2.44 ms
xla gpu f64             190.27 - 80.46x slower +5.19 ms
xla cpu f64             168.25 - 90.98x slower +5.88 ms
elixir f32                3.22 - 4760.93x slower +310.94 ms
elixir f64                3.11 - 4924.56x slower +321.63 ms

Where keep indicates the tensor was kept on the device instead of being transferred back to Elixir. You can see the benchmark in the bench directory and find some examples in the examples directory of the EXLA project.

Compiling numerical definitions

Before moving forward, it is important for us to take a look at how numerical definitions are compiled. For example, take the softmax function:

defn softmax(t) do
  Nx.exp(t) / Nx.sum(Nx.exp(t))

One might think that Elixir takes the AST of the softmax function above and compiles it directly to the GPU. However, that’s not the case! Numerical definitions are first compiled to Elixir code that will emit the computation graph and this computation graph is then compiled to the GPU. The multiple stages go like this:

Elixir AST
-> compiles to .beam (Erlang VM bytecode)
   -> executes into defn AST
      -> compiles to GPU

This multi-stage programming is made possible thanks to Elixir macros. For example, when you see a conditional inside defn, that conditional looks exactly like Elixir conditionals, but it will be compiled to an accelerator:

defn softmax(t) do
  if Nx.any?(t) do

In a nutshell, defn provides us with a subset of Elixir for numerical computations that can be compiled to specific hardware, such as CPU, GPU, and other accelerators. All of this was possible without making changes or forking the language.

And while defn is a subset of the language, it is a considerable one. You will find support for:

  • Mathematical operators
  • Pipes (|>), module attributes, the access syntax (i.e. tensor[1][1..-1]), etc
  • Elixir macros constructs (imports, aliases, etc)
  • Control-flow with conditionals (both if and cond), loops (coming soon), etc
  • Transformations, an explicit mechanism to invoke Elixir code from a defn (which enables constructs such as grad)

And more coming down the road.

Why functional programming?

At this point, you may be wondering: is functional programming a good fit for numerical computing? One of the main concerns is that immutability can be expensive when working with large blobs of memory. And that’s a valid concern! In fact, when using the default tensor backend, tensors will be backed by Elixir binaries which are copied on every operation. That’s why it was critical for us to design Nx with pluggable backends from day one.

As we move to higher-level abstractions, such as numerical definitions, we will start to reap the benefits of functional programming.

For example, in order to build computation graphs, immutability becomes an indispensable tool both in terms of implementation and reasoning. The JAX library for Python, which has been one of the guiding lights for Nx design, also promotes functional and immutable principles:

JAX is intended to be used with a functional style of programming

JAX Docs

Unlike NumPy arrays, JAX arrays are always immutable

JAX Docs

Similarly, existing frameworks like Thinc.ai argue that functional programming can provide better abstractions and more composable building blocks for deep learning libraries.

We hope that, by exploring these concepts in a language that is functional by design, Elixir can bring new ideas and insights at the higher-level.

What is next?

There is a lot of work ahead of us and we definitely cannot tackle all of it alone. Generally speaking, here are some broad areas the numerical computing community in Elixir should investigate in the long term:

  • Visual tools: such as plotting libraries and integration with notebooks for interactive programming

  • Machine learning tools: while Sean is already exploring some designs for neural networks, we will likely also see interest on tools for supervised learning (classification/regression), dimensionality reduction, clustering, etc. My hope is that those libraries can be implemented with defn, allowing them to benefit from custom backends and custom compilers

  • Nx: there is a lot to explore inside Nx itself, such as better support for linear algebra operations and perhaps even FFT. I am also looking forward to see how folks will experiment with backends that are optimized to work with tensors that exhibit certain properties, such as sparse tensors and hermetian matrices

  • defn: while defn already supports grad, that’s just one of many transformations we can automatically perform. We could also support auto-batching (also known as vmap), inverses, Jacobian/Hessian matrices, etc

  • Integration: there are two ways we can speed up Nx tensors, either by using custom backends (eager) or by using custom compilers (lazy). There are many options we can consider here, such as libtorch and eigen as backends, and a growing list of tensor compilers. Since we aim to put Nx as the building block of the ecosystem, we hope that by integrating new compilers and backends, developers and researchers will have the option to experiment with many different performance and usage profiles

For now, we have created an Nx-related mailing list where we can coordinate those ideas and use for general discussion.

For the short-term, Sean and I are working on features like tensor streaming, communication across devices, as well as AOT compilation. The latter might be particularly useful for Nerves. We are also investigating how to integrate dataframes directly into Nx, including defn support. By supporting dataframes, we hope to have a single library to tackle different steps of a machine learning pipeline, where everything can be inlined and compiled into a single GPU executable. For this, we are looking into xarray’s datasets and TensorFlow feature columns.

Given there is a lot of explore, we are also interested in feedback and experiences, especially missing features we should prioritize. You can find a list of other planned features in our issues tracker.

Happy computing!