Codegen RuntimeWrapper orchestration into single function by bobrenjc93 · Pull Request #181271 · pytorch/pytorch · GitHub
Skip to content

Codegen RuntimeWrapper orchestration into single function#181271

Draft
bobrenjc93 wants to merge 2 commits intogh/bobrenjc93/875/basefrom
gh/bobrenjc93/875/head
Draft

Codegen RuntimeWrapper orchestration into single function#181271
bobrenjc93 wants to merge 2 commits intogh/bobrenjc93/875/basefrom
gh/bobrenjc93/875/head

Conversation

@bobrenjc93
Copy link
Copy Markdown
Contributor

@bobrenjc93 bobrenjc93 commented Apr 23, 2026

Stack from ghstack (oldest at bottom):

Collapse _RuntimeCompiledFnInvoker.run,
_RuntimeForwardEpilogue.capture_orig_inputs,
increment_mutation_versions, and finalize into a single codegen'd
function with all branches resolved at compile time.

The generated function inlines:

  • capture_orig_inputs: dict comprehension → baked {idx: args[idx]} literal
  • increment_mutation_versions: conditional + generator → baked tuple
  • compiled_invoker.run: trace_joint branch + detach indices inlined
  • output arity validation: baked expected count
  • split mutated inputs: baked slice index
  • apply mutations / replay aliases: delegate to existing codegen'd functions
  • dynamic dims: baked per-output dim sets
  • grad_enabled_mutation: baked boolean

Generated code for inference (0 mutations, 1 alias, 1 input):

def _runtime_wrapper(_compiled_fn_, _first_ctx_, _on_before_call_, args):
    orig_inputs = {0: args[0]}
    with _first_ctx_():
        grad_enabled = torch.is_grad_enabled()
        try:
            if grad_enabled: torch._C._set_grad_enabled(False)
            _on_before_call_()
            all_outs = _normalize_as_list_(_compiled_fn_(args))
        finally:
            if grad_enabled: torch._C._set_grad_enabled(True)
    del args
    if len(all_outs) != 1:
        raise AssertionError(...)
    fw_outs = all_outs
    ret_outs = _replay_aliases_(orig_inputs, fw_outs)
    return ret_outs

RuntimeWrapper orchestration step in isolation (us/call):

Case Before (method dispatch) After (codegen) Speedup
0 alias, 0 mut, 5 args 0.35 us 0.17 us 2.1x
2 alias, 0 mut, 5 args 0.41 us 0.17 us 2.5x
0 alias, 2 mut, 5 args 0.74 us 0.25 us 3.0x
3 alias, 1 mut, 10 args 0.79 us 0.25 us 3.2x
5 alias, 3 mut, 20 args 0.93 us 0.32 us 2.9x

Collapse _RuntimeCompiledFnInvoker.run,
_RuntimeForwardEpilogue.capture_orig_inputs,
increment_mutation_versions, and finalize into a single codegen'd
function with all branches resolved at compile time.

The generated function inlines:
- capture_orig_inputs: dict comprehension → baked {idx: args[idx]} literal
- increment_mutation_versions: conditional + generator → baked tuple
- compiled_invoker.run: trace_joint branch + detach indices inlined
- output arity validation: baked expected count
- split mutated inputs: baked slice index
- apply mutations / replay aliases: delegate to existing codegen'd functions
- dynamic dims: baked per-output dim sets
- grad_enabled_mutation: baked boolean

Generated code for inference (0 mutations, 1 alias, 1 input):

    def _runtime_wrapper(_compiled_fn_, _first_ctx_, _on_before_call_, args):
        orig_inputs = {0: args[0]}
        with _first_ctx_():
            grad_enabled = torch.is_grad_enabled()
            try:
                if grad_enabled: torch._C._set_grad_enabled(False)
                _on_before_call_()
                all_outs = _normalize_as_list_(_compiled_fn_(args))
            finally:
                if grad_enabled: torch._C._set_grad_enabled(True)
        del args
        if len(all_outs) != 1:
            raise AssertionError(...)
        fw_outs = all_outs
        ret_outs = _replay_aliases_(orig_inputs, fw_outs)
        return ret_outs

RuntimeWrapper orchestration step in isolation (us/call):

| Case | Before (method dispatch) | After (codegen) | Speedup |
|---|---|---|---|
| 0 alias, 0 mut, 5 args | 0.35 us | 0.17 us | 2.1x |
| 2 alias, 0 mut, 5 args | 0.41 us | 0.17 us | 2.5x |
| 0 alias, 2 mut, 5 args | 0.74 us | 0.25 us | 3.0x |
| 3 alias, 1 mut, 10 args | 0.79 us | 0.25 us | 3.2x |
| 5 alias, 3 mut, 20 args | 0.93 us | 0.32 us | 2.9x |

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 23, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/181271

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 15 New Failures

