Elixir and Machine Learning: Nx v0.1 released!

Nx

We are glad to announce Nx (Numerical Elixir) v0.1 has been released!

For those unfamiliar, Elixir is a dynamic, functional language for building scalable and maintainable applications. Elixir leverages the Erlang VM, known for running low-latency, distributed, and fault-tolerant systems.

Numerical Elixir is an effort, publicly unveiled almost a year ago, to bring Elixir to the world of numerical computing and machine learning. The foundation of this effort is a library called Nx, that brings multi-dimensional arrays (tensors) and just-in-time compilation of numerical Elixir to both CPU and GPU. As we will see, the mixture of functional programming and tensor compilers provide an elegant and powerful abstraction for emitting highly specialized code.

In this blog post, we will discuss the current state of Nx, some of its upcoming features, and take a look at its growing ecosystem.

Nx’s mascot is the Numbat, a marsupial native to southern Australia. Unfortunately the Numbat are endangered and it is estimated to be fewer than 1000 left. If you are excited about Nx, consider donating to Numbat conservation efforts, such as Project Numbat and Australian Wildlife Conservancy.

Nx 101

Let’s start with a very quick introduction to Nx. Let’s create a two-dimensional tensor:

iex> t = Nx.tensor([[1, 2], [3, 4]])
#Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>

Tensors can be unsigned integers (u8, u16, u32, u64), signed integers (s8, s16, s32, s64), floats (f16, f32, f64), and brain floats (bf16). Each dimension of a tensor can be optionally named.

To implement a numerically stable version of the Softmax function using Nx:

iex> t = Nx.tensor([[1, 2], [3, 4]])
iex> normalized = Nx.subtract(t, Nx.reduce_max(t))
iex> Nx.divide(Nx.exp(normalized), Nx.sum(Nx.exp(normalized)))
#Nx.Tensor<
  f32[2][2]
  [
    [0.032058604061603546, 0.08714432269334793],
    [0.23688282072544098, 0.6439142227172852]
  ]
>

The computations above are happening in pure Elixir. However, you can plug a custom backend, such as Torchx, and have the computation be performed by state-of-the-art libraries such as LibTorch, on both CPU and GPU:

iex> Nx.default_backend(Torchx.Backend)
iex> t = Nx.tensor([[1, 2], [3, 4]])
iex> normalized = Nx.subtract(t, Nx.reduce_max(t))
iex> Nx.divide(Nx.exp(t), Nx.sum(Nx.exp(t)))
#Nx.Tensor<
  Torchx.Backend
  f32[2][2]
  [
    [0.032058604061603546, 0.08714432269334793],
    [0.23688282072544098, 0.6439142227172852]
  ]
>

The full power of Nx comes from defn, which stands for numerical definitions. Numerical definitions are a subset of Elixir tailored for numerical computing:

defmodule MyModule do
  import Nx.Defn

  defn softmax(t) do
    normalized = t - Nx.reduce_max(t)
    Nx.exp(normalized) / Nx.sum(Nx.exp(normalized))
  end
end

Inside defn we can use Elixir regular operators and they are all translated to their equivalent tensor operations. You have access to many of the language features and data types, such as macros, the beloved pipe operator, pattern-matching, maps, and more.

When invoked, the code above takes the types and shapes of the arguments and compiles them to highly optimized code to run on the CPU, the GPU, or even Cloud TPUs. For example, we can use Google’s XLA compiler via EXLA:

iex> Nx.Defn.default_options(compiler: EXLA, client: :cuda)
iex> MyModule.softmax(Nx.tensor([[1, 2], [3, 4]]))
#Nx.Tensor<
  f32[2][2]
  EXLA.DeviceBackend(cpu)
  [
    [0.032058604061603546, 0.08714432269334793],
    [0.23688282072544098, 0.6439142227172852]
  ]
>

For reference, here are some benchmarks of the function above when called with a tensor of one million random float values:

Name                       ips        average  deviation         median         99th %
xla gpu f32 keep      15308.14      0.0653 ms    ±29.01%      0.0638 ms      0.0758 ms
xla gpu f64 keep       4550.59        0.22 ms     ±7.54%        0.22 ms        0.33 ms
xla cpu f32             434.21        2.30 ms     ±7.04%        2.26 ms        2.69 ms
xla gpu f32             398.45        2.51 ms     ±2.28%        2.50 ms        2.69 ms
xla gpu f64             190.27        5.26 ms     ±2.16%        5.23 ms        5.56 ms
xla cpu f64             168.25        5.94 ms     ±5.64%        5.88 ms        7.35 ms
elixir f32                3.22      311.01 ms     ±1.88%      309.69 ms      340.27 ms
elixir f64                3.11      321.70 ms     ±1.44%      322.10 ms      328.98 ms

Comparison:
xla gpu f32 keep      15308.14
xla gpu f64 keep       4550.59 - 3.36x slower +0.154 ms
xla cpu f32             434.21 - 35.26x slower +2.24 ms
xla gpu f32             398.45 - 38.42x slower +2.44 ms
xla gpu f64             190.27 - 80.46x slower +5.19 ms
xla cpu f64             168.25 - 90.98x slower +5.88 ms
elixir f32                3.22 - 4760.93x slower +310.94 ms
elixir f64                3.11 - 4924.56x slower +321.63 ms

