Scalability Analysis#

This notebook provides a comprehensive benchmark analysis of sctrial on a real longitudinal immunotherapy dataset (Sade‑Feldman et al., Cell 2018; GSE120575).

Benchmark Objectives#

To answer the following questions:

  1. How does runtime scale with dataset size (cells, genes, participants)?

  2. How do different aggregation strategies and statistical approaches compare?

  3. How much memory is used across key operations?

  4. Do statistical outputs look well‑behaved and consistent?

  5. Are results reproducible across repeated runs?

Note: This notebook is a scalability and performance benchmark for sctrial, and the demonstrated analyses use subsampled data (≤500 cells per participant-visit) and a subset of genes (up to 2,000 of ~33k).

[1]:
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=RuntimeWarning, message='invalid value encountered in sqrt')
# NOTE: cluster-robust SE warning is intentionally NOT suppressed.
# With ~10 participants, cluster-robust p-values are unreliable.
# Bootstrap (use_bootstrap=True) is recommended for real analyses
# but disabled here for benchmarking speed.

import time
import psutil
import gc
from functools import wraps
from typing import Callable, Any

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import scipy.sparse as sp
import sctrial as st
from statsmodels.stats.multitest import multipletests

pd.options.mode.chained_assignment = None
sc.settings.verbosity = 0

# =============================================================================
# CONFIGURATION
# =============================================================================
SEED = 42
np.random.seed(SEED)

# Analysis parameters
MIN_GENES_FOR_SCORE = 5
MIN_PAIRED_PER_ARM = 3
FDR_ALPHA = 0.25

# Benchmark parameters
MAX_CELLS_PER_PARTICIPANT_VISIT = 500
SCALING_GENE_COUNTS = [100, 500, 1000, 2000, 5000]  # Gene scaling test
SCALING_CELL_FRACTIONS = [0.25, 0.5, 0.75, 1.0]     # Cell scaling test
N_BENCHMARK_REPEATS = 3                              # Repeats for timing stability
PARALLEL_JOBS = 2  # Jobs for did_table_parallel benchmarking

# Storage for benchmark results
benchmark_results = {
    'timing': {},
    'memory': {},
    'scaling': [],
    'validation': {}
}

# =============================================================================
# BENCHMARK UTILITIES
# =============================================================================
def timed_run(func: Callable, *args, n_repeats: int = 1, **kwargs) -> tuple[Any, float, float]:
    """Run function with timing, return (result, mean_time, std_time)."""
    times = []
    result = None
    for _ in range(n_repeats):
        if result is not None:
            del result          # free prior iteration's result before next run
        gc.collect()
        start = time.perf_counter()
        result = func(*args, **kwargs)
        times.append(time.perf_counter() - start)
    return result, np.mean(times), np.std(times)

def memory_profile(func: Callable, *args, **kwargs) -> tuple[Any, float, float]:
    """Run function with memory profiling, return (result, delta_mb, current_mb).

    Measures process RSS (Resident Set Size) before and after the function call.
    RSS captures all allocations including C-level / numpy arrays, unlike
    tracemalloc which only tracks Python-level allocations and can report stale
    cumulative peaks across calls.
    """
    gc.collect()
    process = psutil.Process()
    rss_before = process.memory_info().rss
    result = func(*args, **kwargs)
    rss_after = process.memory_info().rss
    gc.collect()
    delta_mb = (rss_after - rss_before) / 1024 / 1024
    current_mb = rss_after / 1024 / 1024
    return result, max(delta_mb, 0), current_mb

def format_time(seconds: float) -> str:
    """Format seconds to human-readable string."""
    if seconds < 1:
        return f"{seconds*1000:.1f}ms"
    elif seconds < 60:
        return f"{seconds:.2f}s"
    else:
        return f"{seconds/60:.1f}min"

print("=" * 70)
print("SCTRIAL SCALABILITY BENCHMARK")
print("=" * 70)
print(f"sctrial version: {st.__version__}")
print(f"Random seed: {SEED}")
print(f"Benchmark repeats: {N_BENCHMARK_REPEATS}")
======================================================================
SCTRIAL SCALABILITY BENCHMARK
======================================================================
sctrial version: 0.3.3
Random seed: 42
Benchmark repeats: 3

1. Setup & Data Loading#

Load the Sade-Feldman melanoma immunotherapy dataset and establish the trial design.

[2]:
from sctrial.datasets import load_sade_feldman

# Load with memory profiling
print("Loading dataset...")
(adata, load_time, _) = timed_run(
    load_sade_feldman,
    max_cells_per_participant_visit=MAX_CELLS_PER_PARTICIPANT_VISIT,
    processed_name="sade_feldman_processed_v6_subsample.h5ad",
    allow_download=True,
    n_repeats=1
)
benchmark_results['timing']['data_load'] = load_time
print(f"  Load time: {format_time(load_time)}")
Loading dataset...
  Load time: 1.28s

Data Harmonization & Trial Design Setup#

[3]:
# =============================================================================
# RESPONSE HARMONIZATION
# =============================================================================
# Harmonize response labels at participant level using package API
# (assigns majority-vote label for participants with mixed response annotations)
adata = st.harmonize_response(adata)

RESPONSE_COL = "response_harmonized"

# =============================================================================
# IDENTIFY PAIRED PARTICIPANTS
# =============================================================================
visit_col = "visit"
adata.obs[visit_col] = adata.obs[visit_col].astype(str)
visits = [v for v in ["Pre", "Post"] if v in adata.obs[visit_col].unique()]

participant_summary = (
    adata.obs.groupby("participant_id")[visit_col].apply(set).reset_index()
)
participant_summary["has_Pre"] = participant_summary[visit_col].apply(lambda x: "Pre" in x)
participant_summary["has_Post"] = participant_summary[visit_col].apply(lambda x: "Post" in x)
participant_summary["is_paired"] = participant_summary["has_Pre"] & participant_summary["has_Post"]

dominant_response = adata.obs.groupby("participant_id")[RESPONSE_COL].first()
participant_summary[RESPONSE_COL] = participant_summary["participant_id"].map(dominant_response)

paired_ids = set(participant_summary.loc[participant_summary["is_paired"], "participant_id"])
paired_by_response = (
    participant_summary[participant_summary["is_paired"]]
    .groupby(RESPONSE_COL)
    .size()
    .to_dict()
)

# Subset to paired participants
adata_paired = adata[adata.obs["participant_id"].isin(paired_ids)].copy()

# =============================================================================
# TRIAL DESIGN
# =============================================================================
design = st.TrialDesign(
    participant_col="participant_id",
    visit_col=visit_col,
    arm_col=RESPONSE_COL,
    arm_treated="Responder",
    arm_control="Non-responder",
    celltype_col="cell_type",
)

# Store dataset info for benchmarks
dataset_info = {
    'n_cells_full': adata.n_obs,
    'n_cells_paired': adata_paired.n_obs,
    'n_genes': adata.n_vars,
    'n_participants_full': adata.obs['participant_id'].nunique(),
    'n_participants_paired': len(paired_ids),
    'n_paired_responder': paired_by_response.get('Responder', 0),
    'n_paired_nonresponder': paired_by_response.get('Non-responder', 0),
}

benchmark_results['dataset_info'] = dataset_info

design
[3]:
TrialDesign(participant_col='participant_id', visit_col='visit', arm_col='response_harmonized', arm_treated='Responder', arm_control='Non-responder', celltype_col='cell_type', crossover_col=None, baseline_visit=None, followup_visit=None)

2. Dataset Overview#

Here, we summarize the dataset characteristics and benchmark configuration in tables, then visualize key distributions to ground the benchmarking context.

[4]:
# =============================================================================
# DATASET SUMMARY TABLE
# =============================================================================
print("=" * 70)
print("DATASET SUMMARY")
print("=" * 70)

summary_df = pd.DataFrame([
    ("Full Dataset - Cells", f"{dataset_info['n_cells_full']:,}"),
    ("Full Dataset - Genes", f"{dataset_info['n_genes']:,}"),
    ("Full Dataset - Participants", f"{dataset_info['n_participants_full']}"),
    ("Paired Dataset - Cells", f"{dataset_info['n_cells_paired']:,}"),
    ("Paired Dataset - Participants", f"{dataset_info['n_participants_paired']}"),
    ("  Responders (paired)", f"{dataset_info['n_paired_responder']}"),
    ("  Non-responders (paired)", f"{dataset_info['n_paired_nonresponder']}"),
    ("Visits", f"{visits}"),
    ("Mixed-response participants", f"{(adata.obs.groupby('participant_id')['response'].nunique() > 1).sum()}"),
], columns=["Metric", "Value"])