As of commit cdcd85d with merge base b627bfb (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 23, 2026

Collapse _RuntimeCompiledFnInvoker.run,
_RuntimeForwardEpilogue.capture_orig_inputs,
increment_mutation_versions, and finalize into a single codegen'd
function with all branches resolved at compile time.

The generated function inlines:
- capture_orig_inputs: dict comprehension → baked {idx: args[idx]} literal
- increment_mutation_versions: conditional + generator → baked tuple
- compiled_invoker.run: trace_joint branch + detach indices inlined
- output arity validation: baked expected count
- split mutated inputs: baked slice index
- apply mutations / replay aliases: delegate to existing codegen'd functions
- dynamic dims: baked per-output dim sets
- grad_enabled_mutation: baked boolean

Generated code for inference (0 mutations, 1 alias, 1 input):

    def _runtime_wrapper(_compiled_fn_, _first_ctx_, _on_before_call_, args):
        orig_inputs = {0: args[0]}
        with _first_ctx_():
            grad_enabled = torch.is_grad_enabled()
            try:
                if grad_enabled: torch._C._set_grad_enabled(False)
                _on_before_call_()
                all_outs = _normalize_as_list_(_compiled_fn_(args))
            finally:
                if grad_enabled: torch._C._set_grad_enabled(True)
        del args
        if len(all_outs) != 1:
            raise AssertionError(...)
        fw_outs = all_outs
        ret_outs = _replay_aliases_(orig_inputs, fw_outs)
        return ret_outs

RuntimeWrapper orchestration step in isolation (us/call):

| Case | Before (method dispatch) | After (codegen) | Speedup |
|---|---|---|---|
| 0 alias, 0 mut, 5 args | 0.35 us | 0.17 us | 2.1x |
| 2 alias, 0 mut, 5 args | 0.41 us | 0.17 us | 2.5x |
| 0 alias, 2 mut, 5 args | 0.74 us | 0.25 us | 3.0x |
| 3 alias, 1 mut, 10 args | 0.79 us | 0.25 us | 3.2x |
| 5 alias, 3 mut, 20 args | 0.93 us | 0.32 us | 2.9x |

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Apr 23, 2026
Collapse _RuntimeCompiledFnInvoker.run,
_RuntimeForwardEpilogue.capture_orig_inputs,
increment_mutation_versions, and finalize into a single codegen'd
function with all branches resolved at compile time.

The generated function inlines:
- capture_orig_inputs: dict comprehension → baked {idx: args[idx]} literal
- increment_mutation_versions: conditional + generator → baked tuple
- compiled_invoker.run: trace_joint branch + detach indices inlined
- output arity validation: baked expected count
- split mutated inputs: baked slice index
- apply mutations / replay aliases: delegate to existing codegen'd functions
- dynamic dims: baked per-output dim sets
- grad_enabled_mutation: baked boolean

Generated code for inference (0 mutations, 1 alias, 1 input):

    def _runtime_wrapper(_compiled_fn_, _first_ctx_, _on_before_call_, args):
        orig_inputs = {0: args[0]}
        with _first_ctx_():
            grad_enabled = torch.is_grad_enabled()
            try:
                if grad_enabled: torch._C._set_grad_enabled(False)
                _on_before_call_()
                all_outs = _normalize_as_list_(_compiled_fn_(args))
            finally:
                if grad_enabled: torch._C._set_grad_enabled(True)
        del args
        if len(all_outs) != 1:
            raise AssertionError(...)
        fw_outs = all_outs
        ret_outs = _replay_aliases_(orig_inputs, fw_outs)
        return ret_outs

RuntimeWrapper orchestration step in isolation (us/call):

| Case | Before (method dispatch) | After (codegen) | Speedup |
|---|---|---|---|
| 0 alias, 0 mut, 5 args | 0.35 us | 0.17 us | 2.1x |
| 2 alias, 0 mut, 5 args | 0.41 us | 0.17 us | 2.5x |
| 0 alias, 2 mut, 5 args | 0.74 us | 0.25 us | 3.0x |
| 3 alias, 1 mut, 10 args | 0.79 us | 0.25 us | 3.2x |
| 5 alias, 3 mut, 20 args | 0.93 us | 0.32 us | 2.9x |

ghstack-source-id: e255dd9
Pull Request resolved: #181271
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant