1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
| defmodule CustomTrainer do
import Nx.Defn
def train(model, data, opts \\ []) do
epochs = Keyword.get(opts, :epochs, 10)
lr = Keyword.get(opts, :learning_rate, 0.001)
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 784}, :f32), %{})
optimizer = Polaris.Optimizers.adam(learning_rate: lr)
optimizer_state = Polaris.Updates.init(optimizer, params)
Enum.reduce(1..epochs, {params, optimizer_state}, fn epoch, {params, opt_state} ->
{loss, params, opt_state} =
train_epoch(data, params, opt_state, predict_fn, optimizer)
IO.puts("Epoch #{epoch}, Loss: #{Nx.to_number(loss)}")
{params, opt_state}
end)
end
defp train_epoch(data, params, opt_state, predict_fn, optimizer) do
Enum.reduce(data, {Nx.tensor(0.0), params, opt_state}, fn batch, {_, params, opt_state} ->
{loss, grads} = value_and_grad(params, fn p ->
preds = predict_fn.(p, batch.input)
cross_entropy_loss(batch.labels, preds)
end)
{updates, new_opt_state} = Polaris.Updates.update(optimizer, grads, opt_state, params)
new_params = Polaris.Updates.apply_updates(params, updates)
{loss, new_params, new_opt_state}
end)
end
defnp cross_entropy_loss(labels, predictions) do
-Nx.mean(Nx.sum(labels * Nx.log(predictions + 1.0e-7), axes: [-1]))
end
end
|