display(summary_df.style.hide(axis='index'))

# =============================================================================
# BENCHMARK CONFIGURATION TABLE
# =============================================================================
config_df = pd.DataFrame([
    ("Scaling gene counts", f"{SCALING_GENE_COUNTS}"),
    ("Cell fractions", f"{SCALING_CELL_FRACTIONS}"),
    ("Repeats per test", f"{N_BENCHMARK_REPEATS}"),
    ("Min genes for scoring", f"{MIN_GENES_FOR_SCORE}"),
    ("Min paired per arm", f"{MIN_PAIRED_PER_ARM}"),
    ("FDR threshold", f"{FDR_ALPHA}"),
    ("Max cells/participant\u2011visit", f"{MAX_CELLS_PER_PARTICIPANT_VISIT}"),
], columns=["Parameter", "Value"])

print("\nBENCHMARK CONFIGURATION")
print("-" * 50)
display(config_df.style.hide(axis='index'))

# =============================================================================
# COMPREHENSIVE VISUALIZATION
# =============================================================================
fig, axes = plt.subplots(2, 4, figsize=(18, 10))

# 1. Cells per Response x Visit
ax1 = axes[0, 0]
cell_counts = adata_paired.obs.groupby([RESPONSE_COL, visit_col], observed=True).size().unstack(fill_value=0)
cell_counts.plot(kind="bar", ax=ax1, color=['#3498db', '#e74c3c'], edgecolor='black')
ax1.set_title("Cells by Response \u00d7 Visit", fontsize=11, fontweight='bold')
ax1.set_xlabel("Response")
ax1.set_ylabel("Number of Cells")
ax1.legend(title="Visit")
ax1.tick_params(axis='x', rotation=0)
for container in ax1.containers:
    ax1.bar_label(container, fmt='%d', fontsize=8)

# 2. Participants per Response x Visit
ax2 = axes[0, 1]
pt_counts = adata_paired.obs.groupby([RESPONSE_COL, visit_col], observed=True)["participant_id"].nunique().unstack(fill_value=0)
pt_counts.plot(kind="bar", ax=ax2, color=['#3498db', '#e74c3c'], edgecolor='black')
ax2.set_title("Participants by Response \u00d7 Visit", fontsize=11, fontweight='bold')
ax2.set_xlabel("Response")
ax2.set_ylabel("Number of Participants")
ax2.legend(title="Visit")
ax2.tick_params(axis='x', rotation=0)
for container in ax2.containers:
    ax2.bar_label(container, fmt='%d', fontsize=9)

# 3. Cells per Participant (strip plot — better than histogram for n=10)
ax3 = axes[0, 2]
cells_per_pt = adata_paired.obs.groupby("participant_id").size().reset_index(name="n_cells")
cells_per_pt[RESPONSE_COL] = cells_per_pt["participant_id"].map(
    adata_paired.obs.groupby("participant_id")[RESPONSE_COL].first()
)
sns.stripplot(data=cells_per_pt, x=RESPONSE_COL, y="n_cells", ax=ax3,
              palette={'Responder': '#2ecc71', 'Non-responder': '#e74c3c'},
              size=8, jitter=0.15, edgecolor='black', linewidth=0.5)
ax3.axhline(cells_per_pt["n_cells"].mean(), color='gray', linestyle='--', alpha=0.5,
            label=f'Mean: {cells_per_pt["n_cells"].mean():.0f}')
ax3.set_title("Cells per Participant", fontsize=11, fontweight='bold')
ax3.set_ylabel("Number of Cells")
ax3.legend(fontsize=8)

# 4. Cells per Participant-Visit Heatmap
ax4 = axes[0, 3]
cells_pv = adata_paired.obs.groupby(["participant_id", visit_col], observed=True).size().unstack(fill_value=0)
sns.heatmap(cells_pv, annot=True, fmt='d', cmap='YlOrRd', ax=ax4, cbar_kws={'label': 'Cells'})
ax4.set_title("Cells per Participant \u00d7 Visit", fontsize=11, fontweight='bold')
ax4.set_xlabel("Visit")
ax4.set_ylabel("Participant")

# 5. Response Balance
ax5 = axes[1, 0]
response_counts = pd.Series(paired_by_response)
colors = ['#e74c3c' if x == 'Non-responder' else '#2ecc71' for x in response_counts.index]
bars = ax5.bar(response_counts.index, response_counts.values, color=colors, edgecolor='black')
ax5.set_title("Paired Participants by Response", fontsize=11, fontweight='bold')
ax5.set_ylabel("Number of Participants")
ax5.bar_label(bars, fmt='%d', fontsize=11)

# 6. Data Sparsity (bar instead of pie — cleaner for 2-category)
ax6 = axes[1, 1]
if sp.issparse(adata_paired.X):
    sparsity = 1.0 - (adata_paired.X.nnz / (adata_paired.X.shape[0] * adata_paired.X.shape[1]))
else:
    sparsity = np.mean(adata_paired.X == 0)
bar_sp = ax6.bar(["Zero", "Non-zero"], [sparsity * 100, (1 - sparsity) * 100],
                 color=['#ecf0f1', '#3498db'], edgecolor='black')
ax6.bar_label(bar_sp, fmt='%.1f%%', fontsize=10)
ax6.set_ylabel("Percentage")
ax6.set_title(f"Expression Matrix Sparsity", fontsize=11, fontweight='bold')
ax6.set_ylim(0, 105)

# 7. Gene Detection per Cell
ax7 = axes[1, 2]
if sp.issparse(adata_paired.X):
    genes_per_cell = np.array((adata_paired.X > 0).sum(axis=1)).flatten()
else:
    genes_per_cell = np.sum(adata_paired.X > 0, axis=1)
ax7.hist(genes_per_cell, bins=30, color='#9b59b6', edgecolor='black', alpha=0.7)
ax7.axvline(np.mean(genes_per_cell), color='red', linestyle='--', label=f'Mean: {np.mean(genes_per_cell):.0f}')
ax7.set_title("Genes Detected per Cell", fontsize=11, fontweight='bold')
ax7.set_xlabel("Number of Genes")
ax7.set_ylabel("Frequency")
ax7.legend(fontsize=8)

# 8. Memory Footprint (log scale for readability)
ax8 = axes[1, 3]
if sp.issparse(adata_paired.X):
    matrix_mb = adata_paired.X.data.nbytes / 1024 / 1024
else:
    matrix_mb = adata_paired.X.nbytes / 1024 / 1024
obs_mb = adata_paired.obs.memory_usage(deep=True).sum() / 1024 / 1024
var_mb = adata_paired.var.memory_usage(deep=True).sum() / 1024 / 1024

mem_data = pd.Series({'Expression\nMatrix': matrix_mb, 'Observations\n(obs)': obs_mb, 'Variables\n(var)': var_mb})
mem_data.plot(kind='barh', ax=ax8, color=['#3498db', '#2ecc71', '#e74c3c'], edgecolor='black', log=True)
ax8.set_title("Memory Footprint", fontsize=11, fontweight='bold')
ax8.set_xlabel("Memory (MB, log scale)")
for i, v in enumerate(mem_data):
    ax8.text(v * 1.3, i, f'{v:.1f} MB', va='center', fontsize=9)

