Source code for sctrial.scoring

"""Gene-set scoring: z-mean, Seurat-style, and AUCell methods."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd
import scipy.sparse as sp
from anndata import AnnData

if TYPE_CHECKING:
    from ctxcore.genesig import GeneSignature

__all__ = ["score_gene_sets", "score_gene_sets_aucell", "ScoreMethod"]

logger = logging.getLogger(__name__)

ScoreMethod = Literal["zmean", "mean"]

try:
    from pyscenic.aucell import aucell, create_rankings
except ImportError:  # pragma: no cover
    create_rankings = None
    aucell = None


[docs] def score_gene_sets( adata: AnnData, gene_sets: dict[str, list[str]], *, layer: str | None = None, method: ScoreMethod = "zmean", prefix: str = "", min_genes: int = 3, overwrite: bool = True, ) -> AnnData: """Score gene sets and store results in `adata.obs`. Parameters ---------- adata AnnData object containing expression data. gene_sets Dictionary mapping set names to lists of gene names. Each value must be a ``list`` (not a bare string). Duplicate gene names within a set are automatically removed. layer Expression matrix source. If None, uses `adata.X`. For log1p-CPM workflows, use layer="log1p_cpm". method Scoring method: - "mean": mean expression across genes. - "zmean": z-score each gene across cells (within the current AnnData), then average z-scores across genes. This is the recommended method as it accounts for different expression scales across genes. prefix Prefix to add to column names (e.g., ``ms_`` for module scores). min_genes Minimum number of genes from the set that must be present in the data. If fewer genes overlap, the score is set to NaN and a warning is logged. Default is 3. overwrite If False, skip gene sets that already have a column in adata.obs. Returns ------- AnnData The input AnnData with new columns added to obs. Notes ----- **Zero-variance gene handling (zmean method):** Genes with zero or near-zero variance (std < 1e-12) are excluded from the z-mean calculation. If ALL genes in a set have zero variance, the score is NaN. This prevents division by zero and ensures meaningful scores. The zmean method computes: mean(z_i) where z_i = (x_i - mean(x_i)) / std(x_i) for each gene i across all cells. **Non-finite expression values:** NaN and inf values in the expression matrix are excluded from score computation (treated as missing). A warning is logged when non-finite values are detected. If *all* values for a cell are non-finite the resulting score will be NaN. """ if method not in ("zmean", "mean"): raise ValueError(f"Unknown method '{method}'. Use 'zmean' or 'mean'.") if not isinstance(gene_sets, dict) or len(gene_sets) == 0: raise ValueError("gene_sets must be a non-empty dict of name -> gene list.") if not isinstance(prefix, str): raise ValueError("prefix must be a string.") if min_genes < 1: raise ValueError("min_genes must be >= 1.") if layer is not None and layer not in adata.layers: raise KeyError(f"Layer '{layer}' not found in adata.layers.") # Validate per-set gene list types up front for name, gset in gene_sets.items(): if not isinstance(gset, (list, tuple, set, frozenset)): raise TypeError( f"Gene set '{name}' must be a list of gene names, got {type(gset).__name__}." ) X = adata.layers[layer] if layer is not None else adata.X var_names = adata.var_names idx = {g: i for i, g in enumerate(var_names)} is_sparse = sp.issparse(X) if is_sparse: if not isinstance(X, sp.csr_matrix): X = X.tocsr() for name, gset in gene_sets.items(): # Deduplicate while preserving order seen: set[str] = set() use: list[str] = [] for g in gset: if g in idx and g not in seen: use.append(g) seen.add(g) col = f"{prefix}{name}" if (not overwrite) and (col in adata.obs.columns): continue # Count unique requested genes (after dedup) for accurate logging n_unique_requested = len({g for g in gset}) n_found = len(use) if n_found < min_genes: logger.warning( "Gene set '%s': only %d/%d unique genes found in data " "(min_genes=%d); setting score to NaN.", name, n_found, n_unique_requested, min_genes, ) adata.obs[col] = np.nan continue if n_found < n_unique_requested: logger.info( "Gene set '%s': %d/%d unique genes found in data.", name, n_found, n_unique_requested, ) gidx = np.array([idx[g] for g in use], dtype=int) if method == "mean" and is_sparse: sub = X[:, gidx] dense_sub = np.asarray(sub.todense()) finite_mask = np.isfinite(dense_sub) n_nonfinite = int(np.sum(~finite_mask)) if n_nonfinite > 0: logger.warning( "Gene set '%s': %d non-finite value(s) in expression " "matrix; these are excluded from the mean.", name, n_nonfinite, ) # nanmean ignores NaN; replace inf with NaN first dense_sub = np.where(finite_mask, dense_sub, np.nan) with np.errstate(all="ignore"): score = np.nanmean(dense_sub, axis=1) adata.obs[col] = score continue # For zmean, or dense mean, compute dense submatrix for gene-set only sub = X[:, gidx].toarray() if is_sparse else np.asarray(X[:, gidx], dtype=np.float64) # Detect and mask non-finite values finite_mask = np.isfinite(sub) n_nonfinite_expr = int(np.sum(~finite_mask)) if n_nonfinite_expr > 0: logger.warning( "Gene set '%s': %d non-finite value(s) in expression " "matrix; these are excluded from scoring.", name, n_nonfinite_expr, ) sub = np.where(finite_mask, sub, np.nan) if method == "mean": with np.errstate(all="ignore"): score = np.nanmean(sub, axis=1) else: # zmean (already validated above) # Compute stats ignoring NaN with np.errstate(all="ignore"): mu = np.nanmean(sub, axis=0, keepdims=True) sd = np.nanstd(sub, axis=0, ddof=1, keepdims=True) # Mask zero-variance genes to prevent division by zero valid_genes = (sd > 1e-12).ravel() n_valid = valid_genes.sum() if n_valid == 0: # All genes have zero variance - return NaN score = np.full(sub.shape[0], np.nan) else: # Z-score only non-zero-variance genes; nanmean # handles any per-cell NaN from the expression data. z = (sub[:, valid_genes] - mu[:, valid_genes]) / sd[:, valid_genes] with np.errstate(all="ignore"): score = np.nanmean(z, axis=1) adata.obs[col] = score return adata
[docs] def score_gene_sets_aucell( adata: AnnData, gene_sets: dict[str, list[str]] | dict[str, GeneSignature], *, layer: str | None = None, prefix: str = "aucell_", overwrite: bool = False, ) -> AnnData: """Score gene sets using AUCell (pySCENIC). Requires pyscenic to be installed. Parameters ---------- adata AnnData object containing expression data. gene_sets Dictionary mapping set names to lists of genes (or GeneSignature objects). layer Expression layer to use. If None, uses `adata.X`. prefix Prefix to add to output columns (default: ``aucell_``). overwrite If False, skip sets that already exist in `adata.obs`. Returns ------- AnnData The input AnnData with AUCell scores added to `adata.obs`. """ if create_rankings is None or aucell is None: raise ImportError( "pyscenic is required for AUCell scoring. Install with 'pip install pyscenic'." ) X = adata.layers[layer] if layer is not None else adata.X if hasattr(X, "toarray"): X = X.toarray() df = pd.DataFrame(X, columns=adata.var_names, index=adata.obs_names) rankings = create_rankings(df) from ctxcore.genesig import GeneSignature for name, genes in gene_sets.items(): col = f"{prefix}{name}" if (not overwrite) and col in adata.obs.columns: continue if isinstance(genes, GeneSignature): gs = genes else: genes_present = [g for g in genes if g in adata.var_names] if not genes_present: adata.obs[col] = np.nan continue gs = GeneSignature(name, genes_present) scores = aucell(rankings, [gs])[0] adata.obs[col] = scores return adata