Fix: MiniCPM-V 4.6 training hangs on text-only samples with DeepSpeed#9639
Fix: MiniCPM-V 4.6 training hangs on text-only samples with DeepSpeed#9639randydl wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a _post_encode method in the MiniCPM template to handle cases during training where multimodal inputs (images/videos) are absent while DeepSpeed is enabled. It generates dummy vision embeddings to prevent DeepSpeed training issues. The review feedback suggests improving robustness by safely retrieving input_ids using .get() to avoid potential KeyErrors, and safely accessing the dtype of the vision tower to prevent AttributeErrors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
To improve robustness and prevent potential runtime errors:
- Use
inputs.get('input_ids')instead of direct key access to avoid aKeyErrorifinput_idsis missing orNone. - Standard PyTorch
nn.Moduleobjects do not have adtypeattribute. To avoid anAttributeErrorifvision_toweris wrapped or does not inherit fromPreTrainedModel, safely retrieve the dtype usinggetattr(base_model.vision_tower, 'dtype', inputs_embeds.dtype).
input_ids = inputs.get('input_ids')
if input_ids is None:
return inputs
base_model = self.get_base_model(model)
inputs_embeds = base_model.get_input_embeddings()(input_ids)
patch_size = base_model.config.vision_config.patch_size
vision_dtype = getattr(base_model.vision_tower, 'dtype', inputs_embeds.dtype)
dummy_pv = torch.zeros(
1, 3, 4 * patch_size, 4 * patch_size,
device=inputs_embeds.device, dtype=vision_dtype)
Fix: MiniCPM-V 4.6 training hangs on text-only samples with DeepSpeed
Problem
When training MiniCPM-V 4.6 with DeepSpeed ZeRO, the training process hangs if the dataset contains pure-text samples (no image/video). This is because:
MiniCPMV4_6Model.forward()skips the vision encoder (vision_tower+merger) entirely sincepixel_valuesandpixel_values_videosare bothNone.Solution
Add a
_post_encodemethod toMiniCPMV4_6Templatethat detects text-only samples under DeepSpeed and runs a minimal dummy image through the full vision pipeline (vision_tower→merger). The dummy features are then zeroed out viaimage_embeds.mean() * 0.and added to the text embeddings, which:The dummy image uses the smallest valid patch grid (
target_sizes=[[4, 4]]), which works for both 16x and 4x downsample modes, producing only 1 visual token in 16x mode — negligible compute overhead.Changes
swift/template/templates/minicpm.py: Add_post_encodetoMiniCPMV4_6Template, addis_deepspeed_enabledimport