Como treinar e fazer finetuning de embeddings multimodais e rerankers com Sentence Transformers
Tom Aarsen publicou um guia prático mostrando como fazer finetuning de modelos multimodais de embedding no Sentence Transformers. O exemplo usa Qwen3-VL-Embedding-2B pra Visual Document Retrieval e chega a NDCG@10 de 0.947 contra 0.888 do modelo base, batendo até o Qwen3-VL-Embedding-8B (4x maior). Receita completa: dataset, loss, training args, evaluator e trainer.
Tom Aarsen, mantenedor do Sentence Transformers, soltou um tutorial completo de como treinar modelos multimodais de embedding e reranker na biblioteca. Depois do post anterior sobre inferência multimodal, agora o foco é treino: como você pega seus dados e faz finetuning em cima de um modelo tipo Qwen/Qwen3-VL-Embedding-2B.
O exemplo prático é Visual Document Retrieval (VDR): dada uma query em texto, achar a página de documento (imagem, com gráficos, tabelas e layout preservados) mais relevante num corpus. O modelo resultante, tomaarsen/Qwen3-VL-Embedding-2B-vdr, chegou em NDCG@10 de 0.947 contra 0.888 do base.
Por que finetunar
Modelos multimodais genéricos tipo Qwen3-VL-Embedding-2B são treinados pra performar razoável em tudo: image-text matching, VQA (Visual Question Answering), document understanding. Só que genérico raramente é o melhor pra tarefa específica.
VDR exige entender layout de documento, gráfico, tabela e texto junto. Skill bem diferente de casar foto de tênis com descrição de produto. Finetunando no domínio, o salto foi de 0.888 pra 0.947, passando até o Qwen3-VL-Embedding-8B (4x maior).
Na prática: antes de pagar por modelo maior, roda finetuning no 2B. O custo/benefício aqui é óbvio.
Os componentes do treino
Mesma receita do treino text-only:
- Model: o modelo multimodal pra finetunar
- Dataset: dados de treino e avaliação
- Loss Function: função que guia a otimização
- Training Arguments (opcional): hiperparâmetros
- Evaluator (opcional): métricas antes, durante e depois
- Trainer: junta tudo
O pipeline usa o mesmo SentenceTransformerTrainer do text-only. Diferença: o dataset contém imagens (ou outras modalidades) junto com texto, e o processor do modelo cuida do preprocessing automaticamente.
Modelo
Duas abordagens. A primeira, finetunar um modelo de embedding multimodal que já existe:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
"Qwen/Qwen3-VL-Embedding-2B",
model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"},
processor_kwargs={"min_pixels": 28 * 28, "max_pixels": 600 * 600},
)
processor_kwargs controla preprocessing (max_pixels maior = qualidade maior, memória maior). model_kwargs vai pro AutoModel.from_pretrained (precisão, attention implementation).
A segunda é partir de um checkpoint VLM cru, sem treino de embedding:
model = SentenceTransformer("Qwen/Qwen3-VL-2B")
Sentence Transformers detecta arquitetura, infere modalidades do processor e monta forward method e pooling automaticamente. Se não bater 100%, dá pra editar o sentence_bert_config.json.
Alternativa: Router
Em vez de um VLM único, dá pra compor encoders separados por modalidade usando o módulo Router. Útil quando você quer encoders leves e especializados em vez de um VLM grandão. A pegadinha: os espaços de embedding começam desalinhados, então precisa treinar pra alinhar (por isso a Dense layer de projeção).
Dataset
O exemplo usa tomaarsen/llamaindex-vdr-en-train-preprocessed, subset em inglês do llamaindex/vdr-multilingual-train. São ~500k pares query-imagem de PDFs públicos, com queries geradas sinteticamente por VLMs (gemini-1.5-pro e Qwen2-VL-72B).
A versão pre-processada filtra 53.512 samples em inglês e resolve 4 dos 16 hard negatives por sample em screenshots reais. Pro treino ele pega 10k samples com colunas query, image e negative_0 (triplets anchor/positive/hard negative). Pra eval usa 300 samples com os 4 hard negatives cada, pra formar corpus desafiador.
Formato
Regra igual ao text-only: colunas casam com a loss escolhida. Os inputs multimodais aceitam:
- Text: strings
- Image: PIL, paths, URLs, arrays numpy/torch
- Audio: paths, arrays, dicts com
array/sampling_rate, outorchcodec.AudioDecoder - Video: paths, arrays, dicts com
array/video_metadata, outorchcodec.VideoDecoder - Dicts multimodais:
{"text": ..., "image": ...}
O data collator chama model.preprocess() que detecta modalidade e aplica o preprocessing. Zero tokenização manual.
Loss: CachedMultipleNegativesRankingLoss + MatryoshkaLoss
Pra VDR, CachedMultipleNegativesRankingLoss é a escolha padrão. Aceita (query, positive) com N hard negatives. Empurra similaridade da query com o positive pra cima, e com todos os negatives pra baixo. Os negatives vêm de dois lugares:
- Hard negatives: as colunas
negative_Xexplícitas - In-batch negatives: positives e hard negatives dos outros samples do mesmo batch, reaproveitados de graça
Batch maior = mais negatives = sinal de treino mais forte. A variante "cached" usa gradient caching pra viabilizar batch grande mesmo com GPU limitada. O mini_batch_size=1 é crítico pra VLM grande:
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=1)
loss = MatryoshkaLoss(model, loss, matryoshka_dims=[2048, 1536, 1024, 512, 256, 128, 64])
O MatryoshkaLoss faz o modelo funcionar bem em várias dimensionalidades. Útil pra multimodal porque o Qwen3-VL gera 2048 dims: em deploy, você trunca pra 256 ou 128 e busca mais rápido com perda mínima.
Training Arguments
Destaques da config:
bf16=True: bfloat16 tem estabilidade numérica melhor que float16 pra VLMbatch_sampler=BatchSamplers.NO_DUPLICATES: garante que in-batch negatives sejam realmente diferentesper_device_train_batch_size=64: parece alto pra VLM de 2B, mas o gradient caching commini_batch_size=1resolveeval_steps=0.1,save_steps=0.1: eval e save a cada 10% de época
Evaluator
InformationRetrievalEvaluator calcula NDCG@10, MAP e Recall@k. Queries são texto, corpus é imagem (positives + hard negatives com IDs offsetados pra não colidir). batch_size=1 pra não estourar memória.
Resultados
Depois de 1 época só, o modelo finetunado chega em NDCG@10 de 0.947 no eval set (300 queries, 1500 docs). Comparação com 20 modelos:
| Modelo | Params | NDCG@10 |
|---|---|---|
| tomaarsen/Qwen3-VL-Embedding-2B-vdr | 2.1B | 0.947 |
| Qwen/Qwen3-VL-Embedding-8B | 8.1B | 0.923 |
| nvidia/omni-embed-nemotron-3b | 4.7B | 0.915 |
| nvidia/llama-nemotron-embed-vl-1b-v2 | 1.7B | 0.912 |
| nomic-ai/nomic-embed-multimodal-7b | 8.3B | 0.912 |
| llamaindex/vdr-2b-multi-v1 | 2.2B | 0.912 |
| Qwen/Qwen3-VL-Embedding-2B (base) | 2.1B | 0.888 |
| sentence-transformers/clip-ViT-L-14 | 428M | 0.611 |
O 2B finetunado bate o 8B do mesmo fornecedor. Se você tem um caso de uso específico e dados minimamente decentes, finetuning em modelo menor é quase sempre melhor que pegar o maior genérico.
Matryoshka na prática
Pico em 2048 dims (0.948), mas fica dentro de 0.3% do pico até 512 (4x menor), e mantém 92% do pico até 64 (32x menor):
| Dims | Base | Finetunado |
|---|---|---|
| 2048 | 0.8961 (100%) | 0.9480 (100%) |
| 1024 | 0.8941 (99.8%) | 0.9464 (99.8%) |
| 512 | 0.8760 (97.8%) | 0.9451 (99.7%) |
| 256 | 0.8347 (93.2%) | 0.9372 (98.9%) |
| 128 | 0.7888 (88.0%) | 0.9058 (95.5%) |
| 64 | 0.6852 (76.5%) | 0.8758 (92.4%) |
Como o gap entre 1024 e 2048 é pequeno, ele salvou o modelo com truncate_dim=1024 default. Metade do storage, praticamente zero perda.
Reranker multimodal
Dá pra treinar Cross Encoder multimodal também, usando CrossEncoderTrainer e losses próprias. Duas arquiteturas válidas:
- Any-to-Any + LogitScore: usa o LM pra gerar token e calcula log-odds de "1" vs "0"
- Feature Extraction + Pooling + Dense: pega o hidden state do último token e projeta pra score via Dense, evitando o LM head
O script de exemplo (doodles) divide o dataset em duas direções (image-to-text e text-to-image) com prompt específico por direção.
Links úteis
- Visual Document Retrieval training script
- Multimodal Reranker (Any-to-Any) com LogitScore
- Multimodal Reranker (Feature Extraction) com Pooling + Dense
- Multimodal Embedding & Reranker Models with Sentence Transformers (post anterior)
☕ comentários · 0