Implementation:Sgl project Sglang Program State Fork Join
| Knowledge Sources | |
|---|---|
| Domains | Frontend_DSL, Parallel_Computing, LLM_Programming |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Concrete tool for creating parallel execution branches in SGLang frontend programs using fork and join operations.
Description
The s.fork(size) method creates a ProgramStateGroup with size independent ProgramState branches that share the current KV cache prefix. Each branch is indexed like a list (forks[i]) and can execute independently. The .join() method collects results back to the parent state.
Usage
Call s.fork(n) inside an @sgl.function to create parallel branches. Use indexing to operate on individual branches, then call .join() to collect results.
Code Reference
Source Location
- Repository: sglang
- File: python/sglang/lang/interpreter.py
- Lines: L864-872 (ProgramState.fork)
- Join: python/sglang/lang/ir.py:L1015-1039 (ProgramStateGroup.join)
Signature
def fork(
self,
size: int = 1,
position_ids_offset: Optional[List[int]] = None,
) -> ProgramStateGroup:
"""Create parallel branches sharing the current prefix."""
# ProgramStateGroup
class ProgramStateGroup:
def join(self, mode: str = "gather_variable"):
"""Collect results from branches."""
def __getitem__(self, index: int) -> ProgramState:
"""Access individual branch."""
Import
import sglang as sgl
# Used within @sgl.function via the ProgramState 's'
forks = s.fork(3)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| size | int | No | Number of parallel branches (default: 1) |
| mode | str | No | Join mode: "gather_variable" or "concate_and_append" |
Outputs
| Name | Type | Description |
|---|---|---|
| ProgramStateGroup | ProgramStateGroup | Group of parallel ProgramState branches |
Usage Examples
Parallel Sampling
import sglang as sgl
@sgl.function
def parallel_gen(s, question):
s += sgl.user(question)
# Fork into 3 parallel branches
forks = s.fork(3)
for i in range(3):
forks[i] += sgl.assistant(sgl.gen("answer", max_tokens=128, temperature=0.8))
forks.join()
state = parallel_gen.run(question="Explain quantum computing.")
# state["answer"] is now a list of 3 different answers
for i, answer in enumerate(state["answer"]):
print(f"Answer {i}: {answer[:80]}...")
Best-of-N Selection
@sgl.function
def best_of_n(s, question, n=5):
s += sgl.user(question)
forks = s.fork(n)
for i in range(n):
forks[i] += sgl.assistant(sgl.gen("answer", max_tokens=200, temperature=0.9))
forks.join()
# Select best answer (e.g., longest)
answers = state["answer"]
# Post-processing to select best...
state = best_of_n.run(question="Write a haiku about coding.")