feat(esrgan): add Real-ESRGAN model support for export, config, and perf by vortex-captain · Pull Request #480 · microsoft/winml-cli · GitHub
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion src/winml/modelkit/commands/export.py
9 changes: 7 additions & 2 deletions src/winml/modelkit/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel:
# ignored here (mirrors winml perf's ONNX path).
from transformers import AutoConfig

from ..export.io import ensure_hf_models_registered

ensure_hf_models_registered()
hf_config = AutoConfig.from_pretrained(config.model_id)
model = WinMLAutoModel.from_onnx(
onnx_path=config.model_path,
Expand Down Expand Up @@ -214,10 +217,12 @@ def _resolve_task(config: WinMLEvaluationConfig) -> str:

from transformers import AutoConfig

from ..loader.task import _detect_task_from_config
from ..export.io import ensure_hf_models_registered
from ..loader.task import _detect_task_and_class_from_config

ensure_hf_models_registered()
hf_config = AutoConfig.from_pretrained(config.model_id)
return _detect_task_from_config(hf_config)
return _detect_task_and_class_from_config(hf_config)[0]


def evaluate(config: WinMLEvaluationConfig) -> EvalResult:
Expand Down
91 changes: 90 additions & 1 deletion src/winml/modelkit/loader/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class WinMLLoaderConfig:
Requires trust_remote_code=True for security.
trust_remote_code: Whether to trust remote/custom code.
Required when using user_script.
loader_config_overrides: Optional patch applied recursively to the HF
config object after ``AutoConfig.from_pretrained``. Keys are
attribute names; nested dicts are merged into sub-configs (e.g.
``{"vision_config": {"image_size": 320}}``). Use cases include
selecting a non-default hyperparameter (``{"scale": 2}`` for
Real-ESRGAN) without committing a separate ``config.json``.

Example:
# Standard usage with auto-detection
Expand All @@ -71,6 +77,7 @@ class WinMLLoaderConfig:
module_path: str | None = None
user_script: str | None = None
trust_remote_code: bool = False
loader_config_overrides: dict[str, Any] | None = None

def to_dict(self) -> dict[str, Any]:
"""Serialize to dictionary.
Expand All @@ -91,6 +98,8 @@ def to_dict(self) -> dict[str, Any]:
result["user_script"] = self.user_script
if self.trust_remote_code:
result["trust_remote_code"] = self.trust_remote_code
if self.loader_config_overrides:
result["loader_config_overrides"] = self.loader_config_overrides
return result

@classmethod
Expand All @@ -110,9 +119,69 @@ def from_dict(cls, data: dict[str, Any]) -> WinMLLoaderConfig:
module_path=data.get("module_path"),
user_script=data.get("user_script"),
trust_remote_code=data.get("trust_remote_code", False),
loader_config_overrides=data.get("loader_config_overrides"),
)


def _deep_merge_dicts(base: dict, top: dict) -> dict:
"""Return a new dict deep-merging ``top`` on top of ``base``.

Nested dicts on both sides are merged recursively; otherwise ``top``'s
value wins. ``base`` is not mutated.
"""
out = dict(base)
for key, value in top.items():
if isinstance(value, dict) and isinstance(out.get(key), dict):
out[key] = _deep_merge_dicts(out[key], value)
else:
out[key] = value
return out


def apply_loader_config_overrides(
hf_config: PretrainedConfig,
overrides: dict[str, Any] | None,
) -> PretrainedConfig:
"""Return a new HF config with ``overrides`` deep-merged onto ``hf_config``.

Serializes the original config via :meth:`PretrainedConfig.to_dict`,
recursively deep-merges ``overrides`` into the resulting plain dict
(``overrides`` keys win on conflict, nested dicts merge into nested
dicts), then reconstructs the config via
``type(hf_config).from_dict(merged)``.

Going through ``to_dict`` / ``from_dict`` lets the config class's own
constructor handle validation, defaulting, and sub-config reconstruction
— including nested :class:`PretrainedConfig` fields like
``CLIPConfig.vision_config`` (HF's ``from_dict`` rebuilds them from
nested dicts automatically). A raw ``setattr`` loop, by contrast, can
silently create attributes the model class never reads when the
override key is missing on the original config.

Empty / ``None`` overrides return the original config unchanged.

Args:
hf_config: The HF :class:`PretrainedConfig` to patch.
overrides: Nested dict of overrides, or ``None``.

Returns:
A :class:`PretrainedConfig` of the same concrete type as
``hf_config`` with the overrides applied. May be the original
instance (when overrides are empty) or a freshly constructed one.
"""
if not overrides:
return hf_config

merged = _deep_merge_dicts(hf_config.to_dict(), overrides)
new_config = type(hf_config).from_dict(merged)
logger.debug(
"Applied loader_config_overrides to %s: %s",
type(hf_config).__name__,
overrides,
)
return new_config


def resolve_loader_config(
model_id: str | None = None,
*,
Expand All @@ -121,6 +190,7 @@ def resolve_loader_config(
model_type: str | None = None,
trust_remote_code: bool = False,
library_name: str = "transformers",
loader_config_overrides: dict[str, Any] | None = None,
) -> tuple[WinMLLoaderConfig, PretrainedConfig, type]:
"""Resolve all loader concerns from raw user inputs.

Expand Down Expand Up @@ -154,6 +224,11 @@ def resolve_loader_config(
When provided without task, the first supported task is used.
trust_remote_code: Whether to trust remote/custom code.
library_name: Source library for TasksManager lookup.
loader_config_overrides: Optional nested dict patched recursively onto
the HF config after it is loaded — see
:func:`apply_loader_config_overrides`. Stored on the returned
:class:`WinMLLoaderConfig` so downstream
``resolved_class.from_pretrained`` consumers can re-apply it.

Returns:
Tuple of:
Expand All @@ -169,8 +244,13 @@ def resolve_loader_config(
"""
from transformers import AutoConfig

from ..export.io import ensure_hf_models_registered
from .task import get_supported_tasks, resolve_task_and_model_class

# Ensure HF model registrations (AutoConfig.register, OnnxConfig overwrites,
# task-mapping fallbacks) have run before any AutoConfig / TasksManager calls.
ensure_hf_models_registered()

# 1. Load hf_config (depends on: model_id, model_type, or model_class)
if model_id is not None:
hf_config = AutoConfig.from_pretrained(
Expand Down Expand Up @@ -206,6 +286,10 @@ def resolve_loader_config(
f"attribute. Cannot proceed with config generation."
)

# 1a. Apply caller-supplied overrides — returns a new config when overrides
# are non-empty so the config class's own __init__ / from_dict validates.
hf_config = apply_loader_config_overrides(hf_config, loader_config_overrides)

# 2. Infer task (depends on: model_type param or hf_config.architectures)
if task is None and model_type is not None:
supported = get_supported_tasks(model_type, library_name=library_name)
Expand Down Expand Up @@ -241,6 +325,7 @@ def resolve_loader_config(
model_class=resolved_class.__name__,
model_type=resolved_model_type,
trust_remote_code=trust_remote_code,
loader_config_overrides=loader_config_overrides or None,
)

return loader_config, resolved_hf_config, resolved_class
Expand Down Expand Up @@ -313,4 +398,8 @@ def _resolve_hf_config_for_class(
return hf_config, hf_config.model_type


__all__ = ["WinMLLoaderConfig", "resolve_loader_config"]
__all__ = [
"WinMLLoaderConfig",
"apply_loader_config_overrides",
"resolve_loader_config",
]
27 changes: 25 additions & 2 deletions src/winml/modelkit/loader/hf.py
Loading
Loading