Add VLM pruning and PTQ with image-text calibration (Megatron-Bridge)#1792
Add VLM pruning and PTQ with image-text calibration (Megatron-Bridge)#1792kevalmorabia97 wants to merge 8 commits into
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds VLM-aware pruning and quantization paths for Megatron-Bridge models, plus CP/DP calibration and dataset support, tiny VLM fixtures, and end-to-end tests. Documentation is updated for the new VLM behavior and pruning scope. ChangesVLM Minitron support
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Possibly related PRs
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
9b3d399 to
c92645a
Compare
|
/ok to test c92645a |
|
/claude review |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1792 +/- ##
===========================================
- Coverage 77.29% 65.98% -11.31%
===========================================
Files 513 513
Lines 56920 56989 +69
===========================================
- Hits 43997 37606 -6391
- Misses 12923 19383 +6460
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Claude review passed — no blocking issues found. LGTM
Summary
Reviewed all 11 changed files (source-prioritized: nas/plugins/mbridge.py, prune/plugins/mcore_minitron.py, utils/plugins/mbridge.py + megatron_generate.py, the example prune_minitron.py, the test fixtures, and docs).
Findings — CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 0
The change cleanly extends Minitron pruning to the language model of VLMs by extracting model.language_model as the prune target while still saving the full wrapper. I traced the dataflow end-to-end and verified the higher-risk areas:
- Mode/State — _inherit_base_model_rules correctly copies the base GPTModel/MambaModel rule onto subclass registry keys (e.g. Qwen3VLGPTModel) so subclasses aren't frozen during convert_to_dynamic. The shared rule object is only read (customize_rule -> validate_rule produces a fresh object and never mutates input), so the aliasing is safe.
- Algorithm correctness — moving the assert export_config[hp_name] in hp.choices check outside the if hp.is_configurable guard is a genuine fix: it now rejects unachievable manual values for non-configurable hparams instead of silently overwriting model.config and producing a weights/config mismatch.
- VLM detection via getattr(unwrapped_model, 'language_model', unwrapped_model) is safe (plain GPT/Mamba models lack that attribute), and the mbridge.py assertion guards the inner type. The text_config write-back path and AutoModelForImageTextToText dummy-model selection are consistent.
- Plugin laziness — the new Qwen3-VL registrations in mbridge.py are correctly guarded behind try/except ImportError for older Megatron-Bridge builds.
- Backward compat — removed create_tiny_gemma3 fixtures have no remaining references; the create_tiny_qwen3vl_dir signature change is additive (optional return_model), so the existing export tests still pass.
Risk: low. Additive feature behind VLM detection; the only behavioral change to the existing LM path is the stricter (and more correct) export_config validation. MoE-VLM is not e2e-tested (documented in the PR description, with the bridge/transformers format-mismatch rationale).
c92645a to
a8ad7b3
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 4
🧹 Nitpick comments (1)
modelopt/torch/utils/plugins/mbridge.py (1)
106-108: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winUpdate the return contract for VLM wrappers.
This function now returns the full VLM wrapper as
unwrapped_model, while the type/docstring still promise aGPTModel | MambaModel. Please widen the fourth return type and document that VLM callers should use.language_modelas the GPT/Mamba pruning target. As per coding guidelines, “Document public APIs.”Proposed contract update
) -> tuple[ AutoBridge, GPTModelProvider | MambaModelProvider, list[MegatronModule], - GPTModel | MambaModel, + MegatronModule, AutoTokenizer, ]: @@ - A tuple of (bridge, provider, model, unwrapped_model, tokenizer). + A tuple of (bridge, provider, model, unwrapped_model, tokenizer). For VLMs, + unwrapped_model is the full wrapper and unwrapped_model.language_model is the + GPT/Mamba pruning target.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/utils/plugins/mbridge.py` around lines 106 - 108, The return contract for this helper is stale: it can now return a full VLM wrapper as the unwrapped model, not just a GPTModel or MambaModel. Update the function’s declared return type and docstring around the logic that sets language_model from unwrapped_model so the fourth return value reflects the broader wrapper type, and explicitly note that VLM callers should use .language_model as the pruning target while preserving the wrapper for saving.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/megatron_bridge/README.md`:
- Around line 318-322: The VLM example command in the README is using a non-VLM
checkpoint id, which is misleading for this section. Update the
prune_minitron.py example to reference a VLM-specific Hugging Face model id
instead of Qwen/Qwen3.5-4B, keeping the rest of the command unchanged so the
example clearly matches the VLM workflow.
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 322-332: Refresh the hparam choices before checking export_config
so validation uses the post-modify granularity. In the logic around the hp/reset
flow in mcore_minitron.py, move the hp.reset_choices() call to happen before the
export_config[hp_name] in hp.choices assertion, or otherwise ensure hp.choices
is updated first, so the check in the export_config validation block sees the
current choices for each hparam.
In `@tests/_test_utils/torch/transformers_models.py`:
- Line 255: The in-function import of Qwen3_5Config inside the helper is
undocumented, so either move it into the module-level Transformers imports
alongside the other configs or add a brief inline justification matching the
Qwen3VLConfig pattern. Update the helper that uses Qwen3_5Config to explain why
the import must stay local, and ensure the reason clearly fits the allowed cases
(optional dependency, version gating, or similar).
In `@tests/examples/megatron_bridge/test_prune_minitron.py`:
- Around line 102-106: The current assertion in test_prune_minitron only
compares non-LM parameter counts, which can miss unintended weight changes in
the vision tower or projector. Update the test around the teacher_model and
pruned_model comparison to compare the actual non-language_model parameter
tensors directly, using the existing language_model filtering logic, so “left
untouched” is verified by value and not just shape/count. If needed in this
module, add the torch import at the top so you can use tensor equality checks.
---
Nitpick comments:
In `@modelopt/torch/utils/plugins/mbridge.py`:
- Around line 106-108: The return contract for this helper is stale: it can now
return a full VLM wrapper as the unwrapped model, not just a GPTModel or
MambaModel. Update the function’s declared return type and docstring around the
logic that sets language_model from unwrapped_model so the fourth return value
reflects the broader wrapper type, and explicitly note that VLM callers should
use .language_model as the pruning target while preserving the wrapper for
saving.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5883770a-6ce0-43fe-9f11-58d8e10c4056
📒 Files selected for processing (11)
CHANGELOG.rstdocs/source/guides/3_pruning.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/prune_minitron.pyexamples/pruning/README.mdmodelopt/torch/nas/plugins/mbridge.pymodelopt/torch/prune/plugins/mcore_minitron.pymodelopt/torch/utils/plugins/mbridge.pymodelopt/torch/utils/plugins/megatron_generate.pytests/_test_utils/torch/transformers_models.pytests/examples/megatron_bridge/test_prune_minitron.py
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/megatron_bridge/quantize.py`:
- Around line 308-312: The warning in quantize.py is using copy-pasted pruning
terminology, so update the warn_rank_0 message in the VLM text-only calibration
branch to refer to quantization calibration or activation statistics instead of
“pruning importance.” Keep the existing conditions and context around is_vlm,
use_image_calib, and args.calib_dataset_name, but rewrite the wording so it
accurately describes what happens during calibration in this path.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5c3634c5-ba68-48c3-844f-e4014adf196b
📒 Files selected for processing (11)
CHANGELOG.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/export.pyexamples/megatron_bridge/prune_minitron.pyexamples/megatron_bridge/quantize.pymodelopt/torch/utils/plugins/megatron_calibration.pymodelopt/torch/utils/plugins/megatron_generate.pymodelopt/torch/utils/vlm_dataset_utils.pytests/_test_utils/torch/transformers_models.pytests/examples/megatron_bridge/test_prune_minitron.pytests/examples/megatron_bridge/test_quantize_export.py
✅ Files skipped from review due to trivial changes (3)
- examples/megatron_bridge/export.py
- CHANGELOG.rst
- examples/megatron_bridge/README.md
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/utils/plugins/megatron_generate.py
- examples/megatron_bridge/prune_minitron.py
e0f406b to
2135cd0
Compare
2135cd0 to
27f791b
Compare
|
/claude review |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/utils/vlm_dataset_utils.py (1)
244-244: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winMove local imports to module scope or document the deferral.
Line 244 and Line 508 add function-local imports without an explicit reason. If
datasetsis intentionally deferred as an optional/heavy dependency, add a short comment naming that reason; otherwise move these imports to the top-level import block.As per coding guidelines, “Keep imports at the top of the file; place imports inside functions only when necessary ... with explicit justification.”
Also applies to: 508-508
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/utils/vlm_dataset_utils.py` at line 244, The local imports of interleave_datasets and load_dataset in vlm_dataset_utils should either be moved into the module-level import block or explicitly justified if they are intentionally deferred. Update the relevant import sites used by the dataset-loading logic so the symbols are imported at the top of the file, or add a short in-function comment near the import explaining that datasets is an optional/heavy dependency being loaded lazily.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@modelopt/torch/utils/vlm_dataset_utils.py`:
- Line 244: The local imports of interleave_datasets and load_dataset in
vlm_dataset_utils should either be moved into the module-level import block or
explicitly justified if they are intentionally deferred. Update the relevant
import sites used by the dataset-loading logic so the symbols are imported at
the top of the file, or add a short in-function comment near the import
explaining that datasets is an optional/heavy dependency being loaded lazily.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 373c77a7-96bf-4940-a30a-85b3e24a3df2
📒 Files selected for processing (16)
CHANGELOG.rstdocs/source/guides/3_pruning.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/export.pyexamples/megatron_bridge/prune_minitron.pyexamples/megatron_bridge/quantize.pyexamples/pruning/README.mdmodelopt/torch/nas/plugins/mbridge.pymodelopt/torch/prune/plugins/mcore_minitron.pymodelopt/torch/utils/plugins/mbridge.pymodelopt/torch/utils/plugins/megatron_calibration.pymodelopt/torch/utils/plugins/megatron_generate.pymodelopt/torch/utils/vlm_dataset_utils.pytests/_test_utils/torch/transformers_models.pytests/examples/megatron_bridge/test_prune_minitron.pytests/examples/megatron_bridge/test_quantize_export.py
✅ Files skipped from review due to trivial changes (5)
- examples/megatron_bridge/export.py
- examples/pruning/README.md
- examples/megatron_bridge/README.md
- docs/source/guides/3_pruning.rst
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (10)
- modelopt/torch/utils/plugins/megatron_generate.py
- modelopt/torch/utils/plugins/mbridge.py
- modelopt/torch/nas/plugins/mbridge.py
- tests/examples/megatron_bridge/test_quantize_export.py
- modelopt/torch/utils/plugins/megatron_calibration.py
- modelopt/torch/prune/plugins/mcore_minitron.py
- examples/megatron_bridge/quantize.py
- examples/megatron_bridge/prune_minitron.py
- tests/examples/megatron_bridge/test_prune_minitron.py
- tests/_test_utils/torch/transformers_models.py
There was a problem hiding this comment.
Claude Review
Re-reviewed after the PR grew from 11 to 16 files (the prior Claude approval was on the earlier commit; the quantize.py VLM path, the VLM calibration loop, and DP dataset sharding are newer). Reviewed all modelopt/ and examples/ source plus the new tests and fixtures.
Findings — CRITICAL: 0, IMPORTANT: 1, SUGGESTION: 1
Most impactful
[IMPORTANT] DP sharding crashes on the default VLM streaming dataset (vlm_dataset_utils.py). The new map-vs-iterable discriminator hasattr(dataset, "__len__") misclassifies _HFDatasetsIterableWrapper — it is an IterableDataset that also defines __len__, so the default nemotron_vlm_dataset_v2 streaming dataset takes the DistributedSampler branch, which DataLoader rejects for iterable datasets (ValueError). Any VLM calibration with the default dataset and DP > 1 — the sharding feature this PR adds — fails. Both new VLM tests run DP=1 (only TP/PP), so they do not catch it. Suggested fix is an isinstance(dataset, IterableDataset) check; see the inline comment.
[SUGGESTION] The non-LM quantizer-disable pattern f"*{name}*" in quantize.py is a substring match; safe for the tested VLMs but could over-match a generic child name. Consider anchoring to the submodule subtree.
Assessment
Risk: low-to-moderate. The feature is cleanly additive behind VLM detection, and end-to-end correctness (text_config write-back, language_model prune target, vision-quantizer disable, _inherit_base_model_rules, tuple-output / vocab_size handling in megatron_prefill) is sound and well-tested for the DP=1 paths. The one blocking item is the multi-DP streaming path, exercised neither by the tests nor the DP=1 fixtures — worth fixing before relying on DP-sharded VLM calibration.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/utils/plugins/mbridge.py (1)
146-152: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winUpdate the return type to include VLM wrappers.
This path now intentionally returns the outer wrapper when
.language_modelexists, but the tuple annotation still says the fourth element isGPTModel | MambaModel. Please widen that element to the wrapper/base module type so type consumers do not assume the returned object is always the inner LM.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/utils/plugins/mbridge.py` around lines 146 - 152, The return type for the tuple produced in mbridge.py is too narrow because the logic now returns the outer VLM wrapper when .language_model exists instead of always returning the inner GPTModel/MambaModel. Update the annotation for the fourth tuple element in the relevant function to a broader wrapper/base module type so callers and type checkers reflect the actual value returned, and keep the runtime check against language_model using GPTModel/MambaModel in place.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@modelopt/torch/utils/plugins/mbridge.py`:
- Around line 146-152: The return type for the tuple produced in mbridge.py is
too narrow because the logic now returns the outer VLM wrapper when
.language_model exists instead of always returning the inner
GPTModel/MambaModel. Update the annotation for the fourth tuple element in the
relevant function to a broader wrapper/base module type so callers and type
checkers reflect the actual value returned, and keep the runtime check against
language_model using GPTModel/MambaModel in place.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 1ed49043-c91e-4481-b608-28f6ce2917d5
📒 Files selected for processing (7)
examples/megatron_bridge/prune_minitron.pyexamples/megatron_bridge/quantize.pymodelopt/torch/prune/plugins/mcore_minitron.pymodelopt/torch/utils/plugins/mbridge.pymodelopt/torch/utils/vlm_dataset_utils.pytests/_test_utils/torch/transformers_models.pytests/examples/megatron_bridge/test_prune_minitron.py
🚧 Files skipped from review as they are similar to previous changes (5)
- examples/megatron_bridge/quantize.py
- modelopt/torch/prune/plugins/mcore_minitron.py
- examples/megatron_bridge/prune_minitron.py
- modelopt/torch/utils/vlm_dataset_utils.py
- tests/_test_utils/torch/transformers_models.py
ChenhanYu
left a comment
There was a problem hiding this comment.
Review summary
What it does (one paragraph)
Adds VLM (Qwen3-VL / Qwen3.5-VL / Gemma3-VL) support to the Megatron-Bridge examples for Minitron pruning (prune_minitron.py) and PTQ (quantize.py). The pattern: prune/quantize only the inner model.language_model; leave the vision tower + vision→LM projector in full precision; save the full VLM back. Calibration is image-text by default for VLMs — drives the full VLM forward over nemotron_vlm_dataset_v2 (or scienceqa) so the language model's quantizer/importance statistics see vision-conditioned activations.
Architectural call-outs
- Quantizer scope vs ModeloptState scope.
quantize.pyquantizes the root VLM (so ModeloptState lives where Megatron-save expects it) but explicitly disables quantizers on every non-LM top-level child viaf"{name}.*", "enable": False. The name is anchored to top-level children so a short non-LM name can't collide-by-substring with an LM quantizer path. ID-based de-dup againstlanguage_model.modules()correctly handles aliases likeself.decoder = language_model.decoder. (quantize.py:572-617) - Pruning scope vs save scope.
prune_minitron.pycallsmtp.prune(language_model, ...)in-place but saves the fullunwrapped_modelwrapper — clean and correct. - CP gate.
get_megatron_vlm_calibration_forward_loopraises oncp_size != 1because the multimodal forward merges vision tokens into the sequence after the (non-CP-split) vision tower — any sequence split would misalign them. Text-only VLM calibration takes the existing text loop and gets CP for free. hidden_sizepinning on VLMs. Shared with the vision→LM projector; permuting it would misalign injected image features. Skipped fromhps_to_sortand explicitly disallowed inprune_export_config.
Correctness fixes worth highlighting (real bugs, not stylistic cleanup)
mcore_minitron.py: validateexport_configagainst non-configurable hparams too. Pre-PR: a value outsidehp.choices(e.g. not aligned to the search-space divisor) was silently ignored whilemodel.configwas overwritten anyway → checkpoint whose weights don't match its declared config. Now asserts on every matching hparam, configurable or not, with a clearer error citing the*_divisorsettings. (mcore_minitron.py:404-422)mcore_minitron.py:reset_choices()moved before thechoicesassertion — previously validation could see stale post-modify()state._inherit_base_model_rules— VLM language models (Qwen3VLGPTModel) are registered under their ownDMRegistrykey but reuse theGPTModeldynamic class. Without rule inheritance, the subclass gets frozen duringconvert_to_dynamic(no width/depth pruning). The function walks the model and copies the base rule onto the subclass key.mbridge.py: stop overwriting the bridge's native layer spec for MoE. Wasprovider.transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(...), which clobbered Qwen3.5's GatedDeltaNet + gated attention and Gemma3's custom spec. Now setsprovider.moe_grouped_gemm = moe_grouped_gemmand lets the native spec read it at build time. This benefits all MoE Bridge users, not just VLMs.- PP tied-embedding double-count in the param breakdown — under PP both the first and last stage own a copy of
word_embeddings; the allreduce-sum double-counts. Subtracted once. (prune_minitron.py:307-325) - ScienceQA chat-template fix. Datasets that ship images but no
messagesfield previously fell back to a text-only prompt → processor emitted no image tokens → VLM forward failed. Now synthesizes a single user turn with[{"type": "image"}, {"type": "text", "text": question}]. - Iterable vs map-style dataset DP sharding. Discriminates on
isinstance(dataset, IterableDataset), not__len__— the streaming wrapper defines__len__but is iterable, andDataLoaderrejects asampleron iterable datasets. Subtle.
Smells / things to consider
_patch_qwen35_moe_sequential_expert_mappingsis a runtime monkey-patch ofQwen35MoEBridge._get_moe_lm_mappingsfrom insideload_mbridge_model_from_hf. It self-detects already-patched and is no-op if the installed bridge has sequential mappings — fine. But:- The patch only fires on the first call to
load_mbridge_model_from_hf. Any code path that importsQwen35MoEBridgebefore that and caches its mappings would see the un-patched version. Unlikely in this example flow, but worth a one-liner in the docstring. TODO: Remove once Megatron-Bridge maps sequential Qwen3.5 MoE experts natively (26.06.01 onwards)— load-bearing TODO; please track it.
- The patch only fires on the first call to
AutoMapping/GatedMLPMappingimports inside the patch function — if Megatron-Bridge ever renames or moves these, the patch silentlyexcept ImportError: returnand pruning fans out on Qwen3.5 MoE with the original grouped-only mappings (which pruning doesn't support). Consider awarn_rank_0on theImportErrorbranch so this fails loud instead of silent.- DP-sharded streaming is per-rank stride, not per-rank shard split.
_ShardedIterableiterates the full stream andislice(rank, None, world)'s it — meaning every rank still pulls (and discards) all other ranks' samples from the underlying streamingload_dataset. Withdp_size ≥ 8and image media decode in the codec path you're burning network + CPU on samples each rank throws away. Future optimization, not a blocker. - Text-only VLM calibration runs the inner
language_modeldirectly via the wrapper lambdaforward_loop(_model=None): text_forward_loop(language_model). CP groups are set up on the parent VLM wrapper; calling forward on the submodule should still pick up the right CP state because CP communicators are global onmpu, but I'd want a CP≥2 smoke test confirming this before relying on it. Currently untested per the PR description. - Quantized VLM HF export is not in this PR.
export.pyadds an 8-line TODO explaining thatexport_mcore_gpt_to_hfuses modelopt's per-arch mappings (no Qwen3-VL / Gemma3-VL coverage) and suggests routing throughAutoBridge.export_hf_weights_quant. The PR ships PTQ + Megatron-checkpoint save, but HF-format quantized VLM export is a follow-up. Worth a tracking issue linked from the PR description. - MMLU regression direction on the image-text row — text-cal MMLU 0.51 vs image-cal MMLU 0.49 on the 10% split. The PR's
> [!NOTE]correctly disclaims these as "high-level trends only," but the direction (image-cal trades MMLU for substantially better VLM benchmarks) deserves a confirmatory run on full MMLU before getting cited downstream. Not a merge blocker — just file the follow-up.
Bottom line
The architecture is clean: prune/quantize target = model.language_model, calibration target = unwrapped_model, save target = unwrapped_model. That three-pointer split is the whole load-bearing idea, and the code consistently encodes it without leaking the abstraction. Multiple sneaky correctness fixes shipped along the way (export_config validation, MoE layer-spec preservation, PP tied-embedding double-count, ScienceQA chat-template) are independently valuable beyond the VLM feature itself.
Suggested follow-ups: track the quantized VLM HF export (#5) as its own issue, fail-loud on the Megatron-Bridge ImportError in _patch_qwen35_moe_sequential_expert_mappings (#2), and add a CP≥2 smoke test for text-only VLM calibration (#4). The runtime monkey-patch in #1 is acceptable with the TODO; flag it on the next Megatron-Bridge bump so the cleanup actually happens.
Prune the language model of Megatron-Bridge VLMs (Qwen3-VL, Qwen3.5-VL, Gemma3-VL) while leaving the vision tower intact. The example extracts model.language_model, prunes it in place (ffn/MoE dims + depth; hidden_size is skipped since it's shared with the vision projector), and saves the full VLM back. - nas/plugins/mbridge.py: register Qwen3VLGPTModel -> _DynamicMCoreLanguageModel and Qwen3VLSelfAttention -> _DynamicSelfAttention. - utils/plugins/mbridge.py: loader accepts VLM wrappers (validates the inner .language_model, returns the full wrapper for saving). - prune/plugins/mcore_minitron.py: protected_layers option to keep VLM deepstack-injected layers during depth pruning; inherit base-model search-space rules for registered subclasses; validate export_config values. - megatron_generate.py: read vocab_size from the inner language_model for VLMs. - examples/megatron_bridge/prune_minitron.py: VLM detection, hidden_size skip, deepstack index protection + remap, and VLM-aware HF save. - tests: tiny Qwen3-VL / Gemma3-VL / Qwen3.5-VL (dense) fixtures + parametrized test_prune_minitron_vlm. - Docs: CHANGELOG (0.46), pruning + megatron_bridge READMEs, pruning guide. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Qwen3.5-VL supersedes Qwen3-VL and removed deepstack vision injection, which was the only consumer of protected_layers (no other supported VLM injects vision features at specific LM layers). Remove the protected_layers search option and the deepstack index protection/remap for simplicity. - prune/plugins/mcore_minitron.py: remove the protected_layers config option, the sorted-layers reorder, and the depth-prune assert. - examples/megatron_bridge/prune_minitron.py: remove the deepstack protected_layers computation and the deepstack-index remap on save. - tests: drop the Qwen3-VL parametrization from test_prune_minitron_vlm (Gemma3-VL + Qwen3.5-VL remain; helper kept for the export test). - Docs: drop Qwen3-VL from the supported VLM lists; the example prunes a smaller Qwen3.5-4B for faster runs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
- Clarify in the README that VLM pruning targets (params / memory / export_config) apply to the language model only; the vision tower is not counted. - Drop the now-dead deepstack assertion from the VLM test and assert the non-language-model params (vision tower, projector, lm_head) are unchanged. - Simplify the layer_types kept-layer computation in prune_minitron.py. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
- quantize.py: quantize only a VLM's language model (vision tower + projector stay full precision; ModelOpt state on the root so the Megatron save works). Calibration modality is inferred from --calib_dataset_name: an image-text dataset drives the full VLM forward (vision-conditioned activations); a text dataset runs text-only LM calibration (ablations). - prune_minitron.py: estimate pruning importance from image-text calibration (full VLM forward) by default, text-only otherwise; same dataset inference. - Shared get_megatron_vlm_calibration_forward_loop (megatron_prefill-based, unwraps tuple outputs) + vlm_dataset_utils (scienceqa, nemotron_vlm_dataset_v2 with config-driven subsets/shard cap). - Tiny VLM test fixtures (qwen3.5-vl, gemma3-vl) with vision tokens derived dynamically from the ref processor; VLM prune + quantize example tests. - README + CHANGELOG; TODO in export.py noting VLM HF export is unsupported. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Review-comment fixes:
- vlm_dataset_utils: discriminate DP sharding on IterableDataset type rather
than __len__ (the streaming wrapper defines __len__), so DP>1 calibration
works for the default streaming VLM dataset instead of crashing in DataLoader.
- mcore_minitron: reset hparam choices before validating export_config so
ConcatHparam choices are refreshed after modify().
- quantize: anchor the non-LM quantizer-disable pattern to the child subtree
(f"{name}.*") so a short non-LM name cannot match an LM quantizer path;
fix copy-pasted "pruning importance" wording (calibration, not pruning).
- transformers_models: document the Qwen3.5-VL lazy import.
- test_prune_minitron_vlm: compare the vision tower/projector/lm_head by value
(byte-identical), not just counts.
Bug fixes:
- mcore_minitron: do not sort (importance-permute) hparams listed in
hparams_to_skip. For VLMs hidden_size is skipped because the language model
shares its residual dimension with the un-pruned vision projector; permuting
it (even when hidden_size is not reduced) misaligned the injected image
features and produced a permuted lm_head/embedding. Now the residual order is
preserved and the pruned VLM stays vision-aligned.
- prune_minitron: write the pruned shared-expert size back under both HF field
names (Qwen3.5-MoE uses shared_expert_intermediate_size), so the exported
config matches the pruned weights and reloads via from_pretrained.
Test-util refactor:
- Merge the dense/MoE tiny-model builders behind a moe flag exposed via partial
aliases (get_tiny_qwen3[/_moe], create_tiny_qwen3_dir[/_moe], and the
Qwen3.5-VL equivalents), with no caller churn.
Verified test_prune_minitron_vlm[gemma3vl] and [qwen3_5_moe_vl] pass on 2 GPUs
(nemo:26.06) with byte-identical vision tower, projector, and lm_head.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
- Log a language-model / frozen-non-LM / full-model param breakdown (via num2hrb) before and after pruning, and clarify in the --prune_target_* help text and the runtime message that the target applies to the language-model tower only for VLMs (vision/audio encoders, projectors, etc. are frozen and excluded). Param counts use the dist utility (allreduce) and correct for tied embeddings under PP. - Remap vision_config.deepstack_visual_indexes (Qwen3-VL) to the surviving layers after depth pruning so the exported config stays valid: renumber for dropped lower layers and snap a dropped injection layer to the nearest survivor (count preserved so the frozen vision projector still lines up). Warn -- recommending full VLM training over LM-only -- when a deepstack layer is dropped, since text-only distillation cannot recover the vision path. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
a2943e5 to
67118fe
Compare
There was a problem hiding this comment.
Thanks for the thorough pass — replies inline. Net: the monkey-patch path is now under CI, two items are addressed in code, and the rest are non-blocking follow-ups.
Now CI-covered: the qwen3_5_moe_vl prune case is enabled (CI is on nemo:26.06), so the sequential-expert mapping path — the monkey-patch and non-grouped MoE pruning — is exercised end-to-end, not just the unit test.
1 (monkey-patch ordering + TODO).
- The patch mutates the class method
Qwen35MoEBridge._get_moe_lm_mappings; mappings are read at bridge-build time insidefrom_hf_pretrained, which runs after_patch_…()in this flow, and the patch is idempotent/self-detecting — so a prior import of the class doesn't see a stale cached copy. Added a docstring line noting it's called unconditionally before the model type is known. - TODO is load-bearing and tracked: "Remove once Megatron-Bridge maps sequential Qwen3.5 MoE experts natively (patched in 26.06.01)."
2 (silent ImportError → grouped-only mappings). Fixed: AutoMapping/GatedMLPMapping are now top-level imports, so a rename/move fails loud at import time instead of silently degrading to unsupported grouped-only mappings. The Qwen35MoEBridge import stays guarded since the patch runs for every model load and an absent bridge (older container / non-Qwen model) is a legitimate no-op.
3 (DP streaming is per-rank stride, not shard-split). Agreed, real but a throughput optimization at large DP, not a correctness issue — will address in follow-up PR.
4 (CP≥2 smoke test for text-only VLM calibration). Text-only VLM calibration reuses the existing shared text loop, whose CP≥2 path is already covered by our other LLM tests — no separate test needed.
5 (quantized VLM HF export). Out of scope by design — export.py carries the TODO and the PR description calls it out; will address in follow-up PR via the AutoBridge.export_hf_weights_quant route.
6 (MMLU direction on the image-text row). The PR NOTE disclaims these as high-level trends from small splits; detailed experiments will be run as follow-up before any downstream citation.
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
67118fe to
ddca770
Compare
| # TODO: Support exporting quantized VLMs. export_mcore_gpt_to_hf uses modelopt's own per-arch | ||
| # mcore<->HF mappings (mcore_qwen.py, ...), which don't cover Qwen3.5-VL / Gemma3-VL. Rather than | ||
| # authoring new per-model mappings, route the megatron->HF quant export through Megatron-Bridge's | ||
| # AutoBridge.export_hf_weights_quant(quantization_checker, quant_fn, quant_block_size): it reuses |
There was a problem hiding this comment.
does MBridge's export work for quantized models? I thought there was some issue there which is why we still use MLM. also this comment is very long, can it be shorter and link to a JIRA ticket as a todo?
There was a problem hiding this comment.
Export works for LLM (PP only). Yet to add some fixes for VLM export
| from modelopt.torch.utils.vlm_dataset_utils import get_supported_vlm_datasets | ||
|
|
||
| # Default calibration datasets when --calib_dataset_name is not set | ||
| DEFAULT_TEXT_CALIB_DATASET = "nemotron-post-training-dataset-v2" |
There was a problem hiding this comment.
should these constants go inside the MBridge quantization utils or example? instead of in pruning
There was a problem hiding this comment.
why does pruning use nemotron data only but quantize.py use CNN + Nemotron?
There was a problem hiding this comment.
We mention defaults in both file. Have been using Nemotron-pt-v2 for pruning since foreever so need to test with cnn+ptv2 before changing defaults
| parser.add_argument( | ||
| "--calib_dataset_name", | ||
| type=str, | ||
| default="nemotron-post-training-dataset-v2", |
There was a problem hiding this comment.
since quantize.py and prune_minitron.py share lots of arguments, could they use a common argparse then add custom args?
| total = dist.allreduce(_local(unwrapped_model)) # sum across pipeline ranks | ||
| lm = dist.allreduce(_local(language_model)) | ||
| # Under pipeline parallelism a tied embedding is materialized on both the first (word_embeddings) | ||
| # and last (output_layer) stage, so the sum double-counts it; subtract one copy. Only the first |
There was a problem hiding this comment.
can this comment be shorter? in general Claude (?) comments are too long
| # and last (output_layer) stage, so the sum double-counts it; subtract one copy. Only the first | ||
| # stage owns ``word_embeddings``, so the allreduce-sum below yields exactly one copy. | ||
| if dist.size() > 1 and getattr(language_model, "share_embeddings_and_output_weights", False): | ||
| emb = dist.allreduce( |
There was a problem hiding this comment.
why do you need an all reduce here? Can you just get the first and last PP rank and then do the subtraction?
| lt for i, lt in enumerate(text_cfg.layer_types) if i + 1 in kept_layer_nums | ||
| ] | ||
| # Qwen3-VL injects deepstack vision features at specific (0-indexed) LM layers. Remap those | ||
| # indices to the surviving layers so the exported config stays valid (no out-of-range index); |
There was a problem hiding this comment.
if this is just specific to Qwen3 VL it belongs here? maybe in a qwen utility file ... also this is verbose
There was a problem hiding this comment.
Applicable for depth pruning (layer dropping only) so only for pruning
There was a problem hiding this comment.
Qwen3.5 VL has this key but its empty. Could be used by other models also
| seq_length=args.seq_length, | ||
| batch_size=args.calib_batch_size, | ||
| ) | ||
| elif mtq.need_calibration(mtq_config): |
There was a problem hiding this comment.
swap the order of this if else, LLMs are used more often and should come first
| export_config = None | ||
| # Sort all parameters for metric-based pruning | ||
| self.hps_to_sort = SUPPORTED_HPARAMS | ||
| # Sort all parameters that may be pruned. Skip the ones explicitly excluded from the |
There was a problem hiding this comment.
very verbose comments in this file
| f"Invalid choice {export_config[hp_name]} for {n}! Available choices: {hp.choices}" | ||
| ) | ||
| hp.reset_choices() # Make sure ConcatHparam choices are updated after modify() | ||
| # Validate requested export_config values are achievable for *every* matching hparam, |
There was a problem hiding this comment.
this comment is unnecessary, the assertion explains it
| __all__ = ["load_mbridge_model_from_hf", "load_modelopt_megatron_checkpoint"] | ||
|
|
||
|
|
||
| def _patch_qwen35_moe_sequential_expert_mappings() -> None: |
There was a problem hiding this comment.
do model specific functions belong in plugins/mbridge.py? There should be a file for model specific functions
There was a problem hiding this comment.
Temporary WAR until 26.06.01 patch is released next month. Everything else is generic
What does this PR do?
Type of change: New feature
Adds vision-language model (VLM) support to the Megatron-Bridge examples for both Minitron pruning (
prune_minitron.py) and PTQ (quantize.py). Only the language model is pruned/quantized — the vision tower and vision→language projector are left in full precision — and the full VLM is saved back.hidden_sizeis skipped for pruning when it is shared with the vision→LM projector.Supported VLMs (tested e2e): Qwen3.5-VL (dense; hybrid GatedDeltaNet + gated attention) and Gemma3-VL (sliding/full attention).
Calibration (image-text)
Calibration is conditioned on real image-text data so the language model's pruning importance / quantizer statistics see vision-conditioned activations. The modality is inferred from
--calib_dataset_name:nemotron_vlm_dataset_v2) drives the full VLM forward;A shared
get_megatron_vlm_calibration_forward_loop(built onmegatron_prefill) drives the full VLM forward over image-text pairs fromvlm_dataset_utils(scienceqa,nemotron_vlm_dataset_v2, with config-driven subset/shard caps to bound downloads). It shards across data-parallel (DP) ranks like the text loop (#1804); context parallelism (CP) applies to text-only VLM calibration (the shared text loop), not the multimodal forward — splitting the sequence would misalign the merged vision embeddings.Results
Validated end-to-end on Cosmos-Reason2-2B (Qwen3-VL). Minitron NAS prunes the language-model tower 1.72B → ~1.59B (vision encoder + projector frozen), top_k=1. Calibration data drives pruning importance; image-text calibration runs the full VLM forward.
nemotron-post-training-dataset-v2)nemotron_vlm_dataset_v2)* Pruned MMLU on the 10% split (the pruning score function); baseline MMLU is the full set. The VLM-benchmark numbers for the text row were measured with a different text calibration set and are expected to be similar for
nemotron-post-training-dataset-v2(marked~).Note
These numbers come from short single runs on small eval splits — read them for high-level trends only, not as exact values.
Takeaways: pruning the LM tower of a VLM works end-to-end. Image-text calibration (this PR's feature) preserves the VLM benchmarks better than text-only — BLINK Rel-Depth ~0.77 vs ~0.69 and RealWorldQA ~0.61 vs ~0.57, both close to the unpruned baseline (0.76 / 0.61) — which is the motivation for calibrating on vision-conditioned activations.
Key changes
quantize.py: quantizes the root model with non-LM (vision) quantizers disabled, so the ModelOpt state lives on the root (required by the Megatron save) while only the language model is quantized.prune_minitron.py: image-text (or text) calibration for VLM pruning importance.megatron_prefill-based, unwraps tuple outputs, DP-sharded) +vlm_dataset_utils.Usage
Testing
test_prune_minitron.py::test_prune_minitron_vlm— Gemma3-VL, image-text (ScienceQA) calibration; full load → prune (depth + ffn) → save → reload.test_quantize_export.py::test_quantize_vlm— Qwen3.5-VL, text calibration; quantize LM → save Megatron checkpoint.test_prune_minitron,test_quantize_and_export) unchanged and passing.Not in scope
export.pysaves the Megatron checkpoint only for VLMs (tracked by a TODO inexport.py). The recommended path is to route the megatron→HF quant export through Megatron-Bridge'sAutoBridge.export_hf_weights_quant(quantization_checker, quant_fn, quant_block_size), which reuses the bridge's per-model mcore↔HF mapping — covering Qwen3.5-VL / Gemma3-VL and the vision tower/projector (left full precision) for free — so modelopt supplies only the checker + pack/scale fn +hf_quant_config(KV-cache scales need a separate path). This avoids re-authoring per-model mappings in modelopt (cf. Add Qwen3VL MCore Export support from PR 895 #1482's Qwen3-VL-onlymcore_qwen3vl.py).Note
Qwen3.5-VL MoE is not tested e2e: the Megatron-Bridge weight conversion expects packed (
gate_up_proj) experts that transformers' tiny checkpoint doesn't emit. MoE pruning itself is covered bytest_mcore_qwen35_gdn_moe_pruning.Before your PR is "Ready for review"
CONTRIBUTING.md: N/AAdditional Information
Follow-up to the GatedDeltaNet/MLA/latent-MoE pruning PR (#1747). Rebased on
mainto pick up CP/DP calibration (#1804); the VLM calibration loop now shards across DP ranks the same way.hidden_sizepruning for VLMs (requires resizing the vision projector) is left for a future PR.🤖 Generated with Claude Code
Summary by CodeRabbit
--cp_sizeflag in quantization examples.