Elixir’s Nx (Numerical Elixir) ecosystem brings machine learning capabilities to the BEAM. You can now train neural networks, run inference, and deploy models—all within Elixir’s fault-tolerant, concurrent runtime.

The Nx Ecosystem

  • Nx: Numerical computing library (like NumPy)
  • EXLA: XLA compiler backend (GPU/TPU acceleration)
  • Axon: Neural network library (like Keras/PyTorch)
  • Bumblebee: Pre-trained transformer models
  • Scholar: Traditional ML algorithms

Getting Started

Add dependencies to mix.exs:

1
2
3
4
5
6
7
8
9
defp deps do
  [
    {:nx, "~> 0.7"},
    {:exla, "~> 0.7"},
    {:axon, "~> 0.6"},
    {:bumblebee, "~> 0.5"},
    {:scholar, "~> 0.3"}
  ]
end

Configure EXLA as default backend:

1
2
# config/config.exs
config :nx, default_backend: EXLA.Backend

Nx Basics: Tensor Operations

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
# Create tensors
t1 = Nx.tensor([1, 2, 3, 4])
t2 = Nx.tensor([[1, 2], [3, 4]])

# Element-wise operations
Nx.add(t1, 10)
# => #Nx.Tensor<s64[4] [11, 12, 13, 14]>

# Matrix multiplication
a = Nx.tensor([[1, 2], [3, 4]])
b = Nx.tensor([[5, 6], [7, 8]])
Nx.dot(a, b)
# => #Nx.Tensor<s64[2][2] [[19, 22], [43, 50]]>

# Reductions
Nx.mean(t1)
# => #Nx.Tensor<f32 2.5>

# Broadcasting
Nx.add(Nx.tensor([[1], [2], [3]]), Nx.tensor([10, 20, 30]))
# => #Nx.Tensor<s64[3][3] [[11, 21, 31], [12, 22, 32], [13, 23, 33]]>

Defn: Compiled Numerical Functions

Use defn for JIT-compiled numerical functions:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
defmodule MyMath do
  import Nx.Defn

  defn softmax(t) do
    exp_t = Nx.exp(t - Nx.reduce_max(t))
    exp_t / Nx.sum(exp_t)
  end

  defn linear_regression_predict(x, w, b) do
    Nx.dot(x, w) + b
  end

  defn mse_loss(y_true, y_pred) do
    Nx.mean(Nx.pow(y_true - y_pred, 2))
  end
end

These compile to XLA for GPU/TPU execution.

Building Neural Networks with Axon

Define a Model

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
defmodule MyModel do
  def build_model(input_shape, num_classes) do
    Axon.input("input", shape: input_shape)
    |> Axon.dense(128, activation: :relu)
    |> Axon.dropout(rate: 0.3)
    |> Axon.dense(64, activation: :relu)
    |> Axon.dropout(rate: 0.3)
    |> Axon.dense(num_classes, activation: :softmax)
  end
end

model = MyModel.build_model({nil, 784}, 10)

Training Loop

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
defmodule Trainer do
  def train(model, train_data, epochs \\ 10) do
    model
    |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
    |> Axon.Loop.metric(:accuracy)
    |> Axon.Loop.run(train_data, %{}, epochs: epochs)
  end
end

# Prepare data as stream of batches
train_data = 
  Stream.zip(x_train_batches, y_train_batches)
  |> Stream.map(fn {x, y} -> %{"input" => x, "label" => y} end)

# Train
trained_state = Trainer.train(model, train_data, 20)

