cafecomtech
Assinar
FERRAMENTAS · NVIDIA · 20 ABR 2026

NVIDIA NeMo RL agora roda treino de Reinforcement Learning com FP8 ponta a ponta

A NVIDIA detalhou uma receita de FP8 ponta a ponta no NeMo RL que acelera treino de RL (Reinforcement Learning) com GRPO em 15-25% nas camadas lineares e até ~48% quando estende pra KV cache e atenção. Testado em Llama 3.1 8B Instruct e Qwen3-30B, bate a acurácia do baseline BF16 usando importance sampling pra fechar o gap numérico entre vLLM (geração) e Megatron Core (treino).

NVIDIA NeMo RL agora roda treino de Reinforcement Learning com FP8 ponta a ponta
NVIDIA NeMo RL agora roda treino de Reinforcement Learning com FP8 ponta a ponta foi anunciado em 20 de abril às 22:52, horário de Brasília. fonte original →
por que importa

FP8 em RL é terreno minado por causa do desalinhamento entre engines (vLLM vs Megatron). A receita da NVIDIA resolve na marra com importance sampling e recalibração dinâmica de escala QKV. Quem treina modelo de raciocínio open source deve testar antes de assumir que BF16 é o único caminho seguro.

Conforme os LLMs migram de geração de texto simples pra raciocínio complexo, RL (Reinforcement Learning) virou peça central. Algoritmos tipo GRPO (Group Relative Policy Optimization) puxam essa transição, deixando modelos de raciocínio melhorarem continuamente via feedback iterativo. Diferente do fine-tuning supervisionado padrão, o loop de RL parte em duas fases bem distintas e intensas: geração (com requisito rígido de latência) e treino (que pede throughput alto).

Pra deixar isso viável, pesquisadores estão apelando pra tipos de baixa precisão tipo FP8. Em cenários onde a geração é limitada por memory bandwidth da GPU, usar parâmetros em baixa precisão ainda ajuda por causa do menor número de bytes por parâmetro.

O post da NVIDIA mergulha nos desafios sistêmicos de RL em baixa precisão e mostra como o NeMo RL (biblioteca open source do framework NeMo) acelera o workload sem perder acurácia.

FP8 nas camadas lineares

A receita usa FP8 com quantização por blocos, mesma abordagem do DeepSeek-V3 Technical Report. Formato dos tensores nas camadas lineares:

  • Pesos: FP8 (E4M3), granularidade [128, 128], escala FP32 por bloco
  • Ativações de entrada: FP8 (E4M3), granularidade [1, 128], escala FP32 por bloco
  • Gradientes de saída: FP8 (E4M3), granularidade [1, 128], escala FP32 por bloco

Com isso, as lineares rodam em math FP8, que tem 2x o throughput de pico vs BF16. Attention, normalização, funções não-lineares e projeções de saída continuam em BF16.

O problema do desalinhamento numérico

Pipelines de RL usam engines separadas: vLLM pros rollouts e NVIDIA Megatron Core pro treino. Cada uma tem kernels CUDA customizados. Isso introduz diferenças numéricas que se amplificam em baixa precisão por causa da lógica extra de quantização/dequantização.

A NVIDIA mede esse desvio como token multiplicative probability error: alinhamento perfeito dá 1, e valores <1.03-1.05 são considerados aceitáveis sem técnicas adicionais.

Três receitas testadas

  • Baseline: BF16 em geração e treino
  • Candidata 1: FP8 só na geração, treino em BF16
  • Final: FP8 ponta a ponta (geração + treino)

O achado interessante: a receita final (FP8 em tudo) mostra menor desalinhamento numérico entre geração e treino do que a candidata 1 (só geração em FP8). O baseline BF16 sempre tem o desvio mais baixo, como esperado.

A intuição aqui faz sentido: quando as duas engines erram do mesmo jeito, se alinham melhor do que quando uma erra e a outra não.

Importance sampling fecha o gap

Importance sampling corrige o descompasso de distribuição entre o modelo que gera os dados e o modelo em treinamento. É um peso por token multiplicado pela loss.

Experimentos mostram:

  • Candidata 1 (FP8 só geração): importance sampling diminui o gap de acurácia vs BF16, mas não fecha.
  • Final (FP8 ponta a ponta): importance sampling fecha completamente o gap vs treino BF16.

Resultados em modelos densos: Llama 3.1 8B Instruct

Acurácia de validação no treino GRPO com dataset de matemática, 4000 steps:

  • BF16: 0.616
  • FP8 só geração: 0.586
  • FP8 ponta a ponta: 0.613

Em throughput, a receita FP8 entrega >15% de ganho consistente vs BF16. O speedup teórico de 2x do FP8 não se realiza na prática porque só as lineares pegam o math mais rápido — attention e elementwise ficam iguais, e os kernels extras de quantização adicionam overhead. Os 15-25% batem com testes isolados do vLLM. Com otimizações adicionais tipo fusão de kernels de quantização, projeta-se chegar a 1.25x.

Em MoE: Qwen3-30B

Mesmo experimento em modelos Mixture of Experts. Qwen3-30B com dataset OpenMathInstruct-2, 8 nodes de H100: curvas de acurácia batem com BF16. Ganho de velocidade ainda em investigação.

Estendendo FP8 pra KV cache e atenção

Em transformer, as lineares não são o único gargalo. Crescimento de KV cache e computação de atenção dominam o tempo de rollout quando as sequências de saída são longas (OSL alto), saturando memory bandwidth.

Implementar FP8 pro KV cache em cenário de RL tem um problema único: os pesos da policy mudam a cada step. Diferente de inferência estática, onde calibração rola uma vez, RL exige tratamento dinâmico das escalas de quantização.

O NeMo RL resolve assim:

  • Recalibração: no fim de cada step de treino, o trainer recalibra as escalas de Query, Key e Value (QKV) usando os pesos atualizados.
  • Seleção de dados: calibração usa os dados de treino (prompts e respostas geradas) pra refletir a distribuição atual.
  • Sincronização: as novas escalas são sincronizadas com a engine de inferência (vLLM) pro próximo rollout.

O overhead de calibração é mínimo: ~2-3% do tempo total do step.

Resultados de FP8 em KV cache + atenção

No Qwen3-8B-Base com GRPO, FP8 aplicado no rollout e BF16 no treino: a divergência KL fica um pouco maior ao quantizar KV cache e atenção juntos (erros compostos), mas habilitar token-level truncated importance sampling alinha a acurácia de validação com o baseline BF16.

Os ganhos de velocidade:

  • FP8 em KV cache + atenção: +~30% de speedup no rollout vs config W8A8 só nas lineares
  • Total vs BF16 baseline: ~48% de speedup

Os ganhos são mais pronunciados em respostas longas, onde atenção pega uma fatia maior do workload.

Como habilitar

Pra ligar FP8 nas camadas lineares em geração e treino, configure via YAML. Pra KV cache e atenção, basta o kv_cache_dtype em vllm_cfg:

policy:
  generation:
    vllm_cfg:
      precision: fp8        # FP8 pras lineares
      kv_cache_dtype: fp8   # FP8 pro KV cache

A recalibração das escalas QKV e sincronização com vLLM são automáticas.

Opções avançadas

  • Manter as primeiras N e/ou últimas M camadas transformer em BF16 durante geração
  • Usar escala do tipo potência de 2 em vez de FP32 (pow2_weight_scaling_factors, pow2_activation_scaling_factors)
  • Escolher variantes da receita FP8 no backend Megatron Core (fp8_recipe: "blockwise" e afins)

Pra quem treina modelo de raciocínio em GPU própria, esse tipo de receita é o que separa rodar em 8 nodes ou precisar de 12. Os 48% de speedup no rollout em sequências longas é o número que interessa, já que RL com raciocínio tende a gerar respostas cada vez maiores.

Começando

As receitas llama-3.1-8b e moonlight-16b estão no GitHub do NeMo RL como ponto de partida.

0

☕ comentários · 0

Entra pra deixar um comentário. Magic link, sem senha.
Sem comentários ainda. Seja o primeiro.

Mateus Veloso

Tech lead. Mantém o cafecomtech quando não tá debugando sistema em produção.