plt.suptitle("Dataset Overview for Benchmark Analysis", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\nTotal estimated memory: {matrix_mb + obs_mb + var_mb:.1f} MB")

======================================================================
DATASET SUMMARY
======================================================================
Metric Value
Full Dataset - Cells 12,187
Full Dataset - Genes 55,737
Full Dataset - Participants 25
Paired Dataset - Cells 7,068
Paired Dataset - Participants 10
Responders (paired) 3
Non-responders (paired) 7
Visits ['Pre', 'Post']
Mixed-response participants 3

BENCHMARK CONFIGURATION
--------------------------------------------------
Parameter Value
Scaling gene counts [100, 500, 1000, 2000, 5000]
Cell fractions [0.25, 0.5, 0.75, 1.0]
Repeats per test 3
Min genes for scoring 5
Min paired per arm 3
FDR threshold 0.25
Max cells/participant‑visit 500
../_images/tutorials_stress_test_real_scale_8_4.png

Total estimated memory: 1509.0 MB

3. Gene Signature Scoring#

Score biologically relevant gene signatures for downstream analysis.

[5]:
# =============================================================================
# DEFINE GENE SIGNATURES
# =============================================================================
available = set(adata_paired.var_names)

raw_signatures = {
    "Cytotoxicity": ["GZMB", "GZMA", "GZMH", "GZMK", "PRF1", "GNLY", "IFNG", "NKG7", "KLRD1", "KLRB1", "FASLG"],
    "Exhaustion": ["PDCD1", "LAG3", "HAVCR2", "TIGIT", "CTLA4", "TOX", "ENTPD1", "CXCL13", "EOMES"],
    "IFN_Response": ["ISG15", "IFI6", "IFIT1", "IFIT2", "IFIT3", "MX1", "MX2", "STAT1", "OAS1", "IRF7"],
    "Memory": ["IL7R", "TCF7", "LEF1", "CCR7", "SELL", "CD27", "CD28"],
    "Activation": ["CD69", "CD38", "HLA-DRA", "ICOS", "CD44", "IL2RA"],
}

# Filter signatures by gene availability
print("Signature Gene Coverage:")
print("-" * 50)
filtered_signatures = {}
coverage_data = []
for name, genes in raw_signatures.items():
    found = [g for g in genes if g in available]
    pct = len(found) / len(genes) * 100
    status = "OK" if len(found) >= MIN_GENES_FOR_SCORE else "SKIP"
    coverage_data.append({
        'Signature': name,
        'Found': len(found),
        'Total': len(genes),
        'Coverage': f"{pct:.0f}%",
        'Status': status
    })
    if len(found) >= MIN_GENES_FOR_SCORE:
        filtered_signatures[name] = found

coverage_df = pd.DataFrame(coverage_data)
display(coverage_df.style.hide(axis='index'))

# =============================================================================
# SCORE SIGNATURES WITH TIMING
# =============================================================================
if filtered_signatures:
    def score_signatures():
        return st.score_gene_sets(
            adata_paired,
            filtered_signatures,
            layer="log1p_tpm",
            method="zmean",
            prefix="sig_",
        )

    _, score_time, score_std = timed_run(score_signatures, n_repeats=N_BENCHMARK_REPEATS)
    benchmark_results['timing']['signature_scoring'] = score_time

    print(f"\nSignature scoring time: {format_time(score_time)} ± {format_time(score_std)}")
    print(f"Signatures scored: {len(filtered_signatures)}")

signature_cols = [c for c in adata_paired.obs.columns if c.startswith("sig_")]

# Visualize signature distributions
if signature_cols:
    fig, axes = plt.subplots(2, 3, figsize=(14, 8))
    axes = axes.flatten()

    for i, col in enumerate(signature_cols[:6]):
        ax = axes[i]
        for response in adata_paired.obs[RESPONSE_COL].unique():
            mask = adata_paired.obs[RESPONSE_COL] == response
            data = adata_paired.obs.loc[mask, col].dropna()
            ax.hist(data, bins=30, alpha=0.5, label=response, density=True)
        ax.set_title(col.replace('sig_', ''), fontweight='bold')
        ax.set_xlabel('Score')
        ax.set_ylabel('Density')
        ax.legend()

    # Hide unused axes
    for j in range(len(signature_cols), 6):
        axes[j].axis('off')

    plt.suptitle("Signature Score Distributions by Response", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
Signature Gene Coverage:
--------------------------------------------------
Signature Found Total Coverage Status
Cytotoxicity 11 11 100% OK
Exhaustion 9 9 100% OK
IFN_Response 10 10 100% OK
Memory 7 7 100% OK
Activation 6 6 100% OK

Signature scoring time: 11.2ms ± 2.7ms
Signatures scored: 5
../_images/tutorials_stress_test_real_scale_10_3.png
[6]:
# =============================================================================
# PAIRING VERIFICATION
# =============================================================================
print("Pairing Verification")
print("-" * 50)

if signature_cols:
    df_pv = (
        adata_paired.obs
        .groupby([design.participant_col, design.visit_col, design.arm_col], observed=True)[signature_cols]
        .mean()
        .reset_index()
    )

    valid_paired = {}
    for feat in signature_cols:
        wide = df_pv.pivot(index=design.participant_col, columns=design.visit_col, values=feat)
        if visits[0] in wide.columns and visits[1] in wide.columns:
            mask = wide[visits[0]].notna() & wide[visits[1]].notna()
            valid_paired[feat] = set(wide[mask].index)
        else:
            valid_paired[feat] = set()

    all_features_valid = set.intersection(*[valid_paired[f] for f in signature_cols]) if signature_cols else set()
    participant_arm = adata_paired.obs.groupby(design.participant_col)[design.arm_col].first()

    VALID_PAIRED_BY_RESPONSE = {
        arm: {pid for pid in all_features_valid if participant_arm.get(pid) == arm}
        for arm in [design.arm_treated, design.arm_control]
    }
    N_VALID_PAIRED = len(all_features_valid)

    print(f"Participants with valid Pre+Post scores: {N_VALID_PAIRED}")
    print(f"  {design.arm_treated}: {len(VALID_PAIRED_BY_RESPONSE[design.arm_treated])}")
    print(f"  {design.arm_control}: {len(VALID_PAIRED_BY_RESPONSE[design.arm_control])}")
else:
    VALID_PAIRED_BY_RESPONSE = paired_by_response
    N_VALID_PAIRED = len(paired_ids)
Pairing Verification
--------------------------------------------------
Participants with valid Pre+Post scores: 10
  Responder: 3
  Non-responder: 7

4. Scaling Benchmarks#

Here, we benchmark how sctrial scales with dataset size. We vary gene counts and cell fractions to stress‑test performance.

  • Cell fractions refer to keeping a fixed proportion of cells (e.g., 25%, 50%, 75%, 100%) via stratified subsampling within each participant‑visit, preserving pairing and structure.

  • Method comparison evaluates the same study design (Responder vs Non‑responder; Pre vs Post) under different aggregation and analysis strategies.

[7]:
# =============================================================================
# SCALING BENCHMARK: Genes
# =============================================================================
print("=" * 70)
print("GENE SCALING BENCHMARK")
print("=" * 70)
print(f"Testing DiD performance across {len(SCALING_GENE_COUNTS)} gene counts...")

gene_scaling_results = []

for n_genes in SCALING_GENE_COUNTS:
    if n_genes > adata_paired.n_vars:
        print(f"  Skipping {n_genes} genes (exceeds available {adata_paired.n_vars})")
        continue

    # Select genes
    _test_genes = list(adata_paired.var_names[:n_genes])

    # Benchmark DiD (serial) — capture loop variable via default arg
    def run_did(genes=_test_genes):
        return st.did_table(
            adata_paired,
            features=genes,
            design=design,
            visits=tuple(visits),
            aggregate="participant_visit",
            layer="log1p_tpm",
        )

    # Benchmark DiD (parallel)
    def run_did_parallel(genes=_test_genes):
        return st.did_table_parallel(
            adata_paired,
            features=genes,
            design=design,
            visits=tuple(visits),
            aggregate="participant_visit",
            layer="log1p_tpm",
            n_jobs=PARALLEL_JOBS,
        )

    _, mean_time, std_time = timed_run(run_did, n_repeats=N_BENCHMARK_REPEATS)
    _, mean_time_par, std_time_par = timed_run(run_did_parallel, n_repeats=N_BENCHMARK_REPEATS)

    gene_scaling_results.append({
        'n_genes': n_genes,
        'mean_time': mean_time,
        'std_time': std_time,
        'genes_per_sec': n_genes / mean_time,
        'mean_time_parallel': mean_time_par,
        'std_time_parallel': std_time_par,
        'genes_per_sec_parallel': n_genes / mean_time_par,
    })

    print(f"  {n_genes:,} genes: serial {format_time(mean_time)} ± {format_time(std_time)} "
          f"({n_genes/mean_time:.0f} genes/sec) | parallel {format_time(mean_time_par)} ± {format_time(std_time_par)} "
          f"({n_genes/mean_time_par:.0f} genes/sec)")

    # Free intermediate results between iterations to limit memory growth
    del _test_genes
    gc.collect()

gene_scaling_df = pd.DataFrame(gene_scaling_results)
benchmark_results['scaling'].append(('genes', gene_scaling_df))

# =============================================================================
# SCALING BENCHMARK: Cells
# =============================================================================
print("\n" + "=" * 70)
print("CELL SCALING BENCHMARK")
print("=" * 70)
print(f"Testing DiD performance across {len(SCALING_CELL_FRACTIONS)} cell fractions...")
print("Cell fractions keep a fixed proportion of cells per participant-visit (stratified), preserving pairing.")

cell_scaling_results = []
n_test_genes = 500  # Fixed gene count for cell scaling test

for frac in SCALING_CELL_FRACTIONS:
    n_cells = int(adata_paired.n_obs * frac)

    # Subsample cells (stratified by participant-visit)
    if frac < 1.0:
        sampled_idx = (
            adata_paired.obs
            .groupby([design.participant_col, design.visit_col], observed=True)
            .apply(lambda x: x.sample(frac=frac, random_state=SEED) if len(x) > 1 else x)
            .index.get_level_values(-1)
        )
        adata_sub = adata_paired[sampled_idx].copy()
    else:
        adata_sub = adata_paired

    _test_genes_cell = list(adata_sub.var_names[:n_test_genes])

    # Capture loop variables via default args
    def run_did_cells(ad=adata_sub, genes=_test_genes_cell):
        return st.did_table(
            ad,
            features=genes,
            design=design,
            visits=tuple(visits),
            aggregate="participant_visit",
            layer="log1p_tpm",
        )

    _, mean_time, std_time = timed_run(run_did_cells, n_repeats=N_BENCHMARK_REPEATS)

    cell_scaling_results.append({
        'cell_fraction': frac,
        'n_cells': adata_sub.n_obs,
        'mean_time': mean_time,
        'std_time': std_time,
    })

    print(f"  {frac*100:.0f}% ({adata_sub.n_obs:,} cells): {format_time(mean_time)} ± {format_time(std_time)}")

    # Free subsampled AnnData between iterations
    if frac < 1.0:
        del adata_sub, sampled_idx
    del _test_genes_cell
    gc.collect()

cell_scaling_df = pd.DataFrame(cell_scaling_results)
benchmark_results['scaling'].append(('cells', cell_scaling_df))

# =============================================================================
# SCALING VISUALIZATION  (publication-quality 2x2 panel)
# =============================================================================
fig, axes = plt.subplots(2, 2, figsize=(14, 11))

# ---- Colour palette & common style ----
C_SERIAL   = '#2c7bb6'
C_PARALLEL = '#d7191c'
C_CELL     = '#fdae61'
C_THRU_S   = '#2c7bb6'
C_THRU_P   = '#d7191c'
MARKER_KW  = dict(markersize=7, markeredgecolor='black', markeredgewidth=0.6)

# ---- (a) Gene-scaling: log-log with error bands ----
ax = axes[0, 0]
x_g = gene_scaling_df['n_genes'].values
y_s = gene_scaling_df['mean_time'].values
e_s = gene_scaling_df['std_time'].values

ax.fill_between(x_g, np.maximum(y_s - e_s, y_s * 0.5), y_s + e_s, color=C_SERIAL, alpha=0.12, zorder=1)
ax.plot(x_g, y_s, 'o-', color=C_SERIAL, linewidth=2, label='Serial', **MARKER_KW)

if 'mean_time_parallel' in gene_scaling_df.columns:
    y_p = gene_scaling_df['mean_time_parallel'].values
    e_p = gene_scaling_df['std_time_parallel'].values
    ax.fill_between(x_g, np.maximum(y_p - e_p, y_p * 0.5), y_p + e_p, color=C_PARALLEL, alpha=0.12, zorder=2)
    ax.plot(x_g, y_p, 's--', color=C_PARALLEL, linewidth=2,
            label=f'Parallel (n_jobs={PARALLEL_JOBS})', **MARKER_KW)

# Log-log + linear-scaling reference line
ax.set_xscale('log'); ax.set_yscale('log')
x_ref = np.array([x_g.min(), x_g.max()])
scale = y_s[0] / x_g[0]
ax.plot(x_ref, x_ref * scale, ':', color='grey', alpha=0.6, label='O(n) reference')
ax.legend(fontsize=9, framealpha=0.9)
ax.set_xlabel('Number of Genes', fontsize=11)
ax.set_ylabel('Runtime (s)', fontsize=11)
ax.set_title('(a)  Gene Scaling (log-log)', fontsize=12, fontweight='bold')
ax.grid(True, which='both', alpha=0.2)

# ---- (b) Cell-scaling: log-log with error band ----
ax = axes[0, 1]
x_c = cell_scaling_df['n_cells'].values
y_c = cell_scaling_df['mean_time'].values
e_c = cell_scaling_df['std_time'].values

ax.fill_between(x_c, np.maximum(y_c - e_c, y_c * 0.5), y_c + e_c, color=C_CELL, alpha=0.25)
ax.plot(x_c, y_c, 'D-', color=C_CELL, linewidth=2, label='500 genes', **MARKER_KW)

ax.set_xscale('log'); ax.set_yscale('log')
scale_c = y_c[0] / x_c[0]
ax.plot([x_c.min(), x_c.max()],
        [x_c.min() * scale_c, x_c.max() * scale_c],
        ':', color='grey', alpha=0.6, label='O(n) reference')
ax.legend(fontsize=9, framealpha=0.9)
ax.set_xlabel('Number of Cells', fontsize=11)
ax.set_ylabel('Runtime (s)', fontsize=11)
ax.set_title('(b)  Cell Scaling (log-log)', fontsize=12, fontweight='bold')
ax.grid(True, which='both', alpha=0.2)

# ---- (c) Throughput: serial vs parallel grouped lollipop ----
ax = axes[1, 0]
labels = gene_scaling_df['n_genes'].astype(str).values
y_pos = np.arange(len(labels))
width = 0.35

thru_s = gene_scaling_df['genes_per_sec'].values
ax.hlines(y_pos + width/2, 0, thru_s, color=C_THRU_S, linewidth=2)
ax.plot(thru_s, y_pos + width/2, 'o', color=C_THRU_S, markersize=8,
        markeredgecolor='black', markeredgewidth=0.6, label='Serial')

if 'genes_per_sec_parallel' in gene_scaling_df.columns:
    thru_p = gene_scaling_df['genes_per_sec_parallel'].values
    ax.hlines(y_pos - width/2, 0, thru_p, color=C_THRU_P, linewidth=2)
    ax.plot(thru_p, y_pos - width/2, 's', color=C_THRU_P, markersize=8,
            markeredgecolor='black', markeredgewidth=0.6, label='Parallel')

ax.set_yticks(y_pos)
ax.set_yticklabels([f'{int(x):,}' for x in gene_scaling_df['n_genes']])
ax.set_ylabel('Number of Genes', fontsize=11)
ax.set_xlabel('Throughput (genes / s)', fontsize=11)
ax.set_title('(c)  Throughput', fontsize=12, fontweight='bold')
ax.legend(fontsize=9, framealpha=0.9)
ax.grid(True, axis='x', alpha=0.2)

# ---- (d) Parallel speed-up factor ----
ax = axes[1, 1]
if 'mean_time_parallel' in gene_scaling_df.columns:
    speedup = gene_scaling_df['mean_time'] / gene_scaling_df['mean_time_parallel']
    ax.plot(x_g, speedup, 'o-', color='#756bb1', linewidth=2, **MARKER_KW)
    ax.axhline(1.0, color='grey', linestyle='--', alpha=0.5, label='No speed-up')
    ax.annotate('Overhead\ndominates',
                xy=(x_g[0], speedup.iloc[0]),
                xytext=(x_g[0] * 2.5, speedup.iloc[0] - 0.05),
                fontsize=8, color='grey', fontstyle='italic',
                arrowprops=dict(arrowstyle='->', color='grey', lw=0.8))
    ax.axhline(PARALLEL_JOBS, color='#d7191c', linestyle=':', alpha=0.5,
               label=f'Ideal ({PARALLEL_JOBS}\u00d7)')
    ax.set_xscale('log')
    ax.set_xlabel('Number of Genes', fontsize=11)
    ax.set_ylabel('Speed-up (serial / parallel)', fontsize=11)
    ax.set_title('(d)  Parallel Speed-up', fontsize=12, fontweight='bold')
    ax.legend(fontsize=9, framealpha=0.9)
    ax.grid(True, which='both', alpha=0.2)
else:
    ax.text(0.5, 0.5, 'Parallel data\nnot available',
            ha='center', va='center', transform=ax.transAxes, fontsize=12, color='grey')
    ax.set_title('(d)  Parallel Speed-up', fontsize=12, fontweight='bold')

plt.suptitle('Scaling Benchmarks', fontsize=14, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()
======================================================================
GENE SCALING BENCHMARK
======================================================================
Testing DiD performance across 5 gene counts...
  100 genes: serial 690.2ms ± 10.0ms (145 genes/sec) | parallel 1.37s ± 1.11s (73 genes/sec)
  500 genes: serial 1.64s ± 37.7ms (304 genes/sec) | parallel 1.11s ± 36.7ms (450 genes/sec)
  1,000 genes: serial 2.76s ± 29.5ms (362 genes/sec) | parallel 1.73s ± 17.5ms (577 genes/sec)
  2,000 genes: serial 5.05s ± 19.5ms (396 genes/sec) | parallel 2.92s ± 32.8ms (686 genes/sec)
  5,000 genes: serial 12.01s ± 73.5ms (416 genes/sec) | parallel 6.77s ± 38.1ms (738 genes/sec)

======================================================================
CELL SCALING BENCHMARK
======================================================================
Testing DiD performance across 4 cell fractions...
Cell fractions keep a fixed proportion of cells per participant-visit (stratified), preserving pairing.
  25% (1,767 cells): 1.21s ± 5.5ms
  50% (3,534 cells): 1.37s ± 20.8ms
  75% (5,301 cells): 1.48s ± 4.1ms
  100% (7,068 cells): 1.60s ± 15.5ms
../_images/tutorials_stress_test_real_scale_13_1.png
[8]:
# =============================================================================
# METHOD COMPARISON BENCHMARK
# =============================================================================
print("=" * 70)
print("METHOD COMPARISON BENCHMARK")
print("=" * 70)
print("This compares: (1) DiD aggregation modes (participant_visit vs cell),")
print("(2) cross-sectional between-arm contrasts at each visit,")
print("and (3) within-arm paired changes. All use the same trial design and feature set.")



method_results = []

# Test different aggregation strategies
agg_modes = ["participant_visit", "cell"]
test_features = signature_cols if signature_cols else list(adata_paired.var_names[:100])

print("\nAggregation Mode Comparison:")
print("-" * 50)

for agg_mode in agg_modes:
    def run_agg_test(_agg=agg_mode):
        return st.did_table(
            adata_paired,
            features=test_features,
            design=design,
            visits=tuple(visits),
            aggregate=_agg,
            layer="log1p_tpm",
        )

    try:
        result, mean_time, std_time = timed_run(run_agg_test, n_repeats=N_BENCHMARK_REPEATS)
        n_results = len(result) if result is not None else 0
        method_results.append({
            'method': f'DiD ({agg_mode})',
            'mean_time': mean_time,
            'std_time': std_time,
            'n_results': n_results
        })
        print(f"  {agg_mode}: {format_time(mean_time)} ± {format_time(std_time)}")
    except Exception as e:
        print(f"  {agg_mode}: FAILED - {e}")
    gc.collect()

# Test cross-sectional comparisons
print("\nCross-sectional Comparison:")
print("-" * 50)

for v in visits:
    def run_cross(_v=v):
        return st.between_arm_comparison(
            adata_paired,
            visit=_v,
            features=test_features,
            design=design,
            aggregate="participant_visit",
            layer="log1p_tpm",
        )

    result, mean_time, std_time = timed_run(run_cross, n_repeats=N_BENCHMARK_REPEATS)
    method_results.append({
        'method': f'Cross-sectional ({v})',
        'mean_time': mean_time,
        'std_time': std_time,
        'n_results': len(result) if result is not None else 0
    })
    print(f"  {v}: {format_time(mean_time)} ± {format_time(std_time)}")
    gc.collect()

# Test within-arm comparisons
print("\nWithin-arm Comparison:")
print("-" * 50)

for arm in [design.arm_treated, design.arm_control]:
    if paired_by_response.get(arm, 0) < MIN_PAIRED_PER_ARM:
        print(f"  {arm}: SKIPPED (insufficient participants)")
        continue

    def run_within(_arm=arm):
        return st.within_arm_comparison(
            adata_paired,
            arm=_arm,
            features=test_features,
            design=design,
            visits=tuple(visits),
            aggregate="participant_visit",
            layer="log1p_tpm",
        )

    result, mean_time, std_time = timed_run(run_within, n_repeats=N_BENCHMARK_REPEATS)
    method_results.append({
        'method': f'Within-arm ({arm})',
        'mean_time': mean_time,
        'std_time': std_time,
        'n_results': len(result) if result is not None else 0
    })
    print(f"  {arm}: {format_time(mean_time)} ± {format_time(std_time)}")
    gc.collect()

method_df = pd.DataFrame(method_results)
benchmark_results['timing']['methods'] = method_df

# Visualization — horizontal lollipop chart with error bars
fig, ax = plt.subplots(figsize=(10, 5))
y_pos = np.arange(len(method_df))
palette = plt.cm.Set2(np.linspace(0, 1, len(method_df)))

ax.hlines(y_pos, 0, method_df['mean_time'], color='#aaaaaa', linewidth=1.5)
ax.errorbar(method_df['mean_time'], y_pos,
            xerr=method_df['std_time'], fmt='none', ecolor='#555555',
            capsize=4, capthick=1.2, zorder=3)
ax.scatter(method_df['mean_time'], y_pos, c=palette, s=120,
           edgecolors='black', linewidths=0.8, zorder=4)

for i, (t, s) in enumerate(zip(method_df['mean_time'], method_df['std_time'])):
    ax.text(t + s + 0.02, i, format_time(t), va='center', fontsize=10, fontweight='bold')

ax.set_yticks(y_pos)
ax.set_yticklabels(method_df['method'], fontsize=10)
ax.set_xlabel('Runtime (seconds)', fontsize=11)
ax.set_title('Method Comparison: Runtime', fontsize=13, fontweight='bold')
ax.grid(True, axis='x', alpha=0.2)
ax.set_axisbelow(True)
ax.margins(x=0.15)
plt.tight_layout()
plt.show()
======================================================================
METHOD COMPARISON BENCHMARK
======================================================================
This compares: (1) DiD aggregation modes (participant_visit vs cell),
(2) cross-sectional between-arm contrasts at each visit,
and (3) within-arm paired changes. All use the same trial design and feature set.

Aggregation Mode Comparison:
--------------------------------------------------
  participant_visit: 478.9ms ± 1.3ms
  cell: 558.2ms ± 14.2ms

Cross-sectional Comparison:
--------------------------------------------------
  Pre: 223.1ms ± 17.5ms
  Post: 277.9ms ± 2.4ms

Within-arm Comparison:
--------------------------------------------------
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 3 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
  Responder: 288.2ms ± 16.5ms
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
  Non-responder: 652.2ms ± 9.8ms
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
/var/folders/71/dc4p4yz15s74z9c69xy6sk_00000gt/T/ipykernel_52935/988223502.py:82: UserWarning: Only 7 clusters (participants) available. Cluster-robust standard errors are unreliable with fewer than 10 clusters. Consider using use_bootstrap=True for more reliable p-values.
  return st.within_arm_comparison(
../_images/tutorials_stress_test_real_scale_14_6.png

5. Memory Profiling & DiD Analysis#

This section quantifies memory usage for key operations and then runs a large-scale DiD analysis on the first 2,000 genes. The goal is to measure resource footprint and end‑to‑end runtime under realistic conditions.

[9]:
# =============================================================================
# MEMORY PROFILING
# =============================================================================
print("=" * 70)
print("MEMORY PROFILING")
print("=" * 70)

memory_profile_results = []
did_results = None  # initialize before conditional assignment

# Profile signature DiD
if signature_cols and len(visits) == 2:
    def did_signatures():
        return st.did_table(
            adata_paired,
            features=signature_cols,
            design=design,
            visits=tuple(visits),
            aggregate="participant_visit",
            layer="log1p_tpm",
        )

    did_results, peak_mb, current_mb = memory_profile(did_signatures)
    memory_profile_results.append({
        'operation': 'DiD (signatures)',
        'delta_mb': peak_mb,
        'current_mb': current_mb,
        'n_features': len(signature_cols)
    })
    print(f"DiD (signatures): RSS delta={peak_mb:.2f}MB, RSS current={current_mb:.2f}MB")

# Profile large-scale DiD (first N genes by dataset order)
test_gene_counts = [500, 1000, 2000]
print("\nLarge-scale DiD memory usage (first N genes by dataset order):")

for n_genes in test_gene_counts:
    if n_genes > adata_paired.n_vars:
        continue

    _test_genes_mem = list(adata_paired.var_names[:n_genes])

    def did_gw(genes=_test_genes_mem):
        return st.did_table(
            adata_paired,
            features=genes,
            design=design,
            visits=tuple(visits),
            aggregate="participant_visit",
            layer="log1p_tpm",
        )

    _, peak_mb, current_mb = memory_profile(did_gw)
    memory_profile_results.append({
        'operation': f'DiD ({n_genes} genes)',
        'delta_mb': peak_mb,
        'current_mb': current_mb,
        'n_features': n_genes
    })
    print(f"  {n_genes} genes: RSS delta={peak_mb:.2f}MB, total RSS={current_mb:.0f}MB")
    del _test_genes_mem
    gc.collect()

memory_df = pd.DataFrame(memory_profile_results)
benchmark_results['memory'] = memory_df

# =============================================================================
# DID RESULTS ANALYSIS
# =============================================================================
print("\n" + "=" * 70)
print("DID RESULTS: SIGNATURES")
print("=" * 70)

if did_results is not None and not did_results.empty:
    display(did_results[["feature", "beta_DiD", "se_DiD", "p_DiD", "FDR_DiD", "n_units"]].round(4))

    # Significance summary
    n_sig_nominal = (did_results['p_DiD'] < 0.05).sum()
    n_sig_fdr = (did_results['FDR_DiD'] < FDR_ALPHA).sum()
    print(f"\nSignificance summary:")
    print(f"  Nominal (p < 0.05): {n_sig_nominal}/{len(did_results)}")
    print(f"  FDR < {FDR_ALPHA}: {n_sig_fdr}/{len(did_results)}")

# Run large-scale benchmark (first 2,000 genes)
print("\n" + "=" * 70)
print("DID BENCHMARK: LARGE-SCALE (first 2,000 genes)")
print("=" * 70)

max_genes = min(2000, adata_paired.n_vars)
features_gw = list(adata_paired.var_names[:max_genes])

def did_genomewide():
    return st.did_table(
        adata_paired,
        features=features_gw,
        design=design,
        visits=tuple(visits),
        aggregate="participant_visit",
        layer="log1p_tpm",
    )

res_gw, gw_time, gw_std = timed_run(did_genomewide, n_repeats=N_BENCHMARK_REPEATS)
benchmark_results['timing']['did_genomewide'] = gw_time

print(f"Large-scale DiD ({max_genes} genes): {format_time(gw_time)} ± {format_time(gw_std)}")
print(f"Throughput: {max_genes/gw_time:.0f} genes/second")

if res_gw is not None and not res_gw.empty:
    print(f"\nTop 10 genes by p-value (from first {max_genes} genes — ordering bias possible):")
    display(res_gw.head(10)[["feature", "beta_DiD", "se_DiD", "p_DiD", "FDR_DiD", "n_units"]].round(4))
======================================================================
MEMORY PROFILING
======================================================================
DiD (signatures): RSS delta=0.00MB, RSS current=19413.62MB

Large-scale DiD memory usage (first N genes by dataset order):
  500 genes: RSS delta=0.00MB, total RSS=19414MB
  1000 genes: RSS delta=0.03MB, total RSS=19414MB
  2000 genes: RSS delta=0.00MB, total RSS=19414MB

======================================================================
DID RESULTS: SIGNATURES
======================================================================
feature beta_DiD se_DiD p_DiD FDR_DiD n_units
0 sig_Activation -2.0117 0.6568 0.0022 0.0110 10
1 sig_IFN_Response -1.6544 1.8360 0.3675 0.7991 10
2 sig_Exhaustion -0.7621 1.0886 0.4839 0.7991 10
3 sig_Memory -0.8186 2.0637 0.6916 0.7991 10
4 sig_Cytotoxicity -0.3820 1.5010 0.7991 0.7991 10

Significance summary:
  Nominal (p < 0.05): 1/5
  FDR < 0.25: 1/5

======================================================================
DID BENCHMARK: LARGE-SCALE (first 2,000 genes)
======================================================================
Large-scale DiD (2000 genes): 4.99s ± 87.9ms
Throughput: 401 genes/second

Top 10 genes by p-value (from first 2000 genes — ordering bias possible):
feature beta_DiD se_DiD p_DiD FDR_DiD n_units
0 ABCC8 2.5994 0.5947 0.0000 0.0247 10
1 RBM7 -2.0196 0.4885 0.0000 0.0355 10
2 IP6K2 -1.6238 0.4469 0.0003 0.1856 10
3 PEX3 -2.0206 0.5947 0.0007 0.3389 10
4 AP3M2 -2.0088 0.6968 0.0039 0.9966 10
5 SLC4A7 -1.7360 0.6064 0.0042 0.9966 10
6 SPAST -1.6535 0.5865 0.0048 0.9966 10
7 PRKCZ 2.0435 0.7251 0.0048 0.9966 10
8 C19orf60 -1.3039 0.4659 0.0051 0.9966 10
9 PIK3C3 -1.7934 0.6915 0.0095 0.9966 10

6. Statistical Validation#

We validate statistical outputs with diagnostic panels:

  • P‑value histogram: checks uniformity under the null. With only ~10 paired participants and real biological signal, perfect uniformity is not expected; deviations indicate a mix of true effects and low power rather than miscalibration.

  • QQ plot: compares observed vs expected p‑values.

  • Effect size distribution: verifies effect size spread and sign balance.

  • Volcano plot: highlights effect sizes vs significance.

  • Memory usage: summarizes peak memory by operation.

  • Signature forest plot: shows DiD effects for signature‑level results.

[10]:
# =============================================================================
# STATISTICAL VALIDATION
# =============================================================================
print("=" * 70)
print("STATISTICAL VALIDATION")
print("=" * 70)

validation_results = {}

# 1. P-value Distribution Check
if res_gw is not None and not res_gw.empty:
    valid_p = res_gw['p_DiD'].dropna()

    # Descriptive check: under a pure null, p-values would be uniform.
    # This set includes both null and non-null genes, so deviations
    # from uniformity are expected and do NOT indicate miscalibration.
    quintiles = [0, 0.2, 0.4, 0.6, 0.8, 1.0]
    observed_props = []
    for i in range(len(quintiles)-1):
        if i < len(quintiles) - 2:
            prop = ((valid_p >= quintiles[i]) & (valid_p < quintiles[i+1])).mean()
        else:
            prop = ((valid_p >= quintiles[i]) & (valid_p <= quintiles[i+1])).mean()
        observed_props.append(prop)

    validation_results['p_value_uniformity'] = observed_props

    print("\n1. P-value Distribution (descriptive, not a calibration test):")
    print("   Reference: 0.20 per quintile if all genes were null (they are not)")
    for i, prop in enumerate(observed_props):
        print(f"   [{quintiles[i]:.1f}, {quintiles[i+1]:.1f}): {prop:.3f}")

# 2. Effect Size Distribution
if res_gw is not None and not res_gw.empty:
    valid_beta = res_gw['beta_DiD'].dropna()
    validation_results['beta_stats'] = {
        'mean': valid_beta.mean(),
        'std': valid_beta.std(),
        'median': valid_beta.median(),
        'min': valid_beta.min(),
        'max': valid_beta.max()
    }

    print("\n2. Effect Size Distribution:")
    print(f"   Mean: {valid_beta.mean():.4f}")
    print(f"   Std:  {valid_beta.std():.4f}")
    print(f"   Min:  {valid_beta.min():.4f}")
    print(f"   Max:  {valid_beta.max():.4f}")

# 3. Sample Size Consistency
if res_gw is not None and not res_gw.empty:
    n_units_values = res_gw['n_units'].unique()
    validation_results['n_units'] = list(n_units_values)

    print("\n3. Sample Size Consistency:")
    print(f"   n_units values: {sorted(n_units_values)}")
    print(f"   Expected: {N_VALID_PAIRED}")

benchmark_results['validation'] = validation_results

# =============================================================================
# VALIDATION VISUALIZATIONS
# =============================================================================
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# 1. P-value histogram
ax1 = axes[0, 0]
if res_gw is not None and not res_gw.empty:
    ax1.hist(res_gw['p_DiD'].dropna(), bins=20, color='#3498db', edgecolor='black', alpha=0.7)
    ax1.axhline(len(res_gw['p_DiD'].dropna())/20, color='red', linestyle='--', label='Uniform expectation')
    ax1.set_xlabel('P-value')
    ax1.set_ylabel('Frequency')
    ax1.set_title('P-value Distribution\n(first 2,000 genes)', fontweight='bold')
    ax1.legend()
else:
    ax1.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax1.transAxes)
    ax1.set_title('P-value Distribution', fontweight='bold')

# 2. QQ plot
ax2 = axes[0, 1]
if res_gw is not None and not res_gw.empty:
    valid_p = res_gw['p_DiD'].dropna().sort_values()
    n_p = len(valid_p)
    expected_p = (np.arange(1, n_p + 1) - 0.5) / n_p  # Hazen plotting positions
    ax2.scatter(-np.log10(expected_p), -np.log10(valid_p), alpha=0.5, s=10)
    max_val = max(-np.log10(expected_p).max(), -np.log10(valid_p).max())
    ax2.plot([0, max_val], [0, max_val], 'r--', label='y=x')
    ax2.set_xlabel('Expected -log10(p)')
    ax2.set_ylabel('Observed -log10(p)')
    ax2.set_title('QQ Plot\n(first 2,000 genes)', fontweight='bold')
    ax2.legend()
else:
    ax2.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax2.transAxes)
    ax2.set_title('QQ Plot', fontweight='bold')

# 3. Effect size distribution
ax3 = axes[0, 2]
if res_gw is not None and not res_gw.empty:
    ax3.hist(res_gw['beta_DiD'].dropna(), bins=30, color='#2ecc71', edgecolor='black', alpha=0.7)
    ax3.axvline(0, color='red', linestyle='--', label='Zero effect')
    ax3.set_xlabel('Beta (DiD effect)')
    ax3.set_ylabel('Frequency')
    ax3.set_title('Effect Size Distribution\n(first 2,000 genes)', fontweight='bold')
    ax3.legend()
else:
    ax3.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax3.transAxes)
    ax3.set_title('Effect Size Distribution', fontweight='bold')

# 4. Volcano plot (first 2,000 genes)
ax4 = axes[1, 0]
if res_gw is not None and not res_gw.empty:
    valid_res = res_gw[res_gw['p_DiD'].notna() & res_gw['beta_DiD'].notna()]
    neg_log_p = -np.log10(valid_res['p_DiD'].clip(lower=1e-300))

    # Color by significance
    colors = ['#e74c3c' if fdr < FDR_ALPHA else '#95a5a6'
              for fdr in valid_res['FDR_DiD'].fillna(1)]

    ax4.scatter(valid_res['beta_DiD'], neg_log_p, c=colors, alpha=0.6, s=20)
    ax4.axhline(-np.log10(0.05), color='blue', linestyle='--', alpha=0.5, label='p=0.05')
    ax4.axvline(0, color='gray', linestyle='-', alpha=0.3)
    ax4.set_xlabel('DiD Effect (beta)')
    ax4.set_ylabel('-log10(p)')
    ax4.set_title('Volcano Plot\n(first 2,000 genes)', fontweight='bold')
    ax4.legend()
else:
    ax4.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax4.transAxes)
    ax4.set_title('Volcano Plot', fontweight='bold')

