Declarative weight conversion framework#523
Open
jlamypoirier wants to merge 10 commits into
Open
Conversation
Mirrors the post-#508 config-side shape on the weight side. Adds ``_create_weight_converters`` + walker on ``ConfigSectionConverter`` with new primitives (Nested/BlockSequence/Linear/OutputProjection) in ``external.py``. Relocates ``KeyValueWeightConverter``/``TransposeSplitWeightConverter`` (formerly ``MLPLayer2Converter``) so the layers/multimodal converters can import them from the engine instead of llama.py. Migrates llama/mistral/qwen2/mixtral/mtp_llama to the new shape. Tied embeddings move from per-call ``drop_on_export=tied`` plumbing to the walker-central ``OutputProjectionWeightConverter`` marker. Legacy ``get_converters``/``get_parameter_converter``/``get_weight_and_bias_converters`` helpers stay in llama.py as shims for the not-yet-migrated converters (apriel/apriel2/gemma4/multimodal); cleanup commit removes them. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds ``DispatchWeightConverter`` (runtime-type dispatch on a single attribute) and ``TypedDictWeightConverter`` (per-key dispatch on a ``dict[str, Config]`` attribute) to the framework. Apriel's hybrid block-sequence uses ``BlockSequenceWeightConverter``'s ``dispatch_registry``; Apriel2 uses both new primitives — ``DispatchWeightConverter`` for the block mixer + normalization dispatch, ``TypedDictWeightConverter`` for the StochasticMixer sub-mixer fan-out. The Apriel2 Fixed/Pattern decoder section converters now contribute no weights of their own; the block fan-out runs once at the base-model level via ``BlockSequenceWeightConverter``, which already handles both shapes through its ``FixedBlockSequenceConfig`` / ``PatternBlockSequenceConfig`` dispatch. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Last of the model-side migrations: the multimodal Llava (Pixtral vision + Mistral text) and the multimodal Apriel2 (Apriel2 vision + Apriel2 text base) handlers now declare their weight conversions via ``_create_weight_converters``. ``PatchEmbeddingWeightConverter`` is now imported from the engine (relocated earlier); the local copies in ``llava.py`` are removed. Gemma4 keeps its imperative ``get_converters`` and continues to work via the ``ConfigSectionConverter.get_converters`` shim — its helper classes don't inherit ``ConfigSectionConverter`` so they don't get a free declarative migration. Revisit in cleanup or a follow-up. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Removes the post-migration deadweight: * ``ConfigSectionConverter.get_converters`` (section-shape shim) — every consumer now calls ``emit_weight_converters`` directly. * ``get_parameter_converter``, ``get_weight_and_bias_converters``, ``MLPLayer2Converter`` alias in llama.py — no remaining callers. * ``drop_on_export`` parameter plumbing throughout gemma4 — the only legitimate use case (head tied embeddings) is handled by ``OutputProjectionWeightConverter`` at the walker level. Gemma4 gains a local ``_linear_converters`` helper that builds ``.weight`` and (optional) ``.bias`` ``WeightConverter`` instances directly — Gemma4's helper classes don't inherit ``ConfigSectionConverter`` so the ``LinearWeightConverter`` declarative primitive doesn't apply. ``effective_bias`` stays in llama.py as a published helper — still used by Apriel/Apriel2 config-side ``CustomConfigConverter`` export_fns and the matching ``LinearWeightConverter.bias_fn`` lambdas. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Collaborator
Author
CI status [fixed as of #524]CI shows 60 failures, but these are pre-existing and identical to main's. Confirmed by diffing the failed-test names: All failures are Local + cluster results on the PR's HEAD:
|
* Drop ``AprielBlockConverter.get_converters`` — calls a now-nonexistent ``.get_converters`` on the block-converter registry values and is unreachable in practice (dispatch goes through ``BlockSequenceWeightConverter(dispatch_registry=...)``). * Drop the unused ``block_converter_class`` ClassVar from Apriel/Mistral/Qwen2/Mixtral head converters — only MTP-Llama's head reads it (kept on ``LlamaHeadConverter``). * Drop the ``exported_config`` parameter throughout: no surviving ``get_converters`` override reads it, and the ``__init__`` ``_export_config(model.config)`` precompute it powered is gone. Tied-embedding handling lives on ``OutputProjectionWeightConverter``. * Fold ``_FixedBlockFanoutWeightConverter`` into ``BlockSequenceWeightConverter`` via a ``config_attr=""`` sentinel for "section IS the block sequence" — kills the cross-package private import from ``llama.py`` into ``multimodal/apriel2.py``. * ``LinearWeightConverter.bias_fn`` and ``OutputProjectionWeightConverter._emit`` use direct attribute access instead of ``getattr(..., default)`` — misuse now surfaces as ``AttributeError`` rather than silently falling back to ``False``. * Tighten ``BlockSequenceWeightConverter``'s assertion to XOR — passing both ``block_converter_class`` and ``dispatch_registry`` no longer silently ignores the former. * Extract ``_join_prefix(parent, own)`` helper for the empty-handling rule shared across Nested/BlockSequence/Dispatch/TypedDict ``_emit`` methods. * Apriel2 base + multimodal aggregators get a ``block_converter_class`` ClassVar (matches ``LlamaBaseModelConverter``) instead of hardcoding ``Apriel2BlockConverter`` inline. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Items 1-7 from the second /review-coarse pass: 1. Delete unused IgnoreImportWeightConverter / IgnoreExportWeightConverter (last callers were the removed drop_on_export plumbing). 2. Migrate Gemma4 to declarative weight converters — Gemma4Attention/MLP/ MoEMLP/HybridMoEMLP/Block/Decoder now inherit ConfigSectionConverter and define _create_weight_converters. The Gemma4-specific transforms (shared-K/V branching, mlp-type dispatch with divergent HF prefixes, conditional norm_2, two-level hybrid-MoE norm descent) live as small private WeightConverter subclasses next to the existing MoE layer converters. Config side stays imperative under CustomConfigConverter at the aggregator (Gemma4 sliding/full block divergence prevents a uniform per-block declarative shape); each helper carries a blanket IgnoredConfigConverter to silence the static walker. 3. Add optional=True to NestedWeightConverter and fold Apriel2 multimodal's vision_encoder back into _create_weight_converters (skip when None). 4. Fold Llava head into LlavaBaseModelConverter._create_weight_converters (NestedWeightConverter with empty hf_prefix; LlavaHead's leaf names are already absolute). 5. Move block_converter_class ClassVar from LlamaHeadConverter to its sole reader MTPLlamaHeadConverter. 6. Replace BlockSequenceWeightConverter's config_attr="" sentinel with an explicit read_self=True flag (2 callers updated). 7. Delete the four pass-only HeadConverter subclasses (Mistral, Mixtral, Qwen2, Apriel); the head_converter_class ClassVar inherits from LlamaBaseModelConverter, and LlavaHeadConverter rebases on LlamaHeadConverter directly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Items 1-4 from the third /review-coarse pass:
1. Split ``BlockSequenceWeightConverter`` into three flat primitives — single-class
(``BlockSequenceWeightConverter``), per-position dispatch
(``DispatchBlockSequenceWeightConverter``), and section-IS-the-block-sequence
(``SelfBlockSequenceWeightConverter``). Drops the dual XOR ``Assert.custom``
switchboard. Shared list-materialization extracted to ``_expand_block_sequence``.
2. Generalize two framework primitives so two of Gemma4's four private one-offs
fold in:
* ``DispatchWeightConverter`` gains ``hf_prefix_overrides`` for per-branch HF
paths (Gemma4's block.mlp dispatch where dense lands under ``mlp.<...>`` and
hybrid MoE flat-merges into the block root).
* ``NestedWeightConverter.config_attr`` accepts tuple/dotted paths for chained
``getattr`` (Gemma4's hybrid-MoE inner norms via ``("dense", "pre_norm")``).
Rename ``_join_prefix`` and ``_prepend_prefix`` to drop the underscore — now
public utilities used by Gemma4's remaining two one-offs.
3. Lift the one-line ``cls.emit_weight_converters(config, "", "")`` passthrough
into ``HuggingFaceBaseModelConverter.get_converters`` as a concrete default.
Apriel2 (text), Apriel2 multimodal, and Llava lose their overrides.
Apriel2BaseModelConverter now multi-inherits ``HuggingFaceBaseModelConverter``
so it picks up the default. Llama, Gemma4, MTP-Llama keep their overrides —
they splice ``head_converter_class.get_converters(config)`` separately because
the head needs the full ``GPTBaseModelConfig`` (MTP-Llama reads
``config.decoder.last_block_config`` for per-prediction-head fan-out).
4. ``AprielBlockConverter`` docstring: ``get_converters`` was removed in the
prior cleanup pass; update the docstring to describe the class as a registry
holder consumed by ``ListDispatchConfigConverter`` (config side) and
``DispatchBlockSequenceWeightConverter`` (weight side).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Six Gemma4 section converters owned their config side only nominally — the actual conversion lives on the aggregator's CustomConfigConverter (fast_llm_recurses=True). Each declared an empty-path IgnoredConfigConverter blanket-claim solely to silence the static architecture-coverage walker. Replace the boilerplate with an explicit weight_only ClassVar on ConfigSectionConverter that short-circuits both _create_config_converters (empty default) and check_architecture_coverage.
- Delete dead helpers: get_apriel2_decoder_converter and two unreferenced _get_weight_converters classmethods (the latter also broken — called get_converters with a now-removed two-arg signature). - Inline the now-unconditional ConfigSectionConverter coverage call in HuggingfaceStateDictCheckpointHandler._check_hf_coverage; every concrete base_model_converter_class is one. - Drop the _effective_bias import alias in apriel2 (no name conflict). - Trim docstrings that referenced the previous (removed) implementation: OutputProjectionWeightConverter / LinearWeightConverter / LlavaHeadConverter. - Accept bias_fn=True/False bool literals on LinearWeightConverter; replaces ~10 `lambda c: False` callsites including all `no_bias = lambda c: False` named bindings. - Hoist orphan trailing comments on Apriel2Fixed/PatternDecoderConverter into class docstrings.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Mirrors the post-#508 declarative config-side shape on the weight side. Every section converter now declares its weight mapping via
_create_weight_converters(cached, override-by-key), the walker flattens into the existinglist[WeightConverter]runtime, and tied embeddings move from per-calldrop_on_export=tiedplumbing into the walker-centralOutputProjectionWeightConvertermarker. Item 1 of the post-#508 deferred list.What changed
external.py— new primitivesNestedWeightConverter,BlockSequenceWeightConverter(Fixed/Pattern + optional per-positiondispatch_registry),DispatchWeightConverter(single-attribute type dispatch),TypedDictWeightConverter(per-key dispatch ondict[str, Config]),LinearWeightConverter(bundles.weight/.biaswith per-sectionbias_fn),OutputProjectionWeightConverter(walker-dropped whenroot_config.tied_embedding_weightis set). Generic transformsKeyValueWeightConverter/TransposeSplitWeightConverter(wasMLPLayer2Converter) /PatchEmbeddingWeightConverterrelocate here from llama.py / llava.py.get_parameter_converter,get_weight_and_bias_converters,MLPLayer2Converteralias, and thedrop_on_exportparameter plumbing are removed.effective_biasstays as a published helper for Apriel/Apriel2CustomConfigConverterexport_fns.Test plan
Follow-up
WeightConverter) is worth doing but not on the critical path — deferred.🤖 Generated with Claude Code