perf(efficient-did): cache polynomial sieve basis across DR nuisance fits by igerber · Pull Request #556 · igerber/diff-diff · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion TODO.md
9 changes: 9 additions & 0 deletions diff_diff/efficient_did.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,11 @@ def fit(
m_hat_cache: Dict[Tuple, np.ndarray] = {}
r_hat_cache: Dict[Tuple[float, float], np.ndarray] = {}
s_hat_cache: Dict[float, np.ndarray] = {} # inverse propensities per group
# Per-fit cache of the polynomial sieve basis, keyed (id(X), degree). The three
# sieve nuisance helpers all build the basis from the same fit-level
# `covariate_matrix`, so this shares each distinct degree's basis across them
# instead of rebuilding it per helper. Lives only for this fit() call.
sieve_basis_cache: Dict[Tuple[int, int], np.ndarray] = {}

if use_covariates:
assert covariates is not None # for type narrowing
Expand Down Expand Up @@ -934,6 +939,7 @@ def fit(
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
unit_weights=unit_level_weights,
basis_cache=sieve_basis_cache,
)
# m_{g', tpre, 1}(X)
key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
Expand All @@ -950,6 +956,7 @@ def fit(
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
unit_weights=unit_level_weights,
basis_cache=sieve_basis_cache,
)
# r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
for comp in {np.inf, gp}:
Expand All @@ -966,6 +973,7 @@ def fit(
criterion=self.sieve_criterion,
ratio_clip=self.ratio_clip,
unit_weights=unit_level_weights,
basis_cache=sieve_basis_cache,
)

# Per-unit DR generated outcomes: shape (n_units, H)
Expand Down Expand Up @@ -998,6 +1006,7 @@ def fit(
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
unit_weights=unit_level_weights,
basis_cache=sieve_basis_cache,
)

# Conditional Omega*(X) with per-unit propensities (Eq 3.12)
Expand Down
41 changes: 38 additions & 3 deletions diff_diff/efficient_did_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def estimate_outcome_regression(
k_max: Optional[int] = None,
criterion: str = "bic",
unit_weights: Optional[np.ndarray] = None,
basis_cache: Optional[Dict[Tuple[int, int], np.ndarray]] = None,
) -> np.ndarray:
r"""Estimate conditional mean outcome change m_hat(X) via a polynomial sieve.

Expand Down Expand Up @@ -169,7 +170,7 @@ def estimate_outcome_regression(
if n_basis >= n_pos:
break

basis_all = _polynomial_sieve_basis(covariate_matrix, K)
basis_all = _sieve_basis_cached(covariate_matrix, K, basis_cache)
basis_group = basis_all[group_mask]

# Rank guard on the (weighted) design Gram, mirroring the propensity sieve.
Expand Down Expand Up @@ -288,6 +289,38 @@ def _polynomial_sieve_basis(X: np.ndarray, degree: int) -> np.ndarray:
return np.column_stack(columns)


def _sieve_basis_cached(
X: np.ndarray, degree: int, cache: Optional[Dict[Tuple[int, int], np.ndarray]]
) -> np.ndarray:
"""Per-fit memoized :func:`_polynomial_sieve_basis`.

``cache`` is a dict owned by one ``EfficientDiD.fit()`` and shared across the three
sieve nuisance helpers, which all receive the same fit-level ``covariate_matrix``.
The basis is a pure function of ``(X, degree)``, so for any degree reached by more
than one helper (every helper starts at ``K=1`` on the same ``X``) the identical
array would otherwise be rebuilt from scratch each time.

Keyed on ``(id(X), degree)``: ``X`` is fixed for a fit, so the basis depends only on
``degree``; ``id(X)`` guards against accidental reuse of a cache with a different
matrix. The cache lives only for the duration of one ``fit()`` (``covariate_matrix``
stays alive throughout, so its ``id`` is stable and uncollidable), so there is no
cross-fit leak and no ``id``-reuse hazard.

When ``cache is None`` (the default for any standalone caller) this is a plain
pass-through to :func:`_polynomial_sieve_basis`, leaving non-``EfficientDiD`` callers
byte-for-byte unchanged. The helpers only read the returned array (no in-place
mutation), so returning a shared cached object is bit-identical to rebuilding it.
"""
if cache is None:
return _polynomial_sieve_basis(X, degree)
key = (id(X), degree)
basis = cache.get(key)
if basis is None:
basis = _polynomial_sieve_basis(X, degree)
cache[key] = basis
return basis


def estimate_propensity_ratio_sieve(
covariate_matrix: np.ndarray,
mask_g: np.ndarray,
Expand All @@ -296,6 +329,7 @@ def estimate_propensity_ratio_sieve(
criterion: str = "bic",
ratio_clip: float = 20.0,
unit_weights: Optional[np.ndarray] = None,
basis_cache: Optional[Dict[Tuple[int, int], np.ndarray]] = None,
) -> np.ndarray:
r"""Estimate propensity ratio via sieve convex minimization (Eq 4.1-4.2).

Expand Down Expand Up @@ -396,7 +430,7 @@ def estimate_propensity_ratio_sieve(
if n_basis >= n_gp_pos:
break

basis_all = _polynomial_sieve_basis(covariate_matrix, K)
basis_all = _sieve_basis_cached(covariate_matrix, K, basis_cache)
Psi_gp = basis_all[mask_gp] # (n_gp, n_basis)
Psi_g = basis_all[mask_g] # (n_g, n_basis)

Expand Down Expand Up @@ -496,6 +530,7 @@ def estimate_inverse_propensity_sieve(
k_max: Optional[int] = None,
criterion: str = "bic",
unit_weights: Optional[np.ndarray] = None,
basis_cache: Optional[Dict[Tuple[int, int], np.ndarray]] = None,
) -> np.ndarray:
r"""Estimate s_{g'}(X) = 1/p_{g'}(X) via sieve convex minimization.

Expand Down Expand Up @@ -586,7 +621,7 @@ def estimate_inverse_propensity_sieve(
if n_basis >= n_group_pos:
break

basis_all = _polynomial_sieve_basis(covariate_matrix, K)
basis_all = _sieve_basis_cached(covariate_matrix, K, basis_cache)
Psi_gp = basis_all[group_mask]

# Normal equations (weighted when survey weights present):
Expand Down
104 changes: 104 additions & 0 deletions tests/test_efficient_did.py
Loading