Fix: MiniCPM-V 4.6 training hangs on text-only samples with DeepSpeed by randydl · Pull Request #9639 · modelscope/ms-swift · GitHub
Skip to content

Fix: MiniCPM-V 4.6 training hangs on text-only samples with DeepSpeed#9639

Open
randydl wants to merge 1 commit into
modelscope:mainfrom
randydl:dev
Open

Fix: MiniCPM-V 4.6 training hangs on text-only samples with DeepSpeed#9639
randydl wants to merge 1 commit into
modelscope:mainfrom
randydl:dev

Conversation

@randydl

@randydl randydl commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Fix: MiniCPM-V 4.6 training hangs on text-only samples with DeepSpeed

Problem

When training MiniCPM-V 4.6 with DeepSpeed ZeRO, the training process hangs if the dataset contains pure-text samples (no image/video). This is because:

  • For text-only samples, MiniCPMV4_6Model.forward() skips the vision encoder (vision_tower + merger) entirely since pixel_values and pixel_values_videos are both None.
  • Under DeepSpeed ZeRO, parameters are sharded across GPUs and only gathered on-demand via all-gather when a computation touches them.
  • When one GPU processes a text-only sample (no vision compute) while another processes an image sample (needs vision parameters), the all-gather synchronization deadlocks — one side never triggers the gather that the other side is waiting for.

Solution

Add a _post_encode method to MiniCPMV4_6Template that detects text-only samples under DeepSpeed and runs a minimal dummy image through the full vision pipeline (vision_towermerger). The dummy features are then zeroed out via image_embeds.mean() * 0. and added to the text embeddings, which:

  • Forces DeepSpeed to all-gather all vision model parameters, preventing the deadlock
  • Is mathematically a no-op (adds zero), so it does not affect training results

The dummy image uses the smallest valid patch grid (target_sizes=[[4, 4]]), which works for both 16x and 4x downsample modes, producing only 1 visual token in 16x mode — negligible compute overhead.

Changes

  • swift/template/templates/minicpm.py: Add _post_encode to MiniCPMV4_6Template, add is_deepspeed_enabled import

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a _post_encode method in the MiniCPM template to handle cases during training where multimodal inputs (images/videos) are absent while DeepSpeed is enabled. It generates dummy vision embeddings to prevent DeepSpeed training issues. The review feedback suggests improving robustness by safely retrieving input_ids using .get() to avoid potential KeyErrors, and safely accessing the dtype of the vision tower to prevent AttributeErrors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +678 to +684

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve robustness and prevent potential runtime errors:

  1. Use inputs.get('input_ids') instead of direct key access to avoid a KeyError if input_ids is missing or None.
  2. Standard PyTorch nn.Module objects do not have a dtype attribute. To avoid an AttributeError if vision_tower is wrapped or does not inherit from PreTrainedModel, safely retrieve the dtype using getattr(base_model.vision_tower, 'dtype', inputs_embeds.dtype).
            input_ids = inputs.get('input_ids')
            if input_ids is None:
                return inputs
            base_model = self.get_base_model(model)
            inputs_embeds = base_model.get_input_embeddings()(input_ids)
            patch_size = base_model.config.vision_config.patch_size
            vision_dtype = getattr(base_model.vision_tower, 'dtype', inputs_embeds.dtype)
            dummy_pv = torch.zeros(
                1, 3, 4 * patch_size, 4 * patch_size,
                device=inputs_embeds.device, dtype=vision_dtype)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant