You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This comment was automatically generated by Dr. CI and updates every 15 minutes.
meta-claBot
added
the
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
label
Jun 26, 2026
Nice work — the schema-derived in-place mapping, the dynamic-shape-safe shape/dtype guard, and the slot-lifetime transfer in _mark_read are all carefully done, and the test coverage (broadcast, dtype-change, dynamic shapes, on-device numeric chain) is thoughtful. One correctness issue stands out, plus a few smaller notes.
🔴 Correctness: non-mutated operand reads aren't recorded in seen_nodes
exir/passes/reinplace.py:478-484 — when a node is reinplaced, only the mutated args are added to seen_nodes:
But the reinplaced in-place op still reads its non-mutated operands at runtime. Because the reverse walk continues before the usual seen_nodes.update(node.all_input_nodes), those reads are never recorded. An earlier-executing node that mutates that same value can then be reinplaced and overwrite it before this op reads it.
Consider (all temps, so the input-placeholder guard doesn't save us):
t=torch.exp(x) # temp, 2 usersw=torch.exp(z) # temp, 1 usera=t+y# arg0 = t (mutated)b=w+t# arg0 = w (mutated, single-use temp), t read as arg1returna*b
Reverse walk with ops_to_inplace={add.Tensor}:
b = w + t → mutates w (dead single-use temp) → reinplaced to w.add_(t). Adds only w to seen_nodes (not t).
a = t + y → mutates t; t is not in seen_nodes → reinplaced to t.add_(y).
Execution order then runs t.add_(y)first, corrupting t before b = w.add_(t) reads it. Result: b == w + (t + y) instead of w + t. ep.module() numerics are wrong. The existing test_multi_use_self_reinplaced_at_most_once doesn't catch this because there t is the mutated arg of both consumers (covered by the new mutated-arg tracking); here t is mutated by one and merely read by the other.
Note: the MLX backend incidentally dodges this at build time, because _inplace_alias_slot / the unary handler require len(a_node.users) == 1 before aliasing — so MLX emits functionally and stays correct. But the sharedreinplace_pass is now exercised on arbitrary binary ops, and the bug is real for any consumer of the pass (and for ep.module() replay).
Suggested fix — record all reads on the reinplace path too (superset of the current mutated-arg set, and still correct for the duplicate-mutated-arg case):
ifall_safe:
... rewrite ...
seen_nodes.update(node.all_input_nodes) # in-place op reads ALL its operandscontinue
I'd add a regression test mirroring the scenario above asserting ep.module() matches eager. Fix this →
🟡 Minor / discussion
_make_inplace_passthrough_handler relies on "output write is the last emitted op." For leaky_relu/gelu/clamp the functional handlers compute intermediates (cond, scaled, lifted constants) into fresh temps before the final write to the aliased out==self slot, so it's correct today — but it's an implicit contract on handler internals (ops.py:4660+). If a future edit to one of those functional handlers reads selfafter writing out, this silently corrupts. Worth a one-line assertion or at least the contract being stated at each delegated handler, not only in the passthrough docstring.
_resync_output_specs (passes.py) positionally re-syncs spec.arg.name. This assumes reinplace never changes output count or order — true for the current rewrite (1:1 replace_all_uses_with + erase), but brittle if the pass ever drops/reorders outputs. A short comment asserting len(output_specs) == len(out_args) (or an explicit check) would make a future violation fail loudly instead of silently mis-pairing names.
getattr(spec.arg, "name", None) — OutputSpec.arg can be a non-TensorArgument (e.g. ConstantArgument) without a name; the getattr guard handles it, but pairing a renamed tensor output against a constant-arg spec by position could still mis-sync if outputs are heterogeneous. Edge case only; current tests don't hit it.
Lazy import in MLXReinplacePass.call to break the cycle with ops.py is reasonable and well-commented. No change needed — just confirming the cycle is real (ops.py imports from reinplace at module load).
✅ Things that look good
Dynamic-shape handling via statically_known_true (conservative = no reinplace on unprovable dims) is the right call and is tested.
The shape and dtype guard correctly blocks lt_-style dtype changes and broadcast-growth — good catch to gate binary ops on this rather than excluding them.
Building a fresh ops_to_inplace set per call and not using DEFAULT_INPLACEABLE_OPS keeps index_put/KV-cache patterns untouched — clearly reasoned in the docstring.
_mark_read slot-lifetime transfer correctly avoids reclaiming an aliased temp; the comment explains the hazard well.
Overall: solid PR. The seen_nodes gap is the one item I'd consider blocking, since it affects the shared pass even though MLX's own handler guards around it.
· branch mlx-reinplace
If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.
To add a label, you can comment to pytorchbot, for example @pytorchbot label "release notes: none"
This revision addresses every point from the prior review — the blocking seen_nodes correctness issue, the _resync_output_specs length assertion, and the passthrough-handler contract assertion are all in. Re-verifying the headline concern below.
✅ The blocking seen_nodes bug is correctly resolved
The prior review flagged that the reinplace path skipped seen_nodes.update(node.all_input_nodes), so a value read as a non-mutated operand by a reinplaced node could be overwritten by an earlier-executing reinplaced node. This PR fixes it — but via a different mechanism than the originally suggested seen_nodes.update(...), so it's worth confirming the new approach holds.
The author's comment at exir/passes/reinplace.py:479-491 argues the non-mutated reads are recorded because the reverse walk revisits the inserted in-place node. Tracing FX's linked-list semantics, this is sound:
The reverse iterator (_NodeIter, advancing via _prev) is positioned at node when the body runs.
inserting_before(node) links the new in-place node as node._prev.
erase_node → _remove_from_list only rewires the neighbors' pointers; it leaves the erased node's own _prev pointing at the inserted node. So the next __next__ yields the inserted node.
The inserted node's target is the edge in-place op, which is not a key in resolved (only functional ops are), so it hits the generic branch at reinplace.py:427-430 and runs seen_nodes.update(node.all_input_nodes) — recording all operands, mutated and non-mutated alike.
Replaying the prior counterexample (b = w + t reinplaced first, t read as arg1): after b.add_ is inserted and revisited, t lands in seen_nodes, so a = t + y is no longer reinplaced. Correct. The on-device ReinplaceChainTest and ep.module() numeric checks back this up.
Minor: the explicit for arg_idx in mutated_args: seen_nodes.add(...) at line 490-491 is now strictly redundant with the revisit (which records the mutated args too, since they're also all_input_nodes of the inserted node). It's harmless belt-and-suspenders, but the comment could note that the revisit is the load-bearing mechanism and the explicit add is just defensive — today's wording reads as if the explicit add is what prevents the double-reinplace, when really the revisit does.
🟡 Minor / discussion
Coupling to FX iterator internals. The whole correctness argument rests on "erase_node leaves _prev intact, so reversed() visits the inserted node next." That's stable, long-standing FX behavior, but it's an implicit dependency. If you wanted to be robust against a future FX iterator change (or a reader who doesn't trust the revisit), adding the suggested seen_nodes.update(node.all_input_nodes) directly on the reinplace path would make the pass self-contained and independent of iterator semantics — it's a superset of what the revisit records, so it can't hurt. Worth considering as defense-in-depth even though the current code is correct.
_make_inplace_addsub_handler (ops.py:4500ish) takes the alpha-constant dtype from n.args[0] (self/arg0), but alpha scales arg1 (b). For the in-place arithmetic ops this PR targets, self and other share a dtype (the reinplace guard already enforces self↔out dtype match), so this is fine in practice — just slightly surprising to read "scale b" while sourcing the dtype from a. A one-word comment would settle it.
REGISTRY._handlers.get(_func_aten) (ops.py:4670ish) reaches into a private attribute to look up the already-registered functional handler. Works, but if REGISTRY exposes a public getter it'd be cleaner; otherwise fine as an internal-module access.
✅ Things that look good
_resync_output_specs now asserts len(output_specs) == len(out_args), so a future rewrite that drops/reorders outputs fails loudly instead of silently mis-pairing names via zip truncation — exactly the prior suggestion.
_make_inplace_passthrough_handler asserts the delegated handler returned the aliased slot, catching the "stopped writing n's slot" violation; the docstring honestly flags the residual read-after-write hazard it can't catch.
_mark_read's slot-lifetime transfer (program_builder.py:335-355) correctly skips reclaiming a temp slot the consumer has aliased in place — the comment explains the use-after-free hazard well.
Building a fresh ops_to_inplace per call and deliberately not using DEFAULT_INPLACEABLE_OPS keeps index_put/KV-cache patterns untouched.
Dynamic-shape handling via statically_known_true (conservative — no reinplace on unprovable dims) is right and tested; the shape+dtype guard correctly blocks lt_ dtype changes and broadcast growth.
Overall: the prior blocking issue is genuinely resolved and the feedback is well-incorporated. My only real suggestion is the defense-in-depth note — adding the explicit all_input_nodes update on the reinplace path so correctness doesn't hinge on FX iterator internals. Not blocking. Nice work.
· branch mlx-reinplace
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CLA SignedThis label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
2 participants
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This adds ExecuTorch's reinplace pass to the MLX backends default pass, and targets all unary/binary ops.