# 5. Memory scaling — line plot for gene-scaling entries, dot for signature
ax5 = axes[1, 1]
if len(memory_df) > 0:
    col = 'current_mb' if 'current_mb' in memory_df.columns else 'delta_mb'
    # Separate signature entry from gene-scaling entries
    mask_genes = memory_df['n_features'].apply(lambda x: x >= 100)
    df_genes = memory_df[mask_genes].sort_values('n_features')
    df_sig = memory_df[~mask_genes]

    if len(df_genes) > 0:
        ax5.plot(df_genes['n_features'], df_genes[col], 'o-', color='#9b59b6',
                 linewidth=2, markersize=8, markeredgecolor='black', markeredgewidth=0.6,
                 label='Gene-count scaling')
    if len(df_sig) > 0:
        ax5.scatter(df_sig['n_features'], df_sig[col], marker='*', s=200,
                    color='#e6550d', edgecolors='black', linewidths=0.6, zorder=5,
                    label='Signatures only')
    ax5.set_xlabel('Number of Features')
    ax5.set_ylabel('Process RSS (MB)')
    ax5.set_title('Process Memory by Feature Count', fontweight='bold')
    ax5.legend(fontsize=8, framealpha=0.9)
    ax5.grid(True, alpha=0.2)
else:
    ax5.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax5.transAxes)
    ax5.set_title('Memory Usage', fontweight='bold')

