[model] support qwen3_tts#9638
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for Qwen3-TTS SFT training with a dual-channel architecture, introducing a custom loss function, model loader, patched forward pass, and template for data collation. Key feedback includes a critical fix to retrieve the correct hidden states from the last layer for the entire batch instead of indexing the embedding layer of a single batch element. Additionally, recommendations were made to pad reference mel spectrograms of varying durations to prevent collation crashes, slice codec IDs consistently to avoid shape mismatches, optimize audio loading/resampling with librosa, and explicitly cast boolean masks to prevent implicit type promotion issues.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
The current implementation accesses outputs.hidden_states[0][-1], which retrieves the embedding layer (index 0) and then indexes the last batch element (index [-1]). This is a critical bug because:
- It uses the embedding layer instead of the final layer's hidden states for sub-talker prediction.
- It discards all batch elements except the last one, which will cause shape mismatch errors during batch training (when
batch_size > 1).
It should be changed to outputs.hidden_states[-1] to correctly retrieve the last layer's hidden states for the entire batch.
| hidden_states = outputs.hidden_states[0][-1] | |
| hidden_states = outputs.hidden_states[-1] |
| codec_mask[i, 8 + text_ids_len - 1:8 + text_ids_len - 1 + codec_ids_len] = True | ||
| attention_mask[i, :8 + text_ids_len + codec_ids_len] = True | ||
|
|
||
| ref_mels = torch.cat([data['tts_ref_mel'] for data in batch], dim=0) |
There was a problem hiding this comment.
Concatenating tts_ref_mel tensors directly using torch.cat will fail with a RuntimeError if the reference audios in the batch have different durations (and thus different mel spectrogram lengths).
To make the data collator robust, pad the mel spectrograms along the time dimension (dimension 1) to the maximum length in the current batch before concatenation.
ref_mels = [data['tts_ref_mel'] for data in batch]
max_mel_len = max(mel.shape[1] for mel in ref_mels)
padded_ref_mels = []
for mel in ref_mels:
padding_len = max_mel_len - mel.shape[1]
if padding_len > 0:
padded_mel = F.pad(mel, (0, 0, 0, padding_len))
else:
padded_mel = mel
padded_ref_mels.append(padded_mel)
ref_mels = torch.cat(padded_ref_mels, dim=0)| input_text_embedding = self.talker.model.text_embedding(input_text_ids) * text_embedding_mask | ||
| input_codec_embedding = self.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask |
There was a problem hiding this comment.
Multiplying float embedding tensors directly by boolean mask tensors (text_embedding_mask and codec_embedding_mask) relies on implicit type promotion, which can be slow or cause warnings/errors on certain PyTorch versions or hardware accelerators.
It is safer and more robust to explicitly cast the boolean masks to the embedding's data type before multiplication.
| input_text_embedding = self.talker.model.text_embedding(input_text_ids) * text_embedding_mask | |
| input_codec_embedding = self.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask | |
| input_text_embedding = self.talker.model.text_embedding(input_text_ids) * text_embedding_mask.to(input_text_embedding.dtype) | |
| input_codec_embedding = self.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask.to(input_codec_embedding.dtype) |
| talker_hidden_states = hidden_states[codec_mask[:, :-1]] | ||
| talker_codec_ids = codec_ids[codec_mask] |
There was a problem hiding this comment.
Currently, talker_hidden_states is indexed using codec_mask[:, :-1], while talker_codec_ids is indexed using codec_mask.
If codec_mask ever contains a True value in its last column (e.g., due to different sequence lengths or template changes), this will result in a shape mismatch between talker_codec_ids and talker_hidden_states, causing the training to crash. To ensure consistency and robustness, slice codec_ids and use codec_mask[:, :-1] for both.
| talker_hidden_states = hidden_states[codec_mask[:, :-1]] | |
| talker_codec_ids = codec_ids[codec_mask] | |
| talker_hidden_states = hidden_states[codec_mask[:, :-1]] | |
| talker_codec_ids = codec_ids[:, :-1][codec_mask[:, :-1]] |
| audio, sr = librosa.load(ref_audio_path, sr=None, mono=True) | ||
| if audio.ndim > 1: | ||
| audio = np.mean(audio, axis=-1) | ||
| if sr != 24000: | ||
| audio = librosa.resample(audio, orig_sr=sr, target_sr=24000) |
There was a problem hiding this comment.
Loading the audio with sr=None and then resampling it with librosa.resample is inefficient because it loads the entire audio at its native sampling rate into memory first before performing resampling.
You can load and resample the audio to 24000 Hz directly in a single highly optimized step by passing sr=24000 to librosa.load. This also guarantees a 1D mono array, making the subsequent ndim check redundant.
| audio, sr = librosa.load(ref_audio_path, sr=None, mono=True) | |
| if audio.ndim > 1: | |
| audio = np.mean(audio, axis=-1) | |
| if sr != 24000: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=24000) | |
| audio, sr = librosa.load(ref_audio_path, sr=24000, mono=True) |
| # Add sub-talker codec embeddings (layers 1-15) | ||
| for i in range(1, 16): | ||
| codec_i_embedding = self.talker.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, :, i]) | ||
| codec_i_embedding = codec_i_embedding * codec_mask.unsqueeze(-1) |

No description provided.