perf/chunktransform by d-v-b · Pull Request #3722 · zarr-developers/zarr-python · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/zarr/abc/codec.py
108 changes: 107 additions & 1 deletion src/zarr/core/codec_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from itertools import islice, pairwise
from typing import TYPE_CHECKING, Any
from warnings import warn
Expand All @@ -14,6 +14,7 @@
Codec,
CodecPipeline,
GetResult,
SupportsSyncCodec,
)
from zarr.core.common import concurrent_map
from zarr.core.config import config
Expand Down Expand Up @@ -66,6 +67,111 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
return fill_value


@dataclass(slots=True, kw_only=True)
class ChunkTransform:
"""A synchronous codec chain bound to an ArraySpec.

Provides `encode` and `decode` for pure-compute codec operations
(no IO, no threading, no batching).

All codecs must implement `SupportsSyncCodec`. Construction will
raise `TypeError` if any codec does not.
"""

codecs: tuple[Codec, ...]
array_spec: ArraySpec

# (sync codec, input_spec) pairs in pipeline order.
_aa_codecs: tuple[tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ArraySpec], ...] = field(
init=False, repr=False, compare=False
)
_ab_codec: SupportsSyncCodec[NDBuffer, Buffer] = field(init=False, repr=False, compare=False)
_ab_spec: ArraySpec = field(init=False, repr=False, compare=False)
_bb_codecs: tuple[SupportsSyncCodec[Buffer, Buffer], ...] = field(
init=False, repr=False, compare=False
)

def __post_init__(self) -> None:
non_sync = [c for c in self.codecs if not isinstance(c, SupportsSyncCodec)]
if non_sync:
names = ", ".join(type(c).__name__ for c in non_sync)
raise TypeError(
f"All codecs must implement SupportsSyncCodec. The following do not: {names}"
)

aa, ab, bb = codecs_from_list(list(self.codecs))

aa_codecs: list[tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ArraySpec]] = []
spec = self.array_spec
for aa_codec in aa:
assert isinstance(aa_codec, SupportsSyncCodec)
aa_codecs.append((aa_codec, spec))
spec = aa_codec.resolve_metadata(spec)

self._aa_codecs = tuple(aa_codecs)
assert isinstance(ab, SupportsSyncCodec)
self._ab_codec = ab
self._ab_spec = spec
bb_sync: list[SupportsSyncCodec[Buffer, Buffer]] = []
for bb_codec in bb:
assert isinstance(bb_codec, SupportsSyncCodec)
bb_sync.append(bb_codec)
self._bb_codecs = tuple(bb_sync)

def decode(
self,
chunk_bytes: Buffer,
) -> NDBuffer:
"""Decode a single chunk through the full codec chain, synchronously.

Pure compute -- no IO.
"""
data: Buffer = chunk_bytes
for bb_codec in reversed(self._bb_codecs):
data = bb_codec._decode_sync(data, self._ab_spec)

chunk_array: NDBuffer = self._ab_codec._decode_sync(data, self._ab_spec)

for aa_codec, spec in reversed(self._aa_codecs):
chunk_array = aa_codec._decode_sync(chunk_array, spec)

return chunk_array

def encode(
self,
chunk_array: NDBuffer,
) -> Buffer | None:
"""Encode a single chunk through the full codec chain, synchronously.

Pure compute -- no IO.
"""
aa_data: NDBuffer = chunk_array
for aa_codec, spec in self._aa_codecs:
aa_result = aa_codec._encode_sync(aa_data, spec)
if aa_result is None:
return None
aa_data = aa_result

ab_result = self._ab_codec._encode_sync(aa_data, self._ab_spec)
if ab_result is None:
return None

bb_data: Buffer = ab_result
for bb_codec in self._bb_codecs:
bb_result = bb_codec._encode_sync(bb_data, self._ab_spec)
if bb_result is None:
return None
bb_data = bb_result

return bb_data

def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
for codec in self.codecs:
byte_length = codec.compute_encoded_size(byte_length, array_spec)
array_spec = codec.resolve_metadata(array_spec)
return byte_length


@dataclass(frozen=True)
class BatchedCodecPipeline(CodecPipeline):
"""Default codec pipeline.
Expand Down
145 changes: 145 additions & 0 deletions tests/test_sync_codec_pipeline.py
Loading