# 6. Signature-level results — forest plot using package function
ax6 = axes[1, 2]
if did_results is not None and not did_results.empty:
    st.plot_did_forest(did_results, title='Signature DiD Effects', ax=ax6)
else:
    ax6.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax6.transAxes)
    ax6.set_title('Signature Effects', fontweight='bold')

plt.tight_layout()
plt.show()
======================================================================
STATISTICAL VALIDATION
======================================================================

1. P-value Distribution (descriptive, not a calibration test):
   Reference: 0.20 per quintile if all genes were null (they are not)
   [0.0, 0.2): 0.090
   [0.2, 0.4): 0.181
   [0.4, 0.6): 0.268
   [0.6, 0.8): 0.259
   [0.8, 1.0): 0.201

2. Effect Size Distribution:
   Mean: -0.3445
   Std:  0.8019
   Min:  -2.4177
   Max:  2.5994

3. Sample Size Consistency:
   n_units values: [10]
   Expected: 10
../_images/tutorials_stress_test_real_scale_18_1.png

7. GSEA Benchmark#

We benchmark gene set enrichment on expanded immune‑relevant gene sets. This evaluates both runtime and interpretability for pathway‑level summaries.

[11]:
# =============================================================================
# GSEA GENE SETS
# =============================================================================
print("=" * 70)
print("GSEA BENCHMARK")
print("=" * 70)