Custom Training Loops

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
defmodule CustomTrainer do
  import Nx.Defn

  def train(model, data, opts \\ []) do
    epochs = Keyword.get(opts, :epochs, 10)
    lr = Keyword.get(opts, :learning_rate, 0.001)

    {init_fn, predict_fn} = Axon.build(model)
    params = init_fn.(Nx.template({1, 784}, :f32), %{})

    optimizer = Polaris.Optimizers.adam(learning_rate: lr)
    optimizer_state = Polaris.Updates.init(optimizer, params)

    Enum.reduce(1..epochs, {params, optimizer_state}, fn epoch, {params, opt_state} ->
      {loss, params, opt_state} = 
        train_epoch(data, params, opt_state, predict_fn, optimizer)
      
      IO.puts("Epoch #{epoch}, Loss: #{Nx.to_number(loss)}")
      {params, opt_state}
    end)
  end

  defp train_epoch(data, params, opt_state, predict_fn, optimizer) do
    Enum.reduce(data, {Nx.tensor(0.0), params, opt_state}, fn batch, {_, params, opt_state} ->
      {loss, grads} = value_and_grad(params, fn p ->
        preds = predict_fn.(p, batch.input)
        cross_entropy_loss(batch.labels, preds)
      end)

      {updates, new_opt_state} = Polaris.Updates.update(optimizer, grads, opt_state, params)
      new_params = Polaris.Updates.apply_updates(params, updates)

      {loss, new_params, new_opt_state}
    end)
  end

  defnp cross_entropy_loss(labels, predictions) do
    -Nx.mean(Nx.sum(labels * Nx.log(predictions + 1.0e-7), axes: [-1]))
  end
end

Using Pre-trained Models with Bumblebee

Load and use transformer models:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Load a text classification model
{:ok, model_info} = Bumblebee.load_model({:hf, "distilbert-base-uncased-finetuned-sst-2-english"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "distilbert-base-uncased"})

serving = Bumblebee.Text.text_classification(model_info, tokenizer,
  defn_options: [compiler: EXLA]
)

# Use in Phoenix LiveView or API
Nx.Serving.batched_run(serving, "This product is amazing!")
# => %{predictions: [%{label: "POSITIVE", score: 0.9998}]}

Text Generation

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
{:ok, model_info} = Bumblebee.load_model({:hf, "gpt2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "gpt2"})

serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config,
  defn_options: [compiler: EXLA],
  stream: true
)

for token <- Nx.Serving.batched_run(serving, "The future of AI is") do
  IO.write(token)
end

Serving ML Models in Production

With Nx.Serving

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
defmodule MyApp.MLServing do
  def child_spec(_opts) do
    {model_info, tokenizer} = load_model()
    
    serving = Bumblebee.Text.text_classification(model_info, tokenizer,
      defn_options: [compiler: EXLA],
      batch_size: 10,
      batch_timeout: 100
    )

    Nx.Serving.new(serving)
    |> Nx.Serving.client_preprocessing(fn input ->
      {Nx.Batch.Stack.new([input]), :client_info}
    end)
  end
  
  defp load_model do
    {:ok, model_info} = Bumblebee.load_model({:hf, "model-name"})
    {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "model-name"})
    {model_info, tokenizer}
  end
end

# Add to supervision tree
children = [
  {Nx.Serving, name: MyApp.MLServing, serving: MyApp.MLServing.child_spec([])}
]

In Phoenix Controller

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
defmodule MyAppWeb.ClassifyController do
  use MyAppWeb, :controller

  def classify(conn, %{"text" => text}) do
    result = Nx.Serving.batched_run(MyApp.MLServing, text)
    
    json(conn, %{
      label: hd(result.predictions).label,
      confidence: hd(result.predictions).score
    })
  end
end

Traditional ML with Scholar

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
alias Scholar.Linear.LogisticRegression
alias Scholar.Preprocessing.StandardScaler

# Prepare data
{x_train, x_test, y_train, y_test} = split_data(features, labels)

# Scale features
scaler = StandardScaler.fit(x_train)
x_train_scaled = StandardScaler.transform(scaler, x_train)
x_test_scaled = StandardScaler.transform(scaler, x_test)

# Train model
model = LogisticRegression.fit(x_train_scaled, y_train,
  num_classes: 3,
  iterations: 1000
)

# Predict
predictions = LogisticRegression.predict(model, x_test_scaled)

# Evaluate
accuracy = Scholar.Metrics.Classification.accuracy(y_test, predictions)

Conclusion

Elixir’s ML ecosystem enables:

  • Native numerical computing with Nx
  • GPU/TPU acceleration via EXLA
  • Deep learning with Axon
  • Pre-trained transformers through Bumblebee
  • Production serving integrated with OTP

At Sajima Solutions, we leverage Elixir’s ML capabilities for applications requiring both AI and robust concurrent systems. Contact us to explore ML solutions for your business.