GH-135379: Support limited scalar replacement for replicated uops in the JIT code generator. by markshannon · Pull Request #135563 · python/cpython · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
347 changes: 176 additions & 171 deletions Include/internal/pycore_uop_ids.h

Large diffs are not rendered by default.

37 changes: 30 additions & 7 deletions Include/internal/pycore_uop_metadata.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions Python/bytecodes.c
61 changes: 59 additions & 2 deletions Python/executor_cases.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions Python/generated_cases.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Python/optimizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1292,8 +1292,8 @@ uop_optimize(
for (int pc = 0; pc < length; pc++) {
int opcode = buffer[pc].opcode;
int oparg = buffer[pc].oparg;
if (oparg < _PyUop_Replication[opcode]) {
buffer[pc].opcode = opcode + oparg + 1;
if (oparg < _PyUop_Replication[opcode].stop && oparg >= _PyUop_Replication[opcode].start) {
buffer[pc].opcode = opcode + oparg + 1 - _PyUop_Replication[opcode].start;
assert(strncmp(_PyOpcode_uop_name[buffer[pc].opcode], _PyOpcode_uop_name[opcode], strlen(_PyOpcode_uop_name[opcode])) == 0);
}
else if (is_terminator(&buffer[pc])) {
Expand Down
38 changes: 33 additions & 5 deletions Tools/cases_generator/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class Uop:
properties: Properties
_size: int = -1
implicitly_created: bool = False
replicated = 0
replicated = range(0)
replicates: "Uop | None" = None
# Size of the instruction(s), only set for uops containing the INSTRUCTION_SIZE macro
instruction_size: int | None = None
Expand Down Expand Up @@ -868,6 +868,28 @@ def compute_properties(op: parser.CodeDef) -> Properties:
needs_prev=variable_used(op, "prev_instr"),
)

def expand(items: list[StackItem], oparg: int) -> list[StackItem]:
# Only replace array item with scalar if no more than one item is an array
index = -1
for i, item in enumerate(items):
if "oparg" in item.size:
if index >= 0:
return items
index = i
if index < 0:
return items
Comment on lines +873 to +880

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire part is somewhat confusing, but I'll let it slide since there's a comment up top :).

try:
count = int(eval(items[index].size.replace("oparg", str(oparg))))
except ValueError:
return items
return items[:index] + [
StackItem(items[index].name + f"_{i}", "", items[index].peek, items[index].used) for i in range(count)
] + items[index+1:]

def scalarize_stack(stack: StackEffect, oparg: int) -> StackEffect:
stack.inputs = expand(stack.inputs, oparg)
stack.outputs = expand(stack.outputs, oparg)
return stack

def make_uop(
name: str,
Expand All @@ -887,20 +909,26 @@ def make_uop(
)
for anno in op.annotations:
if anno.startswith("replicate"):
result.replicated = int(anno[10:-1])
text = anno[10:-1]
start, stop = text.split(":")
result.replicated = range(int(start), int(stop))
break
else:
return result
for oparg in range(result.replicated):
for oparg in result.replicated:
name_x = name + "_" + str(oparg)
properties = compute_properties(op)
properties.oparg = False
properties.const_oparg = oparg
stack = analyze_stack(op)
if not variable_used(op, "oparg"):
stack = scalarize_stack(stack, oparg)
else:
properties.const_oparg = oparg
rep = Uop(
name=name_x,
context=op.context,
annotations=op.annotations,
stack=analyze_stack(op),
stack=stack,
caches=analyze_caches(inputs),
local_stores=find_variable_stores(op),
body=op.block,
Expand Down
8 changes: 6 additions & 2 deletions Tools/cases_generator/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,13 @@ def inst_header(self) -> InstHeader | None:
while anno := self.expect(lx.ANNOTATION):
if anno.text == "replicate":
self.require(lx.LPAREN)
times = self.require(lx.NUMBER)
stop = self.require(lx.NUMBER)
start_text = "0"
if self.expect(lx.COLON):
start_text = stop.text
stop = self.require(lx.NUMBER)
self.require(lx.RPAREN)
annotations.append(f"replicate({times.text})")
annotations.append(f"replicate({start_text}:{stop.text})")
else:
annotations.append(anno.text)
tkn = self.expect(lx.INST)
Expand Down
8 changes: 5 additions & 3 deletions Tools/cases_generator/uop_metadata_generator.py
Loading