expanded_sets = {
    "Cytotoxicity": ["GZMB", "GZMA", "GZMH", "GZMK", "PRF1", "GNLY", "NKG7", "IFNG", "KLRD1", "KLRB1", "KLRK1", "FASLG", "CTSW"],
    "Exhaustion": ["PDCD1", "LAG3", "HAVCR2", "TIGIT", "CTLA4", "ENTPD1", "TOX", "CXCL13", "EOMES", "BATF", "IRF4", "ICOS", "LAYN"],
    "IFN_Response": ["ISG15", "IFI6", "IFIT1", "IFIT2", "IFIT3", "MX1", "MX2", "STAT1", "OAS1", "IRF7", "IFITM1", "IFITM3", "IFI44", "IFI44L", "OAS2"],
    "T_Cell_Activation": ["CD69", "CD38", "HLA-DRA", "ICOS", "IL2RA", "TNFRSF9", "CD27", "CD28", "TNFRSF4", "CD40LG", "IFNG"],
    "Memory_T": ["IL7R", "TCF7", "LEF1", "CCR7", "SELL", "CD27", "CD28", "BCL2", "LTB", "MALAT1"],
    "B_Cell_Response": ["CD79A", "CD79B", "MS4A1", "MZB1", "XBP1", "JCHAIN", "IGHG1", "IGHG2", "IGHA1", "IGKC", "CD74"],
    "Antigen_Presentation": ["HLA-DRA", "HLA-DRB1", "HLA-DPA1", "HLA-DPB1", "CD74", "HLA-A", "HLA-B", "HLA-C", "B2M"],
    "Myeloid_Inflammation": ["S100A8", "S100A9", "LYZ", "VCAN", "IL1B", "CXCL8", "CTSD", "LGALS3", "FCGR3A", "LILRB1"],
    "NK_Markers": ["NKG7", "KLRD1", "KLRB1", "GNLY", "PRF1", "TRAC", "FCGR3A", "TYROBP"],
    "Chemokines": ["CCL2", "CCL3", "CCL4", "CCL5", "CXCL9", "CXCL10", "CXCL11", "CXCL8", "CXCL2", "CXCL3"],
    "Cell_Cycle": ["MKI67", "TOP2A", "TYMS", "MCM5", "MCM6", "MCM7", "PCNA", "UBE2C", "CCNB1", "CCNB2"],
    "Apoptosis": ["BAX", "BCL2", "CASP3", "CASP8", "FAS", "FASLG", "BID", "BCL2L1", "TNFRSF10B"],
    "Costimulation": ["CD27", "CD28", "TNFRSF4", "TNFRSF9", "ICOS", "CD40", "CD40LG", "TNFSF9"],
}

