GitHub - modelscope/mcore-bridge: MCore-Bridge: Providing Megatron-Core model definitions for state-of-the-art large models and making Megatron training as simple as Transformers. · GitHub
Skip to content

modelscope/mcore-bridge

Repository files navigation

MCore-Bridge: Making Megatron training as simple as Transformers

Providing Megatron-Core model definitions for state-of-the-art large models

ModelScope
中文   |   English  

📖 Table of Contents

☎ Groups

You can contact us and communicate with us by adding our group:

WeChat Group

🎉 News

  • 🎉 2026.03.30: MCore-Bridge is released! Providing Megatron-Core model definitions for state-of-the-art large models and making Megatron training as simple as Transformers.

🛠️ Installation

To install using pip:

pip install mcore-bridge -U

# Using uv
pip install uv
uv pip install mcore-bridge -U --torch-backend=auto

To install from source:

# pip install git+https://github.com/modelscope/mcore-bridge.git

git clone https://github.com/modelscope/mcore-bridge.git
cd mcore-bridge
pip install -e .

# Using uv
uv pip install -e . --torch-backend=auto

✨ Model List

The following is the list of models supported by MCore-Bridge:

Series model_type
Qwen qwen2, qwen2_moe
qwen2_vl, qwen2_5_vl, qwen2_5_omni
qwen3, qwen3_moe
qwen3_vl, qwen3_vl_moe, qwen3_omni_moe
qwen3_next, qwen3_5, qwen3_5_moe
DeepSeek deepseek_v3, deepseek_v32
GLM glm4, glm4_moe, glm4_moe_lite
glm4v, glm4v_moe,
glm_moe_dsa
MiniMax minimax_m2
Kimi kimi_k2, kimi_vl
InternLM internlm3, internvl_chat, internvl
Ovis ovis2_5
Llama llama, llama4
GPT-OSS gpt_oss
ERNIE ernie4_5, ernie4_5_moe
MiMo mimo
Dots dots1
OLMoE olmoe

🚀 Quick Start

How to use MCore-Bridge for training can be referred to the ms-swift project. Here we introduce how to use MCore-Bridge programmatically.

You need to create the following file (test.py), then run CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py. Below is sample code demonstrating how to use Mcore-Bridge for model creation, weight loading, export, and saving.

The saved model can be used for inference by referring to the example code in the model card.

# test env: transformers==5.2.0 megatron-core==0.16.1
import os
import torch
import torch.distributed as dist
from megatron.core import mpu
from modelscope import snapshot_download
from transformers import AutoConfig, AutoProcessor
from mcore_bridge import ModelConfig, get_mcore_model, hf_to_mcore_config

is_rank0 = int(os.getenv('RANK')) == 0
torch.cuda.set_device(f"cuda:{os.getenv('LOCAL_RANK')}")
dist.init_process_group(backend='nccl')
TP, PP, EP, ETP = 2, 2, 2, 1
mpu.initialize_model_parallel(
    tensor_model_parallel_size=TP,
    pipeline_model_parallel_size=PP,
    expert_model_parallel_size=EP,
    expert_tensor_parallel_size=ETP,
)

model_dir = snapshot_download('Qwen/Qwen3.5-35B-A3B')
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
config_kwargs = hf_to_mcore_config(hf_config)
config = ModelConfig(
    params_dtype=torch.bfloat16,
    tensor_model_parallel_size=TP,
    pipeline_model_parallel_size=PP,
    expert_model_parallel_size=EP,
    expert_tensor_parallel_size=ETP,
    sequence_parallel=True,
    mtp_num_layers=1,
    **config_kwargs)

# Create model
mg_models = get_mcore_model(config)

# Load weights
bridge = config.bridge
bridge.load_weights(mg_models, model_dir)

# Export weights
for name, parameter in bridge.export_weights(mg_models):
    pass

# Save weights
output_dir = 'Qwen3.5-35B-A3B-HF'
bridge.save_weights(mg_models, output_dir)
if is_rank0:
    processor.save_pretrained(output_dir)
    hf_config.save_pretrained(output_dir)

Using Peft

Mcore-Bridge is fully compatible with Peft for LoRA training. The following introduces how to use Peft to prepare a PeftModel and save the incremental weights.

You need to create the following file (test.py), then run CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py.

import copy
import os
import torch
import torch.distributed as dist
from megatron.core import mpu
from modelscope import snapshot_download
from peft import LoraConfig, get_peft_model
from transformers import AutoConfig, AutoProcessor

from mcore_bridge import ModelConfig, get_mcore_model, hf_to_mcore_config, set_random_seed

is_rank0 = int(os.getenv('RANK')) == 0
torch.cuda.set_device(f"cuda:{os.getenv('LOCAL_RANK')}")
dist.init_process_group(backend='nccl')
TP, PP = 2, 2
mpu.initialize_model_parallel(
    tensor_model_parallel_size=TP,
    pipeline_model_parallel_size=PP,
)
# To correctly initialize the model randomly (full parameters/LoRA)
# you need to set the random seed.
set_random_seed(42)

model_dir = snapshot_download('Qwen/Qwen3.5-4B')
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
config_kwargs = hf_to_mcore_config(hf_config)
config = ModelConfig(
    params_dtype=torch.bfloat16,
    tensor_model_parallel_size=TP,
    pipeline_model_parallel_size=PP,
    sequence_parallel=True,
    **config_kwargs)

# Create model and load weights
mg_models = get_mcore_model(config)
bridge = config.bridge
bridge.load_weights(mg_models, model_dir)

# Prepare PeftModel and load LoRA weights
# For multimodal models, it is recommended to use regex to specify target_modules
target_modules = r'^language_model.*\.(in_proj|out_proj|linear_fc1|linear_fc2|linear_qkv|linear_proj)$'
# When saving as safetensors, you need to store the corresponding HF target_modules
hf_target_modules = r'^model.language_model.*\.(in_proj_qkv|in_proj_z|in_proj_b|in_proj_a|out_proj|gate_proj|up_proj|down_proj|q_proj|k_proj|v_proj|o_proj)$'
lora_config = LoraConfig(task_type='CAUSAL_LM', r=8, lora_alpha=32, lora_dropout=0.05, target_modules=target_modules)
peft_models = [get_peft_model(model, lora_config) for model in mg_models]
# Optional
# bridge.load_weights(peft_models, model_dir, peft_format=True)

# Export LoRA weights
for name, parameter in bridge.export_weights(mg_models, peft_format=True):
    pass

# Save LoRA weights
output_dir = 'Qwen3.5-4B-LoRA'
bridge.save_weights(mg_models, output_dir, peft_format=True)
if is_rank0:
    hf_lora_config = copy.copy(lora_config)
    hf_lora_config.target_modules = hf_target_modules
    hf_lora_config.save_pretrained(output_dir)

Using the saved LoRA weights:

from transformers import Qwen3_5ForConditionalGeneration
from modelscope import snapshot_download
from peft import PeftModel

model_dir = snapshot_download('Qwen/Qwen3.5-4B')
model = Qwen3_5ForConditionalGeneration.from_pretrained(model_dir)
peft_model = PeftModel.from_pretrained(model, 'Qwen3.5-4B-LoRA')

🏛 License

This framework is licensed under the Apache License (Version 2.0). For models and datasets, please refer to the original resource page and follow the corresponding License.

About

MCore-Bridge: Providing Megatron-Core model definitions for state-of-the-art large models and making Megatron training as simple as Transformers.

Topics

Resources

License

Stars

Watchers

Forks

Packages

Contributors

Languages