Nx and Machine learning

We have spent the last months maturing Nx towards Machine Learning and production use cases. Sean Moriarity has developed Axon, which we used to battle-test Nx and its automatic differentiation engine against several traditional and non-traditional neural networks.

For example, here is a Convolutional Neural Network model to train and classify the CIFAR-10 dataset implemented with Axon:

Axon.input(input_shape)
|> Axon.conv(32, kernel_size: {3, 3}, activation: :relu)
|> Axon.batch_norm()
|> Axon.max_pool(kernel_size: {2, 2})
|> Axon.conv(64, kernel_size: {3, 3}, activation: :relu)
|> Axon.batch_norm()
|> Axon.max_pool(kernel_size: {2, 2})
|> Axon.flatten()
|> Axon.dense(64, activation: :relu)
|> Axon.dropout(rate: 0.5)
|> Axon.dense(10, activation: :softmax)

You can find the whole example, including downloading, training, and inference of the dataset here. You can also find examples for generative, structured, and other vision-related neural networks.

To power the existing and upcoming functionality, we have brought many features to Nx. In particular:

  • We implemented streaming capabilities, which allows a program to be loaded into GPUs/TPUs, while we stream batches of inputs to it. This can be useful for distributed learning and also running inference efficiently in production.

  • We started working on a series of functions related to Linear Algebra under the Nx.LinAlg module, which are relevant for models that rely on matrix factorization.

  • We introduced while loops into numerical definitions, to support both static and dynamic unrolling of loops, which are handy in recurrent models (speech recognition, semantic parsing, sign language translation, etc).

  • We added hooks to numerical definitions, which allow developers to stream data out of GPUs/TPUs as computation happens. With this, you can debug system, monitoring the performance of models during training (think TensorBoard integration) and inference, and more.

There is still a lot of work ahead of us and you can follow the issues tracker for both Nx and Axon projects for more information.

The future of Nx

Over the last 10 months we have put a huge amount of work on making Nx the building block for numerical computing and machine learning in Elixir. The path we chose was not the only option available to us. For example, we could have:

The options above are extremely useful, especially if you want to quickly put a system in production. However, our goals are also to:

  • make Elixir a suitable platform for new Machine Learning developments

  • fully leverage the power provided by the platform Elixir runs on, the Erlang VM

  • provide consistency and stability, especially when working on a domain that is still actively evolving

For those reasons, we chose to invest on Nx as its own foundation, agnostic to any particular framework. The road is definitely longer but we also believe the pay-off will be higher too!

Plus, we are not alone! Many folks have joined the Machine Learning Working Group from the Erlang Ecosystem Foundation to bring other important projects to life, such as:

  • Axon - Nx-powered Neural Networks for Elixir, shown in the previous section

  • Explorer - dataframes (series and tabular data) for Elixir. It runs on Rust’s Polars for amazing performance

  • Livebook - interactive and collaborative code notebooks for Elixir. Once you install Livebook, there are several example notebooks available. We are also planning to port many of Axon examples to notebooks, you can track them in the notebooks directory

  • Scidata - download and normalize datasets related to science

There are also exciting projects being developed outside of the working group, such as OpenCV bindings via evision and others.

Here is a peek at what we expect to see in the near future, within Elixir’s Machine Learning ecosystem:

  • Integration between ONNX and Axon, allowing developers to bring trained models from other platforms into Elixir and vice-versa

  • Precompiled Explorer bindings, so developers can get started with Dataframes in Elixir without a need to have the Rust toolchain installed on their machines

  • Desktop app versions of Livebook, making it easier than ever for any developer to get Elixir code up and running on their machines

  • Support for checkpointing in Nx’ automatic differentiation system. Checkpoints reduce the memory usage at the cost of increased computation when calculating gradients, which is helpful when training large models

This is barely scratching the surface of what is possible. Here are some ideas to explore in the long term:

  • Support for other compilers and backends. Our bindings for Google XLA are quite complete and there is work in progress on LibTorch (contributions are welcome). We are also interested in exploring other options, such as Apache TVM.

  • Distributed training: in Machine Learning, “distributed” often stands for training across multiple GPUs. With Nx, we can mix the “distributed” meaning of Machine Learning with the “distributed” meaning of the Erlang VM, which is across multiple nodes.

  • Federated learning is a technique for training an algorithm across multiple edge devices. Federated training comes in different shapes, such as centralized - when there is a central server responsible for aggregating and coordinating devices - and decentralized. Elixir and the Erlang VM can shine under several scenarios, thanks to its orchestrating capabilities born from telecommunication and thanks to projects like Nerves.

And there are definitely other possibilities we haven’t even considered yet. I hope this shares some of our vision, ideas, and goals. If you are excited about these new possibilities, we welcome you to use, enjoy, and contribute to many of the projects above, or perhaps even start your own!

Happy coding!