gene_sets = {}
print("\nGene Set Coverage:")
for k, genes in expanded_sets.items():
    present = [g for g in genes if g in adata_paired.var_names]
    if len(present) >= MIN_GENES_FOR_SCORE:
        gene_sets[k] = present
        print(f"  {k}: {len(present)}/{len(genes)} genes")

print("\nTotal gene sets for GSEA:")
print(len(gene_sets))

# =============================================================================
# RUN GSEA WITH TIMING (single run — GSEA uses internal permutations so
# repeated benchmarking is redundant and very slow)
#
# NOTE: High tie rates (>20%) are expected in subsampled scRNA-seq data
# because many genes have near-zero DiD effects. This may affect enrichment
# rank stability. For publication analyses, use full gene sets and check
# fgsea's tie-handling warnings.
# =============================================================================
if gene_sets:
    def run_gsea():
        return st.run_gsea_did_multi(
            adata_paired,
            gene_sets={"Custom": gene_sets},
            design=design,
            visits=tuple(visits),
            min_size=MIN_GENES_FOR_SCORE,
        )

    gsea_res, gsea_time, gsea_std = timed_run(run_gsea, n_repeats=1)
    gsea_res = gsea_res.get("Custom", pd.DataFrame())
    benchmark_results["timing"]["gsea_did"] = gsea_time

    print("\nGSEA runtime:", format_time(gsea_time))
    print("Gene sets per second:", f"{len(gene_sets)/gsea_time:.1f}")

    if gsea_res is not None and not gsea_res.empty:
        print("\nGSEA Results:")
        display(gsea_res.round(4))

        # ---- Publication-quality GSEA visualization ----
        fig, ax = plt.subplots(figsize=(10, max(4, 0.4 * len(gsea_res))))
        if "NES" in gsea_res.columns:
            sorted_res = gsea_res.sort_values("NES")
            labels = sorted_res["Term"] if "Term" in sorted_res.columns else sorted_res.index
            nes_vals = sorted_res["NES"].values
            y_pos = np.arange(len(sorted_res))

            # Determine significance column
            fdr_col = [c for c in sorted_res.columns if "fdr" in c.lower() or "padj" in c.lower() or "q" in c.lower()]
            p_col = [c for c in sorted_res.columns if "pval" in c.lower() or "p_val" in c.lower() or c == "P-value"]
            sig_col = fdr_col[0] if fdr_col else (p_col[0] if p_col else None)

            # Color by direction; grey if non-significant
            colors = []
            annotations = []
            for _, row in sorted_res.iterrows():
                base_color = "#c0392b" if row["NES"] < 0 else "#27ae60"
                if sig_col and sig_col in row.index:
                    q = row[sig_col]
                    if q < 0.05:
                        annotations.append("***")
                    elif q < 0.10:
                        annotations.append("**")
                    elif q < 0.25:
                        annotations.append("*")
                    else:
                        annotations.append("")
                        base_color = "#bdc3c7"
                else:
                    annotations.append("")
                colors.append(base_color)

            ax.hlines(y_pos, 0, nes_vals, color="#aaaaaa", linewidth=1.5, zorder=1)
            ax.scatter(nes_vals, y_pos, c=colors, s=120,
                       edgecolors="black", linewidths=0.8, zorder=3)

            for j, (nes, ann) in enumerate(zip(nes_vals, annotations)):
                offset = 0.05 if nes >= 0 else -0.05
                ha = "left" if nes >= 0 else "right"
                ax.text(nes + offset, j, f"{nes:.2f} {ann}".strip(),
                        va="center", ha=ha, fontsize=9, fontweight="bold")

            ax.axvline(0, color="black", linewidth=0.8)
            ax.set_yticks(y_pos)
            ax.set_yticklabels(labels, fontsize=10)
            ax.set_xlabel("Normalized Enrichment Score (NES)", fontsize=11)
            sig_note = ""
            if sig_col:
                sig_note = f"\n(* {sig_col}<0.25  ** <0.10  *** <0.05; grey = n.s.)"
            ax.set_title(f"GSEA DiD Enrichment{sig_note}",
                         fontsize=12, fontweight="bold")
            ax.grid(True, axis="x", alpha=0.2)
            ax.set_axisbelow(True)
            ax.margins(x=0.2)
        else:
            ax.text(0.5, 0.5, "NES not available",
                    ha="center", va="center", transform=ax.transAxes)
        plt.tight_layout()
        plt.show()
