[model] support qwen3_tts by Jintao-Huang · Pull Request #9638 · modelscope/ms-swift · GitHub
Skip to content

[model] support qwen3_tts#9638

Open
Jintao-Huang wants to merge 12 commits into
modelscope:mainfrom
Jintao-Huang:support_qwen3_tts
Open

[model] support qwen3_tts#9638
Jintao-Huang wants to merge 12 commits into
modelscope:mainfrom
Jintao-Huang:support_qwen3_tts

Conversation

@Jintao-Huang

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen3-TTS SFT training with a dual-channel architecture, introducing a custom loss function, model loader, patched forward pass, and template for data collation. Key feedback includes a critical fix to retrieve the correct hidden states from the last layer for the entire batch instead of indexing the embedding layer of a single batch element. Additionally, recommendations were made to pad reference mel spectrograms of varying durations to prevent collation crashes, slice codec IDs consistently to avoid shape mismatches, optimize audio loading/resampling with librosa, and explicitly cast boolean masks to prevent implicit type promotion issues.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation accesses outputs.hidden_states[0][-1], which retrieves the embedding layer (index 0) and then indexes the last batch element (index [-1]). This is a critical bug because:

  1. It uses the embedding layer instead of the final layer's hidden states for sub-talker prediction.
  2. It discards all batch elements except the last one, which will cause shape mismatch errors during batch training (when batch_size > 1).

It should be changed to outputs.hidden_states[-1] to correctly retrieve the last layer's hidden states for the entire batch.

Suggested change
hidden_states = outputs.hidden_states[0][-1]
hidden_states = outputs.hidden_states[-1]

codec_mask[i, 8 + text_ids_len - 1:8 + text_ids_len - 1 + codec_ids_len] = True
attention_mask[i, :8 + text_ids_len + codec_ids_len] = True

ref_mels = torch.cat([data['tts_ref_mel'] for data in batch], dim=0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Concatenating tts_ref_mel tensors directly using torch.cat will fail with a RuntimeError if the reference audios in the batch have different durations (and thus different mel spectrogram lengths).

To make the data collator robust, pad the mel spectrograms along the time dimension (dimension 1) to the maximum length in the current batch before concatenation.

        ref_mels = [data['tts_ref_mel'] for data in batch]
        max_mel_len = max(mel.shape[1] for mel in ref_mels)
        padded_ref_mels = []
        for mel in ref_mels:
            padding_len = max_mel_len - mel.shape[1]
            if padding_len > 0:
                padded_mel = F.pad(mel, (0, 0, 0, padding_len))
            else:
                padded_mel = mel
            padded_ref_mels.append(padded_mel)
        ref_mels = torch.cat(padded_ref_mels, dim=0)

Comment on lines +1778 to +1779
input_text_embedding = self.talker.model.text_embedding(input_text_ids) * text_embedding_mask
input_codec_embedding = self.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Multiplying float embedding tensors directly by boolean mask tensors (text_embedding_mask and codec_embedding_mask) relies on implicit type promotion, which can be slow or cause warnings/errors on certain PyTorch versions or hardware accelerators.

It is safer and more robust to explicitly cast the boolean masks to the embedding's data type before multiplication.

Suggested change
input_text_embedding = self.talker.model.text_embedding(input_text_ids) * text_embedding_mask
input_codec_embedding = self.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask
input_text_embedding = self.talker.model.text_embedding(input_text_ids) * text_embedding_mask.to(input_text_embedding.dtype)
input_codec_embedding = self.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask.to(input_codec_embedding.dtype)

Comment on lines +1802 to +1803
talker_hidden_states = hidden_states[codec_mask[:, :-1]]
talker_codec_ids = codec_ids[codec_mask]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Currently, talker_hidden_states is indexed using codec_mask[:, :-1], while talker_codec_ids is indexed using codec_mask.

If codec_mask ever contains a True value in its last column (e.g., due to different sequence lengths or template changes), this will result in a shape mismatch between talker_codec_ids and talker_hidden_states, causing the training to crash. To ensure consistency and robustness, slice codec_ids and use codec_mask[:, :-1] for both.

Suggested change
talker_hidden_states = hidden_states[codec_mask[:, :-1]]
talker_codec_ids = codec_ids[codec_mask]
talker_hidden_states = hidden_states[codec_mask[:, :-1]]
talker_codec_ids = codec_ids[:, :-1][codec_mask[:, :-1]]

Comment on lines +1154 to +1158
audio, sr = librosa.load(ref_audio_path, sr=None, mono=True)
if audio.ndim > 1:
audio = np.mean(audio, axis=-1)
if sr != 24000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=24000)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Loading the audio with sr=None and then resampling it with librosa.resample is inefficient because it loads the entire audio at its native sampling rate into memory first before performing resampling.

You can load and resample the audio to 24000 Hz directly in a single highly optimized step by passing sr=24000 to librosa.load. This also guarantees a 1D mono array, making the subsequent ndim check redundant.

Suggested change
audio, sr = librosa.load(ref_audio_path, sr=None, mono=True)
if audio.ndim > 1:
audio = np.mean(audio, axis=-1)
if sr != 24000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=24000)
audio, sr = librosa.load(ref_audio_path, sr=24000, mono=True)

# Add sub-talker codec embeddings (layers 1-15)
for i in range(1, 16):
codec_i_embedding = self.talker.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, :, i])
codec_i_embedding = codec_i_embedding * codec_mask.unsqueeze(-1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Multiplying the float tensor codec_i_embedding by the boolean tensor codec_mask relies on implicit type promotion. It is safer and more robust to explicitly cast the boolean mask to the embedding's data type before multiplication.

Suggested change

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant