for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
input_b.split(batch_size),
input_br.split(batch_size)):
loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log())
for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
input_b.split(batch_size),
input_br.split(batch_size)):
loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log())