Implementation:FlagOpen FlagEmbedding LLARA Pretrain Modeling
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Large Language Models, Retrieval Augmented Generation, Neural Networks |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A specialized LLaMA-based model architecture for LLARA pretraining that implements bidirectional attention masking for simultaneous passage summarization and continuation prediction.
Description
This module implements PreLlamaModel, a modified LLaMA architecture designed for the LLARA (Large Language And Retrieval Augmented) pretraining task. The key innovation is a custom attention mechanism that enables the model to simultaneously learn two complementary tasks: backward summarization (generating a summary of preceding context) and forward prediction (predicting future content). This is achieved through specialized position IDs and attention masks that create two independent attention streams within a single forward pass.
The implementation extends the standard LLaMA model with custom causal mask modifications that block attention between summary and prediction tokens while allowing both to attend to the input passage. It includes bag-of-words (BoW) loss functions for both summary and prediction tasks to encourage vocabulary coverage, support for mixed autoregressive and BoW training objectives, and specialized token handling for the 8-word summary and prediction constraints.
The architecture is built on top of HuggingFace's LLaMA implementation with minimal modifications focused on the attention mechanism, making it compatible with standard LLaMA infrastructure while enabling the unique bidirectional pretraining capability needed for retrieval-augmented generation tasks.
Usage
Use this model for pretraining LLARA on long-context tasks where learning both backward summarization and forward continuation improves retrieval and generation capabilities.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/LLARA/pretrain/modeling.py
- Lines: 1-441
Signature
class NewLlamaModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]
class PreLlamaModel(LlamaForCausalLM):
def __init__(self, config)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_summarize_ids: Optional[torch.LongTensor] = None,
output_predict_ids: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]
class PreModel(nn.Module):
def __init__(self, model: AutoModel = None)
def forward(self, *args, **kwargs)
def save(self, output_dir: str)
Import
from modeling import PreLlamaModel, PreModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes | Input token IDs (batch_size, seq_len) |
| attention_mask | Optional[torch.Tensor] | No | Attention mask |
| position_ids | Optional[torch.LongTensor] | No | Position IDs (auto-computed if None) |
| labels | Optional[torch.LongTensor] | No | Labels for autoregressive loss |
| output_summarize_ids | Optional[torch.LongTensor] | No | Target vocabulary for summarization BoW loss |
| output_predict_ids | Optional[torch.LongTensor] | No | Target vocabulary for prediction BoW loss |
| past_key_values | Optional[Cache] | No | Cached key-value pairs for generation |
| use_cache | Optional[bool] | No | Use KV cache for generation |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.FloatTensor | Combined loss (AR + BoW_summarize + BoW_predict) |
| logits | torch.FloatTensor | Language model logits (batch_size, seq_len, vocab_size) |
| past_key_values | Optional[Cache] | Updated KV cache |
| hidden_states | Optional[Tuple] | Hidden states from all layers |
| attentions | Optional[Tuple] | Attention weights from all layers |
Architecture Overview
Input Format
The model expects input in this specific format:
[passage tokens...]
[summarize prompt: ", summarize the above passage within eight words: "]
[special tokens: <s1><s2><s3><s4><s5><s6><s7><s8>]
[predict prompt: ", predict the following passage within eight words: "]
[special tokens: <s9><s10><s11><s12><s13><s14><s15><s16>]
Token IDs:
- Summarize prompt: [9162, 19138, 675, 278, 2038, 13382, 2629, 9475, 3838, 29901, 29871]
- Summary special tokens: [32008, 32011, 32004, 32013, 32007, 32005, 32002, 32014]
- Predict prompt: [9162, 8500, 278, 1494, 13382, 2629, 9475, 3838, 29901, 29871]
- Prediction special tokens: [32000, 32009, 32012, 32001, 32010, 32003, 32006, 32015]
Custom Attention Mechanism
The key innovation is the modified attention mask:
# Standard causal mask
causal_mask = torch.triu(causal_mask, diagonal=1)
# Block attention from prediction tokens to summary tokens
causal_mask[:,
:,
-len(predict_suffix_ids):,
-len(predict_suffix_ids) - len(summarize_suffix_ids): -len(predict_suffix_ids)
] = torch.finfo(dtype).min
This creates three attention regions: 1. Passage Region: Standard causal attention 2. Summary Region: Can attend to passage, not to prediction 3. Prediction Region: Can attend to passage, not to summary
Position ID Modification
# Prediction tokens use same positions as summary tokens
position_ids[i][-len(predict_suffix_ids):] = position_ids[i][
-len(summarize_suffix_ids) - len(predict_suffix_ids): -len(summarize_suffix_ids)
]
This ensures prediction and summary tokens don't have sequential positions relative to each other.
- Training Objectives ==
Autoregressive Loss
Standard next-token prediction loss:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
ar_loss = CrossEntropyLoss()(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
Trains the model to predict next tokens in the passage and special token sequences.
Bag-of-Words Summarization Loss
Encourages summary tokens to cover vocabulary from ground truth:
# Extract logits for 8 summary special tokens
special_logits = logits[:, -len(predict_prompt_ids) - 8:-len(predict_prompt_ids), :]
# Max pooling over the 8 positions
special_logits, _ = torch.max(special_logits, dim=1)
# Log-softmax for probability
possibility = log_softmax(special_logits)
# Negative log-likelihood for target vocabulary
for p, temp_output_ids in zip(possibility, output_summarize_ids):
unique_useful_ids = torch.unique(temp_output_ids[temp_output_ids > 2])
bow_summarize_loss -= torch.mean(p[unique_useful_ids])
Process: 1. Get logits from 8 summary special tokens 2. Max pool across positions 3. Compute probability distribution 4. Maximize probability of target vocabulary words 5. Scaled by 1/10 to balance with AR loss
Bag-of-Words Prediction Loss
Same as summarization but for prediction tokens:
# Extract logits for 8 prediction special tokens
special_logits = logits[:, -8:, :]
# Max pooling and probability computation
special_logits, _ = torch.max(special_logits, dim=1)
possibility = log_softmax(special_logits)
# Maximize probability of target vocabulary
for p, temp_output_ids in zip(possibility, output_predict_ids):
unique_useful_ids = torch.unique(temp_output_ids[temp_output_ids > 2])
bow_predict_loss -= torch.mean(p[unique_useful_ids])
Combined Loss
# Combine BoW losses
if bow_summarize_loss > 0 and bow_predict_loss > 0:
bow_loss = (bow_summarize_loss + bow_predict_loss) / 2
elif bow_summarize_loss > 0:
bow_loss = bow_summarize_loss
elif bow_predict_loss > 0:
bow_loss = bow_predict_loss
# Total loss
if ar_loss is not None and bow_loss is not None:
loss = ar_loss + bow_loss
elif ar_loss is None:
loss = bow_loss
else:
loss = ar_loss
- Implementation Details ==
Special Token Vocabulary
16 new special tokens added to LLaMA vocabulary:
- <s1> to <s8>: Summary tokens (IDs: 32008, 32011, 32004, 32013, 32007, 32005, 32002, 32014)
- <s9> to <s16>: Prediction tokens (IDs: 32000, 32009, 32012, 32001, 32010, 32003, 32006, 32015)
These tokens are randomly initialized and learned during pretraining.
Attention Mask Shape
Input sequence (example with short passage):
[P1 P2 P3 ... Pn | sum_prompt | s1-s8 | pred_prompt | s9-s16]
Attention mask (0=can attend, -inf=cannot attend):
P1 P2 P3 ... Pn sum s1-s8 pred s9-s16
P1 0 -∞ -∞ ... -∞ -∞ -∞ -∞ -∞
P2 0 0 -∞ ... -∞ -∞ -∞ -∞ -∞
...
Pn 0 0 0 ... 0 -∞ -∞ -∞ -∞
sum_prompt 0 0 0 ... 0 0 -∞ -∞ -∞
s1 0 0 0 ... 0 0 0 -∞ -∞
s2 0 0 0 ... 0 0 0 -∞ -∞
...
s8 0 0 0 ... 0 0 0 -∞ -∞
pred_prompt 0 0 0 ... 0 -∞ -∞ 0 -∞
s9 0 0 0 ... 0 -∞ -∞ 0 0
...
s16 0 0 0 ... 0 -∞ -∞ 0 0
Key: Summary tokens (s1-s8) cannot see prediction (s9-s16), and vice versa.
Cache Position Handling
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values else 0
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device
)
Properly handles KV cache for generation (though not typically used during pretraining).
Usage Examples
Pretraining
from modeling import PreLlamaModel, PreModel
from transformers import LlamaTokenizer
# Initialize model
config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
config.vocab_size = 32016 # Extended with 16 special tokens
model = PreLlamaModel(config)
# Wrap in PreModel for convenience
model = PreModel(model=model)
model.gradient_checkpointing_enable()
# Prepare data
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Add special tokens
new_tokens = [f"<s{i}>" for i in range(1, 17)]
tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
# Example passage
passage = "Machine learning is a subset of artificial intelligence..."
# Create input with prompts
summarize_prompt = ", summarize the above passage within eight words: "
predict_prompt = ", predict the following passage within eight words: "
input_text = passage + summarize_prompt + " ".join([f"<s{i}>" for i in range(1, 9)])
input_text += predict_prompt + " ".join([f"<s{i}>" for i in range(9, 17)])
# Tokenize
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Ground truth vocabularies for BoW loss
summary_vocab = ["machine", "learning", "artificial", "intelligence", "subset"]
predict_vocab = ["algorithms", "data", "patterns", "prediction", "models"]
output_summarize_ids = tokenizer.encode(" ".join(summary_vocab), return_tensors="pt")
output_predict_ids = tokenizer.encode(" ".join(predict_vocab), return_tensors="pt")
# Forward pass
outputs = model(
input_ids=input_ids,
labels=input_ids, # For AR loss
output_summarize_ids=output_summarize_ids,
output_predict_ids=output_predict_ids
)
loss = outputs.loss
loss.backward()
# Loss breakdown (if debugging):
# ar_loss: autoregressive next-token prediction
# bow_summarize_loss: encourage summary to cover key terms
# bow_predict_loss: encourage prediction to cover continuation terms
Model Saving
# Save checkpoint
model.save("/path/to/checkpoint")
# Saves:
# - pytorch_model.bin (or multiple shards)
# - config.json
# - tokenizer files (if using PreModel wrapper)
Inference (Generation)
# Load pretrained model
model = PreLlamaModel.from_pretrained("/path/to/checkpoint")
model.eval()
# Prepare input
passage = "Your long passage here..."
input_text = passage + ", summarize the above passage within eight words: "
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate summary tokens
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=8,
do_sample=False,
num_beams=1
)
# Decode generated summary special tokens
# Note: These are special tokens, not readable text
# The model learns to encode semantic information in these tokens
# for downstream retrieval tasks
Design Rationale
Why Bidirectional Pretraining?
Traditional language models only predict forward. LLARA adds backward summarization:
- Forward prediction: Learn what comes next
- Backward summarization: Learn what came before
This dual objective better prepares the model for retrieval-augmented generation where understanding both context and continuation is crucial.
Why Bag-of-Words Loss?
Pure autoregressive training on special tokens may not learn meaningful representations. BoW loss:
- Encourages tokens to cover relevant vocabulary
- Doesn't enforce specific order (bag-of-words)
- Provides semantic guidance without being overly prescriptive
- Scaled down (1/10) to not dominate AR loss
Why 8 Tokens?
8 tokens balances:
- Expressiveness: Enough tokens to encode useful information
- Efficiency: Not too many to slow down inference
- Granularity: Roughly equivalent to a concise summary or key continuation hint
Why Block Attention Between Summary and Prediction?
- Independence: Summary should depend only on past, prediction only on past + future signal
- Parallel Training: Both objectives computed in one forward pass
- Prevent Leakage: Prediction shouldn't cheat by looking at summary
Training Considerations
Memory Usage
- Longer sequences (passage + 2 prompts + 16 special tokens)
- Enable gradient checkpointing for large models
- Use mixed precision training (FP16/BF16)
Hyperparameters
- BoW weight: Currently 0.1 (1/10 of AR loss)
- Learning rate: Typical LLaMA pretraining LR (1e-4 to 3e-4)
- Batch size: Depends on passage length
- Warmup steps: Important for stable training
Data Requirements
- Long passages (ideally 1000+ tokens)
- Ground truth vocabulary for BoW loss:
* Summary: Key terms from passage * Prediction: Relevant terms from continuation
- Can be extracted automatically or provided manually
Flash Attention Compatibility
Note: The code explicitly disables Flash Attention 2:
if self.config._attn_implementation == "flash_attention_2":
raise ValueError("You can not use flash attention to pretrain")
Reason: Custom attention mask modifications may not be compatible with Flash Attention's optimized kernels. Standard attention or SDPA should be used.
- Gradient Checkpointing ==
model.gradient_checkpointing_enable()
Highly recommended for pretraining large models on long sequences. Trades computation for memory by recomputing activations during backward pass.
Limitations
- Fixed format: Requires specific prompt structure
- Special tokens: Needs vocabulary extension
- Attention implementation: Cannot use Flash Attention 2
- Long sequences: May be memory-intensive
- BoW targets: Requires preparing target vocabularies
Future Extensions
Possible improvements:
- Adaptive number of tokens: Allow variable-length summaries/predictions
- Hierarchical attention: Multiple levels of summarization
- Cross-attention: Let prediction attend to summary indirectly
- Dynamic weighting: Learn optimal BoW loss weight
- Retrieval integration: Use summary tokens for semantic search