[MLX] enable reinplacement for unary/binary ops by metascroy · Pull Request #20557 · pytorch/executorch · GitHub
Skip to content

[MLX] enable reinplacement for unary/binary ops#20557

Merged
metascroy merged 2 commits into
mainfrom
mlx-reinplace
Jun 29, 2026
Merged

[MLX] enable reinplacement for unary/binary ops#20557
metascroy merged 2 commits into
mainfrom
mlx-reinplace

Conversation

@metascroy

Copy link
Copy Markdown
Contributor

This adds ExecuTorch's reinplace pass to the MLX backends default pass, and targets all unary/binary ops.

@pytorch-bot

pytorch-bot Bot commented Jun 26, 2026

Copy link
Copy Markdown

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2026
@metascroy

Copy link
Copy Markdown
Contributor Author

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 14m 14s —— View job


Code Review — [MLX] enable reinplacement for unary/binary ops

  • Read the diff against origin/main
  • Review exir/passes/reinplace.py changes
  • Review MLX backend changes (ops.py, passes.py, program_builder.py)
  • Review tests
  • Post review feedback

Nice work — the schema-derived in-place mapping, the dynamic-shape-safe shape/dtype guard, and the slot-lifetime transfer in _mark_read are all carefully done, and the test coverage (broadcast, dtype-change, dynamic shapes, on-device numeric chain) is thoughtful. One correctness issue stands out, plus a few smaller notes.


🔴 Correctness: non-mutated operand reads aren't recorded in seen_nodes

exir/passes/reinplace.py:478-484 — when a node is reinplaced, only the mutated args are added to seen_nodes:

if all_safe:
    for arg_idx in mutated_args:
        seen_nodes.add(node.args[arg_idx])
    ...
    continue                         # <-- skips seen_nodes.update(node.all_input_nodes)

But the reinplaced in-place op still reads its non-mutated operands at runtime. Because the reverse walk continues before the usual seen_nodes.update(node.all_input_nodes), those reads are never recorded. An earlier-executing node that mutates that same value can then be reinplaced and overwrite it before this op reads it.

Consider (all temps, so the input-placeholder guard doesn't save us):

t = torch.exp(x)   # temp, 2 users
w = torch.exp(z)   # temp, 1 user
a = t + y          # arg0 = t  (mutated)
b = w + t          # arg0 = w  (mutated, single-use temp), t read as arg1
return a * b

Reverse walk with ops_to_inplace={add.Tensor}:

  1. b = w + t → mutates w (dead single-use temp) → reinplaced to w.add_(t). Adds only w to seen_nodes (not t).
  2. a = t + y → mutates t; t is not in seen_nodes → reinplaced to t.add_(y).

Execution order then runs t.add_(y) first, corrupting t before b = w.add_(t) reads it. Result: b == w + (t + y) instead of w + t. ep.module() numerics are wrong. The existing test_multi_use_self_reinplaced_at_most_once doesn't catch this because there t is the mutated arg of both consumers (covered by the new mutated-arg tracking); here t is mutated by one and merely read by the other.

Note: the MLX backend incidentally dodges this at build time, because _inplace_alias_slot / the unary handler require len(a_node.users) == 1 before aliasing — so MLX emits functionally and stays correct. But the shared reinplace_pass is now exercised on arbitrary binary ops, and the bug is real for any consumer of the pass (and for ep.module() replay).

Suggested fix — record all reads on the reinplace path too (superset of the current mutated-arg set, and still correct for the duplicate-mutated-arg case):

if all_safe:
    ... rewrite ...
    seen_nodes.update(node.all_input_nodes)   # in-place op reads ALL its operands
    continue

I'd add a regression test mirroring the scenario above asserting ep.module() matches eager. Fix this →


🟡 Minor / discussion

  • _make_inplace_passthrough_handler relies on "output write is the last emitted op." For leaky_relu/gelu/clamp the functional handlers compute intermediates (cond, scaled, lifted constants) into fresh temps before the final write to the aliased out==self slot, so it's correct today — but it's an implicit contract on handler internals (ops.py:4660+). If a future edit to one of those functional handlers reads self after writing out, this silently corrupts. Worth a one-line assertion or at least the contract being stated at each delegated handler, not only in the passthrough docstring.

  • _resync_output_specs (passes.py) positionally re-syncs spec.arg.name. This assumes reinplace never changes output count or order — true for the current rewrite (1:1 replace_all_uses_with + erase), but brittle if the pass ever drops/reorders outputs. A short comment asserting len(output_specs) == len(out_args) (or an explicit check) would make a future violation fail loudly instead of silently mis-pairing names.

  • getattr(spec.arg, "name", None)OutputSpec.arg can be a non-TensorArgument (e.g. ConstantArgument) without a name; the getattr guard handles it, but pairing a renamed tensor output against a constant-arg spec by position could still mis-sync if outputs are heterogeneous. Edge case only; current tests don't hit it.

  • Lazy import in MLXReinplacePass.call to break the cycle with ops.py is reasonable and well-commented. No change needed — just confirming the cycle is real (ops.py imports from reinplace at module load).


✅ Things that look good

  • Dynamic-shape handling via statically_known_true (conservative = no reinplace on unprovable dims) is the right call and is tested.
  • The shape and dtype guard correctly blocks lt_-style dtype changes and broadcast-growth — good catch to gate binary ops on this rather than excluding them.
  • Building a fresh ops_to_inplace set per call and not using DEFAULT_INPLACEABLE_OPS keeps index_put/KV-cache patterns untouched — clearly reasoned in the docstring.
  • _mark_read slot-lifetime transfer correctly avoids reclaiming an aliased temp; the comment explains the hazard well.

Overall: solid PR. The seen_nodes gap is the one item I'd consider blocking, since it affects the shared pass even though MLX's own handler guards around it.
· branch mlx-reinplace

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-codesync

meta-codesync Bot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this in D109898238.

@metascroy

Copy link
Copy Markdown
Contributor Author

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

@metascroy metascroy requested a review from Gasoonjia June 26, 2026 23:59
@metascroy metascroy merged commit 0f3303f into main Jun 29, 2026
235 of 237 checks passed
@metascroy metascroy deleted the mlx-reinplace branch June 29, 2026 19:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants