Implementation:Microsoft DeepSpeedExamples Add Argument CIFAR
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation |
| Repository | Microsoft/DeepSpeedExamples |
| Title | Add_Argument_CIFAR |
| Type | Function Doc |
| Source File | training/cifar/cifar10_deepspeed.py
|
| Lines | 14-108 |
| Import | Direct function in cifar10_deepspeed.py
|
| Implements | Principle:Microsoft_DeepSpeedExamples_DeepSpeed_CLI_Integration |
Overview
Concrete tool for setting up DeepSpeed-compatible argument parsing in the CIFAR-10 example.
Description
The add_argument() function in cifar10_deepspeed.py constructs a comprehensive argument parser that merges application-specific training arguments with DeepSpeed's required configuration arguments. This function is the entry point for the entire CIFAR-10 DeepSpeed workflow -- it is called in __main__ before any other initialization.
The function performs three key operations:
- Creates an
argparse.ArgumentParserwith the description "CIFAR" - Adds custom arguments for training control, mixed precision, ZeRO, and MoE configuration
- Calls
deepspeed.add_config_arguments(parser)to inject DeepSpeed-specific arguments (--deepspeed,--deepspeed_config, etc.)
The resulting args namespace is then passed to both get_ds_config(args) (to build the JSON config) and deepspeed.initialize(args=args, ...) (to configure the engine).
Code Reference
File: training/cifar/cifar10_deepspeed.py, Lines 14-108
def add_argument():
parser = argparse.ArgumentParser(description="CIFAR")
# For train.
parser.add_argument(
"-e",
"--epochs",
default=30,
type=int,
help="number of total epochs (default: 30)",
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="local rank passed from distributed launcher",
)
parser.add_argument(
"--log-interval",
type=int,
default=2000,
help="output logging information at a given interval",
)
# For mixed precision training.
parser.add_argument(
"--dtype",
default="fp16",
type=str,
choices=["bf16", "fp16", "fp32"],
help="Datatype used for training",
)
# For ZeRO Optimization.
parser.add_argument(
"--stage",
default=0,
type=int,
choices=[0, 1, 2, 3],
help="Datatype used for training",
)
# For MoE (Mixture of Experts).
parser.add_argument(
"--moe",
default=False,
action="store_true",
help="use deepspeed mixture of experts (moe)",
)
parser.add_argument(
"--ep-world-size", default=1, type=int,
help="(moe) expert parallel world size"
)
parser.add_argument(
"--num-experts",
type=int,
nargs="+",
default=[1],
help="number of experts list, MoE related.",
)
parser.add_argument(
"--mlp-type",
type=str,
default="standard",
help="Only applicable when num-experts > 1, accepts [standard, residual]",
)
parser.add_argument(
"--top-k", default=1, type=int,
help="(moe) gating top 1 and 2 supported"
)
parser.add_argument(
"--min-capacity",
default=0,
type=int,
help="(moe) minimum capacity of an expert regardless of the capacity_factor",
)
parser.add_argument(
"--noisy-gate-policy",
default=None,
type=str,
help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter",
)
parser.add_argument(
"--moe-param-group",
default=False,
action="store_true",
help="(moe) create separate moe param groups, required when using ZeRO w. MoE",
)
# Include DeepSpeed configuration arguments.
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
Signature
def add_argument() -> argparse.Namespace:
"""Build and parse CLI arguments for CIFAR-10 DeepSpeed training.
Returns:
argparse.Namespace: Parsed arguments including both custom and DeepSpeed args.
"""
I/O Contract
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | (none) | -- | Reads from sys.argv via parser.parse_args()
|
| Output | args | argparse.Namespace |
Combined namespace with custom + DeepSpeed arguments |
Argument Reference
Training Arguments
| Argument | Type | Default | Description |
|---|---|---|---|
-e / --epochs |
int | 30 | Number of total training epochs |
--local_rank |
int | -1 | Local rank set by DeepSpeed distributed launcher |
--log-interval |
int | 2000 | Print loss statistics every N mini-batches |
Mixed Precision Arguments
| Argument | Type | Default | Choices | Description |
|---|---|---|---|---|
--dtype |
str | "fp16" | bf16, fp16, fp32 | Data type for mixed precision training |
ZeRO Arguments
| Argument | Type | Default | Choices | Description |
|---|---|---|---|---|
--stage |
int | 0 | 0, 1, 2, 3 | ZeRO optimization stage |
MoE Arguments
| Argument | Type | Default | Description |
|---|---|---|---|
--moe |
flag | False | Enable Mixture of Experts |
--ep-world-size |
int | 1 | Expert parallel world size |
--num-experts |
int (nargs="+") | [1] | Number of experts per MoE layer (list) |
--mlp-type |
str | "standard" | MLP type: "standard" or "residual" |
--top-k |
int | 1 | Top-k gating (1 or 2 supported) |
--min-capacity |
int | 0 | Minimum expert capacity |
--noisy-gate-policy |
str | None | Noisy gating policy: None, RSample, or Jitter |
--moe-param-group |
flag | False | Create separate MoE param groups (required for ZeRO + MoE) |
Usage Example
# In __main__ of cifar10_deepspeed.py:
if __name__ == "__main__":
args = add_argument()
main(args)
# CLI invocations that feed into add_argument():
# Basic run with default fp16
deepspeed cifar10_deepspeed.py --deepspeed
# Specify dtype and ZeRO stage
deepspeed cifar10_deepspeed.py --deepspeed --dtype bf16 --stage 2 --epochs 10
# Enable MoE with 4 experts
deepspeed --num_gpus=2 cifar10_deepspeed.py --deepspeed \
--moe --num-experts 4 --top-k 1 --ep-world-size 2 --moe-param-group
How Arguments Flow to Downstream Components
add_argument()
|
v
args (Namespace)
|
+-----> get_ds_config(args) --> ds_config (dict)
| |
+-----> deepspeed.initialize(args=args, ..., config=ds_config)
| |
+-----> Net(args) --> model (uses args.moe, args.num_experts, etc.)
|
+-----> main(args) --> training loop (uses args.epochs, args.log_interval)
Related Pages
- Principle:Microsoft_DeepSpeedExamples_DeepSpeed_CLI_Integration -- The principle this implementation realizes
- Implementation:Microsoft_DeepSpeedExamples_DeepSpeed_Initialize_CIFAR -- Consumes the args produced by this function
- Implementation:Microsoft_DeepSpeedExamples_Net_DeepSpeed -- Uses
args.moeand related MoE arguments - Environment:Microsoft_DeepSpeedExamples_CIFAR10_Training_Environment