else:
    print("No gene sets available for GSEA.")

gc.collect()
======================================================================
GSEA BENCHMARK
======================================================================

Gene Set Coverage:
  Cytotoxicity: 13/13 genes
  Exhaustion: 13/13 genes
  IFN_Response: 15/15 genes
  T_Cell_Activation: 11/11 genes
  Memory_T: 10/10 genes
  B_Cell_Response: 10/11 genes
  Antigen_Presentation: 9/9 genes
  Myeloid_Inflammation: 9/10 genes
  NK_Markers: 8/8 genes
  Chemokines: 9/10 genes
  Cell_Cycle: 10/10 genes
  Apoptosis: 9/9 genes
  Costimulation: 8/8 genes

Total gene sets for GSEA:
13
2026-04-05 19:42:39,008 [WARNING] Duplicated values found in preranked stats: 24.84% of genes
The order of those genes will be arbitrary, which may produce unexpected results.

GSEA runtime: 2.0min
Gene sets per second: 0.1

GSEA Results:
Name Term ES NES NOM p-val FDR q-val FWER p-val Tag % Gene % Lead_genes
0 prerank IFN_Response -0.732314 -1.759546 0.0 0.030137 0.019 15/15 26.79% STAT1;MX1;MX2;IFITM1;ISG15;IFI44L;IFI44;IFI6;I...
1 prerank B_Cell_Response 0.676856 1.565822 0.046569 0.077976 0.185 6/10 10.52% IGHA1;IGHG1;IGHG2;CD79B;IGKC;CD79A
2 prerank Chemokines -0.671906 -1.411026 0.079038 0.448467 0.453 9/9 32.82% CCL3;CXCL11;CCL2;CXCL9;CXCL10;CXCL2;CCL4;CCL5;...
3 prerank Exhaustion -0.607116 -1.401916 0.092257 0.318113 0.473 10/13 33.25% IRF4;CTLA4;ICOS;EOMES;HAVCR2;LAYN;ENTPD1;LAG3;...
4 prerank Antigen_Presentation -0.655176 -1.366071 0.123311 0.294912 0.551 7/9 27.46% B2M;HLA-A;HLA-C;HLA-DPB1;CD74;HLA-DPA1;HLA-B
5 prerank Apoptosis -0.645602 -1.338906 0.13468 0.279843 0.614 5/9 25.18% FAS;TNFRSF10B;CASP3;FASLG;BID
6 prerank T_Cell_Activation -0.607957 -1.3337 0.149485 0.238943 0.622 8/11 27.76% IL2RA;CD38;ICOS;CD28;TNFRSF9;IFNG;CD27;CD69
7 prerank Cell_Cycle -0.616866 -1.322096 0.123457 0.217314 0.644 8/10 34.89% MCM6;PCNA;MCM5;CCNB1;TYMS;CCNB2;UBE2C;TOP2A
8 prerank Costimulation -0.624324 -1.249148 0.198962 0.276794 0.775 4/8 17.33% TNFSF9;ICOS;CD28;TNFRSF9
9 prerank Myeloid_Inflammation -0.578058 -1.205582 0.239796 0.294832 0.853 9/9 42.21% IL1B;S100A8;S100A9;LYZ;LGALS3;CTSD;LILRB1;VCAN...
10 prerank Memory_T -0.518895 -1.108468 0.33564 0.37743 0.941 4/10 34.57% IL7R;CD28;CD27;CCR7
11 prerank Cytotoxicity -0.452434 -1.03649 0.429048 0.425179 0.972 8/13 41.81% GZMK;FASLG;IFNG;CTSW;KLRK1;NKG7;GZMB;GZMH
12 prerank NK_Markers 0.317555 0.694485 0.847458 0.851921 0.996 8/8 68.25% KLRD1;KLRB1;GNLY;PRF1;TRAC;FCGR3A;NKG7;TYROBP
../_images/tutorials_stress_test_real_scale_20_4.png
[11]:
3662

Workflow API Example#

Here we run DiD using the fluent workflow API with a DiDConfig object.

[12]:
if signature_cols:
    cfg = st.DiDConfig(aggregate='participant_visit', standardize=True, use_bootstrap=False, layer="log1p_tpm")
    wf_res = (
        st.workflow(adata_paired)
        .did_table(features=signature_cols[:1], design=design, visits=tuple(visits), config=cfg)
        .result()
    )
    display(wf_res)
beta_DiD se_DiD p_DiD beta_time p_time n_units resid_sd cov_type_used feature FDR_DiD
0 -0.382015 1.501016 0.799106 0.366814 0.631812 10 18.570902 cluster sig_Cytotoxicity 0.799106