Previews (R and shell scripts)

script_1.R


#   sncRNA-seq analysis pipeline to:
#   1) read Bowtie2-mapped small-RNA tables from multiple references (RNAcentral,
#      planarian tRNA set, rRNA set, transcriptome, genome),
#   2) assign an RNA class (miRNA, tRF, rRF, etc.) using custom rules,
#   3) compute a *consensus* annotation per unique sequence (resolve multi-hits),
#   4) build a sample-by-feature count matrix and CPM-filter it for DGE


# ----------------------------
# 0) Packages 
# ----------------------------
install_if_missing <- function(pkgs, bioc = FALSE) {
  for (p in pkgs) {
    if (!requireNamespace(p, quietly = TRUE)) {
      if (!bioc) {
        install.packages(p, repos = "https://cloud.r-project.org")
      } else {
        if (!requireNamespace("BiocManager", quietly = TRUE)) {
          install.packages("BiocManager", repos = "https://cloud.r-project.org")
        }
        BiocManager::install(p, update = FALSE, ask = FALSE)
      }
    }
  }
}

cran_pkgs <- c("data.table", "dplyr", "stringr", "purrr", "tidyr", "forcats", "ggplot2", "ggrepel", "scales")
bioc_pkgs <- c("GenomicAlignments", "Biostrings", "edgeR")

install_if_missing(cran_pkgs, bioc = FALSE)
install_if_missing(bioc_pkgs, bioc = TRUE)

# ----------------------------
# 1) Configuration
# ----------------------------
cfg <- list(
  # Input directories
  # Alignment of all snRNAs to RNAcentral DB
  dir_rnacentral  = "F:/PARN_ELAC_silencing/smallRNA/calculated_data_bowtie2_end_to_end/mapped_to_RNAcentral/mapped_seq_with_strand_new/",
  # Alignment of TRfs only
  dir_trna_map    = "E:/Illumina/PARN_ELAC2_silencing/smallRNA/smallRNAwithAdapters/miRNA/bowtie2_mapped_SM_tRNA_seq_with_strand/",
  # Alignment of all snRNAs not mapped to RNAcentral DB to SM rRNA reference
  dir_rrna_map    = "F:/PARN_ELAC_silencing/smallRNA/calculated_data_bowtie2_end_to_end/mapped_to_rRNA_and_genome/rRNA_seq_with_strand/",
  # Alignment of all snRNAs not mapped to RNAcentral DB to transcriptome reference
  dir_transcript  = "F:/PARN_ELAC_silencing/smallRNA/calculated_data_bowtie2_end_to_end/mapped_to_transcriptome/mapped_seq_with_strand/",
  # Alignment of all snRNAs not mapped to RNAcentral DB or rRNA/transcriptome to genome reference
  dir_genome      = "F:/PARN_ELAC_silencing/smallRNA/calculated_data_bowtie2_end_to_end/mapped_to_genome/mapped_seq_with_strand/",
  
  # Output
  out_dir         = "F:/PARN_ELAC_silencing/smallRNA/clean_pipeline_out/",
  
  # Filtering
  keep_strand     = 0L,   # keep only strand == 0 
  max_total_mm    = 3L,   # keep reads with (XM + non-match cigar ops) <= 3
  
  # CPM filtering for DGE
  cpm_threshold   = 10,
  cpm_min_samples = 2L
)

if (!dir.exists(cfg$out_dir)) dir.create(cfg$out_dir, recursive = TRUE)

# ----------------------------
# 2) Helpers
# ----------------------------


revcomp_chr <- function(x) {
  as.character(Biostrings::reverseComplement(Biostrings::DNAStringSet(x)))
}

# Read a mapping table 
read_map_tbl <- function(path) {
  dt <- data.table::fread(
    path,
    header = FALSE,
    fill = TRUE,
    showProgress = FALSE
  )
  
  # Keep only the first 9 columns and name them consistently
  if (ncol(dt) < 9) {
    stop("File has < 9 columns: ", path)
  }
  dt <- dt[, 1:9]
  data.table::setnames(dt, c("read", "strand", "ref", "position", "qual", "cigar", "seq", "XM", "MD"))
  dt
}

parse_xm <- function(xm_chr) {
  out <- suppressWarnings(as.integer(stringr::str_replace(xm_chr, "^XM:i:", "")))
  out[is.na(out)] <- 0L
  out
}

# Compute non-match operations in CIGAR using 
cigar_nonmatch_ops <- function(cigar_chr) {
  op <- GenomicAlignments::cigarOpTable(cigar_chr)
  
  # Count matches separately (M plus "=" if present); everything else is "non-match ops"
  match_cols <- intersect(colnames(op), c("M", "="))
  nonmatch_cols <- setdiff(colnames(op), match_cols)
  
  nonmatch <- if (length(nonmatch_cols) == 0) rep.int(0L, nrow(op)) else rowSums(op[, nonmatch_cols, drop = FALSE])
  match    <- if (length(match_cols) == 0) rep.int(0L, nrow(op)) else rowSums(op[, match_cols, drop = FALSE])
  
  list(nonmatch = as.integer(nonmatch), match = as.integer(match))
}

# RNA class assignment 
rna_type_from_ref <- function(ref_chr) {
  # Priority is encoded by order of case_when clauses
  dplyr::case_when(
    stringr::str_detect(ref_chr, stringr::regex("multiple_hits", ignore_case = TRUE)) ~ "multiple_hits",
    stringr::str_detect(ref_chr, stringr::regex("dd_Smed_v6", ignore_case = TRUE)) ~ "mRNA fragments",
    
    # rRNA-like
    stringr::str_detect(ref_chr, stringr::regex("ribosomal|\\brRNA\\b|ITS1|ITS2|SpacerA|\\b28S\\b|\\b12S\\b|\\b16S\\b|5\\.8S|Schmed_cloneH735c", ignore_case = TRUE)) ~ "rRNA fragments",
    
    # tRNA-like
    stringr::str_detect(ref_chr, stringr::regex("\\btRNA\\b|transfer", ignore_case = TRUE)) ~ "tRNA fragments",
    
    # sno/sn
    stringr::str_detect(ref_chr, stringr::regex("snoRNA|nucleolar", ignore_case = TRUE)) ~ "snoRNA fragments",
    stringr::str_detect(ref_chr, stringr::regex("snRNA|spliceosomal|\\b7SK\\b|nuclear", ignore_case = TRUE)) ~ "snRNA fragments",
    
    # small regulatory
    stringr::str_detect(ref_chr, stringr::regex("piRNA", ignore_case = TRUE)) ~ "piRNA",
    stringr::str_detect(ref_chr, stringr::regex("miRNA|microRNA|\\bmiR\\b|Sme-|sme-lin|sme-let|Sme-Bantam", ignore_case = TRUE)) ~ "miRNA",
    
    # lnc
    stringr::str_detect(ref_chr, stringr::regex("long_non-coding|\\blnc\\b", ignore_case = TRUE)) ~ "lncRNA fragments",
    
    TRUE ~ "other fragments"
  )
}

#  sample are ELAC13S, WT25S, GFP33S, etc
sample_to_group <- function(sample_id) {
  #  ELAC13S -> ELAC3S
  stringr::str_replace(sample_id, "^(ELAC|WT|GFP|PARN)\\d(3S|5S)$", "\\1\\2")
}

list_basenames <- function(dir_path) {
  list.files(dir_path, full.names = FALSE, all.files = FALSE)
}

common_basenames <- function(dirs) {
  bn <- purrr::map(dirs, list_basenames)
  Reduce(intersect, bn)
}

# ----------------------------
# 3) Per-sample processing
# ----------------------------
process_one_sample <- function(basename_file, cfg) {
  # Read each source
  dt_rc   <- read_map_tbl(file.path(cfg$dir_rnacentral, basename_file))
  dt_trna <- read_map_tbl(file.path(cfg$dir_trna_map, basename_file))
  dt_rrna <- read_map_tbl(file.path(cfg$dir_rrna_map, basename_file))
  dt_tx   <- read_map_tbl(file.path(cfg$dir_transcript, basename_file))
  
  # Annotate tRNA/rRNA sources 
  dt_trna[, RNA_type := "tRNA fragments"]
  dt_rrna[, RNA_type := "rRNA fragments"]
  
  # Annotate RNAcentral
  dt_rc[, RNA_type := rna_type_from_ref(ref)]
  
  # Priority logic 
  # 1) keep tRNA mappings first (remove those reads from RNAcentral set)
  dt_rc2 <- dt_rc[!(read %in% dt_trna$read)]
  dt_all <- data.table::rbindlist(list(dt_rc2, dt_trna), use.names = TRUE, fill = TRUE)
  
  # 2) add rRNA mappings only for reads not already present
  dt_rrna2 <- dt_rrna[!(read %in% dt_all$read)]
  dt_all <- data.table::rbindlist(list(dt_all, dt_rrna2), use.names = TRUE, fill = TRUE)
  
  # 3) add transcriptome mappings (mRNA fragments) only for remaining reads
  dt_tx2 <- dt_tx[!(read %in% dt_all$read)]
  dt_tx2[, RNA_type := "mRNA fragments"]
  dt_all <- data.table::rbindlist(list(dt_all, dt_tx2), use.names = TRUE, fill = TRUE)
  
  # Compute mismatch features
  cig <- cigar_nonmatch_ops(dt_all$cigar)
  dt_all[, cigar_match    := cig$match]
  dt_all[, cigar_nonmatch := cig$nonmatch]
  dt_all[, xm_mm          := parse_xm(XM)]
  dt_all[, all_mm         := xm_mm + cigar_nonmatch]
  
  # Compute oriented sequence 
  dt_all[, orig_seq := ifelse(strand == 0, seq, revcomp_chr(seq))]
  
  # Apply filters 
  dt_all <- dt_all[strand == cfg$keep_strand & all_mm <= cfg$max_total_mm]
  
  dt_all[, .(
    read, strand, ref, position, cigar, seq, orig_seq,
    RNA_type, XM, MD, cigar_match, cigar_nonmatch, xm_mm, all_mm
  )]
}

# ----------------------------
# 4) Consensus annotation per unique sequence
# ----------------------------
#   - build unique (orig_seq, pos_ref) pairs from all samples,
#   - classify each pos_ref,
#   - if a sequence maps to multiple types, prefer:
#       (1) tRNA fragments
#       (2) entries containing "Schmidtea"
#       (3) "other fragments" (only if no tRNA)
#       else "multiple_hits"
build_consensus_annotation <- function(dt_all_samples) {
  dt <- data.table::as.data.table(dt_all_samples)
  
  # pos_ref includes coordinate + ref 
  dt[, pos_ref := paste(position, ref, sep = "_")]
  
  # Unique mapping candidates
  uniq <- unique(dt[, .(orig_seq, pos_ref)])
  
  # Classify based on pos_ref content 
  uniq[, RNA_type := rna_type_from_ref(pos_ref)]
  
  # Determine if any sequence has multiple RNA types
  uniq[, correct_annotation := {
    # All candidates for this orig_seq
    pos_all <- pos_ref
    type_all <- RNA_type
    
    # Preference 1: any tRNA hit
    if (any(type_all == "tRNA fragments")) {
      pos_all[which(type_all == "tRNA fragments")[1]]
    } else if (any(stringr::str_detect(pos_all, stringr::regex("Schmidtea", ignore_case = TRUE)))) {
      pos_all[which(stringr::str_detect(pos_all, stringr::regex("Schmidtea", ignore_case = TRUE)))[1]]
    } else if (any(type_all == "other fragments")) {
      pos_all[which(type_all == "other fragments")[1]]
    } else {
      "multiple_hits"
    }
  }, by = orig_seq]
  
  uniq[, correct_RNA_type := rna_type_from_ref(correct_annotation)]
  
  # One row per orig_seq
  out <- unique(uniq[, .(orig_seq, correct_annotation, correct_RNA_type)])
  out
}

# ----------------------------
# 5) Genome-only add-in 
# ----------------------------
read_genome_only <- function(basename_file, cfg) {
  dt_g <- read_map_tbl(file.path(cfg$dir_genome, basename_file))
  dt_g[, RNA_type := "genome"]
  
  cig <- cigar_nonmatch_ops(dt_g$cigar)
  dt_g[, cigar_match    := cig$match]
  dt_g[, cigar_nonmatch := cig$nonmatch]
  dt_g[, xm_mm          := parse_xm(XM)]
  dt_g[, all_mm         := xm_mm + cigar_nonmatch]
  dt_g[, orig_seq       := ifelse(strand == 0, seq, revcomp_chr(seq))]
  
  dt_g <- dt_g[strand == cfg$keep_strand & all_mm <= cfg$max_total_mm]
  
  dt_g[, .(
    read, strand, ref, position, cigar, seq, orig_seq,
    RNA_type, XM, MD, cigar_match, cigar_nonmatch, xm_mm, all_mm
  )]
}

# ----------------------------
# 6) Count matrix + CPM filtering
# ----------------------------
make_count_matrix <- function(sample_tables) {
  # Build (seq_anno, sample, n) long table and cast wide.
  dt_long <- data.table::rbindlist(
    lapply(names(sample_tables), function(sid) {
      dt <- data.table::as.data.table(sample_tables[[sid]])
      dt[, seq_anno := paste(orig_seq, correct_annotation, sep = " ")]
      dt[, .(n = .N), by = .(seq_anno)][, sample := sid][]
    }),
    use.names = TRUE, fill = TRUE
  )
  
  dt_wide <- data.table::dcast(dt_long, seq_anno ~ sample, value.var = "n", fill = 0)
  
  # Convert to numeric matrix for edgeR::cpm
  mat <- as.matrix(dt_wide[, -1])
  storage.mode(mat) <- "numeric"
  rownames(mat) <- dt_wide$seq_anno
  
  mat
}

cpm_filter <- function(count_mat, cpm_threshold = 10, min_samples = 2L) {
  cpm <- edgeR::cpm(count_mat)
  keep <- rowSums(cpm > cpm_threshold) >= min_samples
  list(cpm = cpm, keep = keep)
}

# ==============================================================================
# 7) Run the pipeline
# ==============================================================================
dirs_needed <- c(cfg$dir_rnacentral, cfg$dir_trna_map, cfg$dir_rrna_map, cfg$dir_transcript)
basenames <- common_basenames(dirs_needed)

if (length(basenames) == 0) {
  stop("No common basenames found across the required directories.")
}

# --- A) Build per-sample tables (no genome)
sample_tbls <- purrr::set_names(vector("list", length(basenames)), basenames)
for (bn in basenames) {
  message("Processing (no genome): ", bn)
  sample_tbls[[bn]] <- process_one_sample(bn, cfg)
}

# Create a stable sample_id from basename 
sample_ids <- stringr::str_replace(basenames, "\\.[^.]*$", "")
names(sample_tbls) <- sample_ids

# Add group labels like ELAC3S / WT5S
for (sid in names(sample_tbls)) {
  dt <- data.table::as.data.table(sample_tbls[[sid]])
  dt[, set := sample_to_group(sid)]
  sample_tbls[[sid]] <- dt
}

# --- B) Build consensus annotation map from all samples
dt_all <- data.table::rbindlist(sample_tbls, use.names = TRUE, fill = TRUE)
anno_map <- build_consensus_annotation(dt_all)

# Save annotation map
saveRDS(anno_map, file.path(cfg$out_dir, "consensus_annotation_map.rds"))
data.table::fwrite(anno_map, file.path(cfg$out_dir, "consensus_annotation_map.csv"))

# --- C) Attach consensus annotation to each sample table
sample_tbls_anno <- lapply(sample_tbls, function(dt) {
  dt <- data.table::as.data.table(dt)
  dt2 <- merge(dt, anno_map, by = "orig_seq", all.x = TRUE)
  
  # If something is NA ...
  dt2[is.na(correct_annotation), correct_annotation := paste(position, ref, sep = "_")]
  dt2[is.na(correct_RNA_type),   correct_RNA_type   := rna_type_from_ref(correct_annotation)]
  
  dt2
})

saveRDS(sample_tbls_anno, file.path(cfg$out_dir, "sample_tables_no_genome_annotated.rds"))

# --- Add genome-only reads (time-consuming)

basenames_g <- common_basenames(c(cfg$dir_genome, cfg$dir_rnacentral))
basenames_g <- intersect(basenames_g, basenames)  # keep aligned set

genome_tbls <- purrr::set_names(vector("list", length(basenames_g)), basenames_g)
for (bn in basenames_g) {
  message("Processing (genome): ", bn)
  genome_tbls[[bn]] <- read_genome_only(bn, cfg)
}
names(genome_tbls) <- stringr::str_replace(basenames_g, "\\.[^.]*$", "")

# Append genome reads not already assigned in no-genome set (by read ID)
sample_tbls_with_genome <- lapply(names(sample_tbls_anno), function(sid) {
  dt_no <- data.table::as.data.table(sample_tbls_anno[[sid]])
  
  if (!sid %in% names(genome_tbls)) return(dt_no)
  
  dt_g <- data.table::as.data.table(genome_tbls[[sid]])
  dt_g <- dt_g[!(read %in% dt_no$read)]
  dt_g[, set := sample_to_group(sid)]
  dt_g[, correct_annotation := ref]
  dt_g[, correct_RNA_type := "genome"]
  
  data.table::rbindlist(list(dt_no, dt_g), use.names = TRUE, fill = TRUE)
})
names(sample_tbls_with_genome) <- names(sample_tbls_anno)

saveRDS(sample_tbls_with_genome, file.path(cfg$out_dir, "sample_tables_with_genome_annotated.rds"))

# --- E) Build count matrix (features are "orig_seq + correct_annotation")
count_mat <- make_count_matrix(sample_tbls_with_genome)
saveRDS(count_mat, file.path(cfg$out_dir, "count_matrix.rds"))

# --- F) CPM + filtering
flt <- cpm_filter(count_mat, cfg$cpm_threshold, cfg$cpm_min_samples)
saveRDS(flt$cpm,  file.path(cfg$out_dir, "cpm_matrix.rds"))
saveRDS(flt$keep, file.path(cfg$out_dir, "cpm_keep_mask.rds"))

# Save filtered matrices 
count_mat_keep <- count_mat[flt$keep, , drop = FALSE]
cpm_keep <- flt$cpm[flt$keep, , drop = FALSE]
saveRDS(count_mat_keep, file.path(cfg$out_dir, "count_matrix_cpm_filtered.rds"))
saveRDS(cpm_keep,       file.path(cfg$out_dir, "cpm_matrix_filtered.rds"))

# --- G) QC plots 
# 1) Composition per "set" (ELAC3S/GFP3S/WT3S/...)
plot_composition <- function(sample_tables, out_png) {
  dt <- data.table::rbindlist(sample_tables, use.names = TRUE, fill = TRUE)
  dt_sum <- dt[, .N, by = .(set, correct_RNA_type)]
  dt_sum[, perc := 100 * N / sum(N), by = set]
  
  p <- ggplot2::ggplot(
    dt_sum,
    ggplot2::aes(x = "", y = perc, fill = correct_RNA_type)
  ) +
    ggplot2::geom_col(color = "white") +
    ggplot2::coord_polar(theta = "y") +
    ggplot2::facet_wrap(~set, nrow = 2) +
    ggplot2::theme_void() +
    ggplot2::theme(
      legend.title = ggplot2::element_blank()
    )
  
  ggplot2::ggsave(out_png, p, width = 12, height = 6, dpi = 300)
}

plot_composition(
  sample_tbls_with_genome,
  file.path(cfg$out_dir, "rna_type_composition_by_set.png")
)

message("Done. Outputs written to: ", cfg$out_dir)

script_3.R

# Purpose
#   Evaluate integration quality on a chosen embedding (PCA):
#     - iLISI: local mixing by batch/condition (higher = better mixing)
#     - cLISI-derived purity: local purity by cell type (higher = better purity)
#     - kBET: average rejection rate of local batch-mixing tests (lower = better mixing)


# ----------------------------
# 0) Package checks
# ----------------------------
check_pkg <- function(pkg) {
  if (!requireNamespace(pkg, quietly = TRUE)) {
    stop("Missing package: ", pkg, " (install it first).", call. = FALSE)
  }
}
check_pkg("Seurat")
check_pkg("lisi")   # immunogenomics/LISI (R package name: lisi)
check_pkg("kBET")
check_pkg("ggplot2")
check_pkg("dplyr")

# ----------------------------
# 1) Core metric function
# ----------------------------
score_embedding <- function(obj,
                            condition_col   = "condition",
                            celltype_col    = "final_population",
                            reduction       = "pca",
                            dims            = 1:30,
                            lisi_perplexity = 30,   # effective neighborhood size in LISI (not kNN k)
                            kbet_k0         = 15,   # neighborhood size for kBET
                            kbet_max_cells  = 5000, # subsample for speed (stratified by condition)
                            seed            = 1) {
  
  md <- obj@meta.data
  stopifnot(all(c(condition_col, celltype_col) %in% colnames(md)))
  stopifnot(reduction %in% names(obj@reductions))
  
  emb_all <- Seurat::Embeddings(obj, reduction = reduction)
  stopifnot(max(dims) <= ncol(emb_all))
  emb <- emb_all[, dims, drop = FALSE]
  
  # Keep only complete cases for the two labels
  keep <- stats::complete.cases(md[, c(condition_col, celltype_col), drop = FALSE])
  emb  <- emb[keep, , drop = FALSE]
  md   <- md[keep, , drop = FALSE]
  
  # Force factors (required for correct level counting)
  batch    <- factor(md[[condition_col]])
  celltype <- factor(md[[celltype_col]])
  
  B <- nlevels(batch)
  C <- nlevels(celltype)
  
  # ---- LISI (per-cell) ----
  # compute_lisi returns a data.frame with one column per label in label_colnames
  lisi_vals <- lisi::compute_lisi(
    X              = emb,
    meta_data      = md[, c(condition_col, celltype_col), drop = FALSE],
    label_colnames = c(condition_col, celltype_col),
    perplexity     = lisi_perplexity
  )
  
  iLISI <- lisi_vals[[condition_col]]
  cLISI <- lisi_vals[[celltype_col]]
  
  # Normalize iLISI to 0..1 where 0 = no mixing, 1 = maximal mixing given B batches
  iLISI_norm <- (iLISI - 1) / max(1, B - 1)
  
  # Convert cLISI to a "purity" score in 0..1 where 1 = perfectly pure neighborhoods
  # (cLISI = 1 means only one cell type locally; cLISI = C means full mixing)
  cPUR_norm <- (C - cLISI) / max(1, C - 1)
  
  stats_fun <- function(x) {
    c(
      mean   = mean(x, na.rm = TRUE),
      median = stats::median(x, na.rm = TRUE),
      p10    = unname(stats::quantile(x, 0.10, na.rm = TRUE)),
      p90    = unname(stats::quantile(x, 0.90, na.rm = TRUE))
    )
  }
  
  # ---- kBET ----
  set.seed(seed)
  idx <- seq_len(nrow(emb))
  if (length(idx) > kbet_max_cells) {
    by_batch <- split(idx, batch)
    sizes <- vapply(by_batch, length, integer(1))
    
    # Allocate approximately proportionally (guarantee at least 1 per batch if possible)
    alloc <- pmax(1L, floor(kbet_max_cells * sizes / sum(sizes)))
    # Trim if we overshot due to pmax(1)
    while (sum(alloc) > kbet_max_cells) {
      j <- which.max(alloc)
      if (alloc[j] > 1L) alloc[j] <- alloc[j] - 1L else break
    }
    
    idx <- sort(unlist(Map(function(ix, m) sample(ix, size = min(m, length(ix))), by_batch, alloc), use.names = FALSE))
  }
  
  X_sub     <- emb[idx, , drop = FALSE]
  batch_sub <- factor(batch[idx])
  
  # Use k0 and turn off heuristic so kBET does not silently change neighborhood size
  kbet_res <- tryCatch(
    kBET::kBET(
      df        = X_sub,
      batch     = batch_sub,
      k0        = kbet_k0,
      heuristic = FALSE,
      do.pca    = FALSE,
      plot      = FALSE,
      verbose   = FALSE
    ),
    error = function(e) e
  )
  
  kBET_mean_reject <- NA_real_
  kBET_p025 <- NA_real_
  kBET_p975 <- NA_real_
  
  if (!inherits(kbet_res, "error") && is.list(kbet_res)) {
    if (!is.null(kbet_res$stats) && "kBET.observed" %in% colnames(kbet_res$stats)) {
      obs <- kbet_res$stats[, "kBET.observed"]
      kBET_mean_reject <- mean(obs, na.rm = TRUE)
      kBET_p025 <- unname(stats::quantile(obs, 0.025, na.rm = TRUE))
      kBET_p975 <- unname(stats::quantile(obs, 0.975, na.rm = TRUE))
    } else if (!is.null(kbet_res$summary) && "kBET.observed" %in% names(kbet_res$summary)) {
      # Fallback: single observed rate from the summary
      kBET_mean_reject <- unname(kbet_res$summary[["kBET.observed"]])
    }
  }
  
  list(
    n_cells          = nrow(emb),
    n_batches        = B,
    n_celltypes      = C,
    iLISI_norm_stats = stats_fun(iLISI_norm),
    cPUR_norm_stats  = stats_fun(cPUR_norm),
    kBET_mean_reject = kBET_mean_reject,
    kBET_p025        = kBET_p025,
    kBET_p975        = kBET_p975
  )
}

# ----------------------------
# 2) Helpers
# ----------------------------
read_result_rdata <- function(path, object_name = NULL) {
  e <- new.env(parent = emptyenv())
  objs <- load(path, envir = e)
  
  if (!is.null(object_name)) {
    if (!object_name %in% objs) stop("Object '", object_name, "' not found in: ", path, call. = FALSE)
    return(e[[object_name]])
  }
  
  # Pick the first list that contains expected fields
  is_score <- function(x) {
    is.list(x) && all(c("iLISI_norm_stats", "cPUR_norm_stats", "kBET_mean_reject") %in% names(x))
  }
  hits <- objs[vapply(objs, function(nm) is_score(e[[nm]]), logical(1))]
  if (length(hits) == 0) stop("No score-like object found in: ", path, call. = FALSE)
  e[[hits[1]]]
}

# ----------------------------
# 3) Summarize and plot (three metrics)
# ----------------------------
summarize_scores <- function(res_list) {
  # res_list: named list of score_embedding() outputs, names are approach labels
  dplyr::bind_rows(lapply(names(res_list), function(nm) {
    res <- res_list[[nm]]
    data.frame(
      approach     = nm,
      n_cells      = res$n_cells,
      n_conditions = res$n_batches,
      n_celltypes  = res$n_celltypes,
      
      iLISI_med = unname(res$iLISI_norm_stats["median"]),
      iLISI_p10 = unname(res$iLISI_norm_stats["p10"]),
      iLISI_p90 = unname(res$iLISI_norm_stats["p90"]),
      
      cPUR_med  = unname(res$cPUR_norm_stats["median"]),
      cPUR_p10  = unname(res$cPUR_norm_stats["p10"]),
      cPUR_p90  = unname(res$cPUR_norm_stats["p90"]),
      
      kBET_mean = unname(res$kBET_mean_reject),
      kBET_p025 = unname(res$kBET_p025),
      kBET_p975 = unname(res$kBET_p975),
      
      stringsAsFactors = FALSE
    )
  }))
}

make_plot_df <- function(df_sum) {
  dplyr::bind_rows(
    dplyr::transmute(
      df_sum,
      approach,
      metric = "Condition mixing (iLISI, normalized)",
      center = iLISI_med, low = iLISI_p10, high = iLISI_p90
    ),
    dplyr::transmute(
      df_sum,
      approach,
      metric = "Cell-type purity (cLISI-derived, normalized)",
      center = cPUR_med, low = cPUR_p10, high = cPUR_p90
    ),
    dplyr::transmute(
      df_sum,
      approach,
      metric = "kBET rejection rate by condition",
      center = kBET_mean, low = kBET_p025, high = kBET_p975
    )
  )
}

plot_integration_metrics <- function(df_plot, approach_levels = NULL) {
  if (!is.null(approach_levels)) {
    df_plot$approach <- factor(df_plot$approach, levels = approach_levels)
  } else {
    df_plot$approach <- factor(df_plot$approach, levels = rev(unique(df_plot$approach)))
  }
  
  ggplot2::ggplot(df_plot, ggplot2::aes(x = approach, y = center)) +
    ggplot2::geom_pointrange(
      ggplot2::aes(ymin = low, ymax = high),
      color = "grey60",
      na.rm = TRUE
    ) +
    ggplot2::geom_point(color = "black", size = 2, na.rm = TRUE) +
    ggplot2::coord_flip() +
    ggplot2::facet_wrap(~ metric, scales = "fixed") +
    ggplot2::theme_classic() +
    ggplot2::labs(x = NULL, y = NULL)
}


# res1 <- score_embedding(INTEGR_WEG0_PJ, condition_col = "condition", celltype_col = "CellType")
# res2 <- score_embedding(integrated_seurat_obj, condition_col = "condition", celltype_col = "possibly_final_anno")
# res3 <- score_embedding(result_obj, condition_col = "condition", celltype_col = "final_population")
# saveRDS(res1, "res1.rds"); saveRDS(res2, "res2.rds"); saveRDS(res3, "res3.rds")

# res_paths <- c(
#   "WT-only integrated (Harmony)"          = "G:/PhD_final/integration_presentation/res1.RData",
#   "All samples integrated (Harmony)"      = "G:/PhD_final/integration_presentation/res2.RData",
#   "All samples merged (no integration)"   = "G:/PhD_final/integration_presentation/res3.RData"
# )
# res_list <- lapply(res_paths, read_result_rdata)
# 
# df_sum  <- summarize_scores(res_list)
# df_plot <- make_plot_df(df_sum)
# 
# p <- plot_integration_metrics(df_plot, approach_levels = rev(names(res_paths)))
# 
# print(df_sum)
# print(p)

script_4.R

# Auto-annotation pipeline 

suppressPackageStartupMessages({
  library(Seurat)
  library(Matrix)
  library(dplyr)
  library(tidyr)
  library(stringr)
  library(purrr)
  library(openxlsx)
  library(igraph)
})
suppressPackageStartupMessages(library(future))
plan(sequential)

if (.Platform$OS.type == "windows" && exists("memory.limit")) {
  try(suppressWarnings(memory.limit(size = 56000)), silent = TRUE)
}

# ---------------------------
# Config (tuned defaults)
# ---------------------------
cfg <- list(
  paths = list(
    seurat_rdata   = "D:/scRNA-seq/AZ_final_obj/seurat_obj_new.RData",
    matrix_rds     = NULL,
    markers_xlsx   = "G:/PhD_final/cell_markers_curated_new_new.xlsx",
    anno_rdata     = "E:/Stringtie_anno/SM_anno/final/final_final/pfam_swiss_ncbi_merged_only_genes_dedup.RData",
    out_xlsx       = sprintf(
      "G:/PhD_final/auto_annotation_%s.xlsx",
      format(Sys.time(), "%Y%m%d_%H%M")
    )
  ),
  ckpt_dir        = "G:/PhD_final/sncRNA/.auto_annot_ckpts",
  use_qs          = TRUE,
  qs_preset       = "balanced",
  
  base_assay      = "SCT",
  pca_name        = "pca.auto",
  umap_name       = "umap.auto",
  max_pcs         = 60L,
  variance_cut    = 0.90,
  knee_smooth     = 5L,
  
  target_n_clusters = 60L,
  k_grid            = c(5L, 8L, 10L, 15L, 20L),
  res_grid          = c(seq(0.4, 2.0, by = 0.2), 1.8, 2.0),
  res_init          = 0.6,
  res_max           = 10,
  grid_max_steps    = 20,
  
  tiny_frac_cut     = 0.015,
  agree_cut         = 0.75,
  
  sub_npcs          = 30L,
  seed              = 42L,
  sub_k_grid        = c(10L, 15L, 20L, 25L),
  sub_res_grid      = seq(0.5, 3.5, by = 0.25),
  sub_min_cells_for_split = 30L,
  sub_max_children  = 6L,
  sub_min_child_n    = 10L,
  sub_min_child_prop = 0.005,
  
  annot_features_max    = 1500L,
  annot_skip_deg        = TRUE,
  deg_subsample_per_ident = 2000L,
  
  enable_ucell     = TRUE,
  ucell_min_genes  = 3L,
  ucell_cells_per_cluster = 1000L,
  ucell_max_signatures   = 300L,
  ucell_ncores     = 1L,
  
  write_round1_degs = TRUE
)
set.seed(cfg$seed)

# Optional packages
has_ucell     <- requireNamespace("UCell", quietly = TRUE)
has_fgsea     <- requireNamespace("fgsea", quietly = TRUE)
has_cellmanam <- requireNamespace("CellMaNam", quietly = TRUE)
has_qs        <- requireNamespace("qs", quietly = TRUE)
has_digest    <- requireNamespace("digest", quietly = TRUE)

# ---------------------------
# Helpers
# ---------------------------
`%||%` <- function(a, b) if (!is.null(a)) a else b
trim <- function(x)
  gsub("^\\s+|\\s+$", "", x)
canon_cluster <- function(v) {
  v <- as.character(v)
  v <- trimws(v)
  v <- sub("^X([0-9]+)$", "g\\1", v)
  v <- sub("^([0-9]+)$", "g\\1", v)
  v
}
sanitize_sheet <- function(x) {
  x <- gsub("[\\*\\?/\\\\\\[\\]:]", "_", x)
  x <- substr(x, 1, 31)
  make.unique(x)
}
pick_layer_arg <- function() {
  if ("layer" %in% names(formals(Seurat::FindAllMarkers)))
    "layer"
  else
    "slot"
}
apply_mapping <- function(keys, map_named) {
  out <- rep(NA_character_, length(keys))
  m <- match(keys, names(map_named))
  hit <- !is.na(m)
  out[hit] <- unname(map_named[m[hit]])
  out
}
as_chr_collapse <- function(x) {
  if (is.null(x)) return("")
  if (is.list(x)) x <- unlist(x, recursive = TRUE, use.names = FALSE)
  x <- unique(na.omit(as.character(x)))
  if (!length(x)) "" else paste(x, collapse = "; ")
}


# IO/ckpt ------------------------------------------------------------
CKPT_DIR <- cfg$ckpt_dir
.dir_ok <- function() {
  dir.create(CKPT_DIR, showWarnings = FALSE, recursive = TRUE)
  TRUE
}
ckpt_path_qs  <- function(stage)
  file.path(CKPT_DIR, paste0("auto_annot_ckpt_", stage, ".qs"))
ckpt_path_rds <- function(stage)
  file.path(CKPT_DIR, paste0("auto_annot_ckpt_", stage, ".rds"))
ckpt_has  <- function(stage)
  file.exists(ckpt_path_qs(stage)) ||
  file.exists(ckpt_path_rds(stage))
ckpt_save <- function(stage, value) {
  .dir_ok()
  if (has_qs &&
      isTRUE(cfg$use_qs))
    qs::qsave(value, ckpt_path_qs(stage), preset = cfg$qs_preset)
  else
    saveRDS(value, ckpt_path_rds(stage))
}
ckpt_load <- function(stage) {
  if (file.exists(ckpt_path_qs(stage)))
    return(qs::qread(ckpt_path_qs(stage)))
  readRDS(ckpt_path_rds(stage))
}
wb_load_or_new <- function(path) if (file.exists(path)) openxlsx::loadWorkbook(path) else openxlsx::createWorkbook()
wb_save <- function(wb, path)
  openxlsx::saveWorkbook(wb, path, overwrite = TRUE)
ckpt_update <- function(stage, st, ...) {
  up <- list(...)
  for (nm in names(up))
    st[[nm]] <- up[[nm]]
  ckpt_save(stage, st)
}

# Memory diet ------------------------------------------------------------
diet_for_checkpoint <- function(obj,
                                keep_assay  = cfg$base_assay,
                                keep_reduc  = c(cfg$pca_name, cfg$umap_name, paste0(cfg$umap_name, ".v2")),
                                keep_graphs = character(),
                                drop_counts = TRUE,
                                drop_scale  = TRUE) {
  DefaultAssay(obj) <- keep_assay
  obj@assays     <- obj@assays[intersect(names(obj@assays), keep_assay)]
  obj@reductions <- obj@reductions[intersect(names(obj@reductions), keep_reduc)]
  obj@graphs     <- obj@graphs[intersect(names(obj@graphs), keep_graphs)]
  
  v <- tryCatch(
    utils::packageVersion("SeuratObject"),
    error = function(e)
      package_version("4.0.0")
  )
  if (v >= package_version("5.0.0")) {
    Seurat::DietSeurat(
      obj,
      assays = names(obj@assays),
      dimreducs = names(obj@reductions),
      graphs = names(obj@graphs),
      layers = setNames(list("data"), names(obj@assays))
    )
  } else {
    Seurat::DietSeurat(
      obj,
      assays = names(obj@assays),
      counts = !drop_counts,
      data = TRUE,
      scale.data = !drop_scale,
      dimreducs = names(obj@reductions),
      graphs = names(obj@graphs),
      features = NULL
    )
  }
}

# Stats helpers ----------------------------------------------------------
.named_or_fail <- function(x, expect_names, method_name) {
  if (!length(x))
    return(x)
  if (is.null(names(x)) || any(is.na(names(x)) | names(x) == "")) {
    if (!missing(expect_names) && length(x) == length(expect_names)) {
      names(x) <- as.character(expect_names)
    } else {
      stop(
        sprintf(
          "Annotation method produced an unnamed vector; method=%s len=%d",
          method_name,
          length(x)
        )
      )
    }
  }
  names(x) <- canon_cluster(names(x))
  x
}
.run_annot_method <- function(label,
                              fun,
                              expect_clusters,
                              ckpt_stage_tag = "annot") {
  message(sprintf("[annot] %s: running...", label))
  v <- tryCatch(
    fun(),
    error = function(e) {
      message(sprintf("[annot] %s: ERROR -> %s", label, conditionMessage(e)))
      return(setNames(character(0), character(0)))
    }
  )
  v <- .named_or_fail(v, expect_clusters, label)
  if (length(v))
    v <- v[names(v) %in% as.character(expect_clusters)]
  tag <- paste0(ckpt_stage_tag, "_", label)
  try(ckpt_save(
    tag,
    list(
      method = label,
      result = v,
      expect_clusters = as.character(expect_clusters),
      saved_at = Sys.time()
    )
  ), silent = TRUE)
  message(sprintf("[annot] %s: %d label(s).", label, length(v)))
  v
}

choose_pcs_by_knee <- function(stdev,
                               max_pcs = 50L,
                               variance_cut = 0.90,
                               smooth_k = 5L,
                               min_pcs = 25L) {
  stdev <- stdev[is.finite(stdev) & stdev > 0]
  stdev <- stdev[seq_len(min(length(stdev), max_pcs))]
  if (!length(stdev))
    return(min_pcs)
  
  var_ratio <- stdev^2 / sum(stdev^2)
  cumvar <- cumsum(var_ratio)
  ceil <- which(cumvar >= variance_cut)[1]
  if (is.na(ceil))
    ceil <- length(var_ratio)
  
  y <- var_ratio
  if (length(y) >= (smooth_k * 2 + 1)) {
    y <- stats::filter(y, rep(1 / (smooth_k * 2 + 1), smooth_k * 2 + 1), sides = 2)
    y[is.na(y)] <- var_ratio[is.na(y)]
  }
  d2 <- diff(y, differences = 2)
  knee <- which.min(d2) + 1L
  
  pcs <- max(min_pcs, min(max(10L, knee), ceil))
  pcs <- min(pcs, length(stdev))
  pcs
}

avg_by_group_sparse <- function(X, groups) {
  stopifnot(inherits(X, "dgCMatrix"))
  groups <- droplevels(factor(groups))
  G <- Matrix::sparse.model.matrix(~ groups - 1)
  sums <- X %*% G
  n_per <- Matrix::colSums(G)
  Dinv <- Matrix::Diagonal(x = as.numeric(1 / n_per))
  avg <- sums %*% Dinv
  colnames(avg) <- levels(groups)
  avg
}
get_graph_modularity <- function(obj, graph.name, membership) {
  S <- obj@graphs[[graph.name]]
  if (is.null(S))
    return(NA_real_)
  if (!methods::is(S, "dgCMatrix"))
    S <- as(S, "dgCMatrix")
  S <- Matrix::drop0((S + Matrix::t(S)) / 2)
  Matrix::diag(S) <- 0
  trip <- as.data.frame(Matrix::summary(S))
  trip <- trip[trip$i < trip$j &
                 trip$x > 0, c("i", "j", "x"), drop = FALSE]
  if (!nrow(trip))
    return(NA_real_)
  vnames <- colnames(S)
  if (is.null(vnames))
    vnames <- seq_len(ncol(S))
  g <- igraph::graph_from_data_frame(
    data.frame(
      from = vnames[trip$i],
      to = vnames[trip$j],
      weight = trip$x
    ),
    directed = FALSE,
    vertices = data.frame(name = vnames)
  )
  mvec <- membership
  if (!is.null(names(mvec)))
    mvec <- mvec[V(g)$name]
  else
    names(mvec) <- V(g)$name
  memb <- as.integer(factor(mvec))
  igraph::modularity(g, membership = memb, weights = igraph::E(g)$weight)
}
score_solution <- function(obj,
                           graph.name,
                           idents,
                           target_n,
                           tiny_cut) {
  memb <- as.character(idents)
  nclu <- length(unique(memb))
  n    <- length(memb)
  tab  <- sort(table(memb), decreasing = TRUE)
  tiny_frac <- if (length(tab))
    sum(tab < (tiny_cut * n)) / length(tab)
  else
    1
  mod <- suppressWarnings(get_graph_modularity(obj, graph.name, memb))
  if (!is.finite(mod) || mod < 1e-6) {
    return(data.frame(
      n_clusters = nclu,
      modularity = mod,
      tiny_frac = tiny_frac,
      score = -Inf
    ))
  }
  if (is.na(target_n))
    close <- 0
  else {
    close <- 1 - (abs(nclu - target_n) / max(target_n, nclu))
    close <- max(0, min(close, 1))
  }
  tiny_pen <- pmin(0.4, tiny_frac * 0.8)
  score <- 0.70 * close + 0.25 * mod - 0.05 * tiny_pen
  data.frame(
    n_clusters = nclu,
    modularity = mod,
    tiny_frac = tiny_frac,
    score = score
  )
}

# Marker prep ------------------------------------------------------------
prepare_marker_ref <- function(cell_markers_df,
                               gene_col = "Markers_positive_SMESG",
                               general_col = "Cell_population_general",
                               detailed_col = "Cell_population_detailed",
                               neg_col = NULL,
                               weight_col = NULL) {
  gene_col <- match.arg(gene_col)
  df <- as.data.frame(cell_markers_df, stringsAsFactors = FALSE)
  norm <- function(v)
    trim(as.character(v))
  gene     <- norm(df[[gene_col]])
  general  <- norm(df[[general_col]])
  detailed <- norm(df[[detailed_col]])
  neg_vec <- if (!is.null(neg_col) &&
                 neg_col %in% names(df))
    as.logical(df[[neg_col]])
  else
    FALSE
  weight_vec <- if (!is.null(weight_col) &&
                    weight_col %in% names(df))
    suppressWarnings(as.numeric(df[[weight_col]]))
  else
    NA_real_
  general[general %in% c("Protonephridia", "Protonephridia ")] <- "Protonephridia"
  detailed[detailed %in% c("Protonephridial tubule precursor ",
                           "Protonephridia tubule precursor")] <- "Protonephridial tubule precursor"
  pos <- tibble::tibble(
    gene = gene,
    general = general,
    detailed = detailed,
    neg = neg_vec,
    weight = weight_vec
  ) %>%
    dplyr::filter(gene != "" &
                    general != "" &
                    detailed != "") %>% dplyr::distinct()
  ref_general  <- pos %>% dplyr::transmute(gene, final_cluster = general, weight, neg)
  ref_detailed <- pos %>% dplyr::transmute(gene,
                                           final_cluster = detailed,
                                           parent_general = general,
                                           weight,
                                           neg)
  list(general = ref_general, detailed = ref_detailed)
}
.bg_genes <- function(obj)
  rownames(Seurat::GetAssayData(obj, assay = cfg$base_assay, layer = "data"))

# DEG robust + cache ------------------------------------------------------
compute_degs_robust <- function(obj, group_col, features_whitelist = NULL) {
  DefaultAssay(obj) <- cfg$base_assay
  stopifnot(group_col %in% colnames(obj@meta.data))
  pick <- function(df, cands, default) {
    for (nm in cands)
      if (nm %in% names(df))
        return(df[[nm]])
    if (is.function(default))
      return(default())
    rep(default, nrow(df))
  }
  if (is.null(features_whitelist)) {
    bg   <- rownames(Seurat::GetAssayData(obj, assay = cfg$base_assay, layer = "data"))
    hvgs <- tryCatch(
      VariableFeatures(obj),
      error = function(e)
        character(0)
    )
    if (!length(hvgs)) {
      obj <- FindVariableFeatures(
        obj,
        assay = cfg$base_assay,
        nfeatures = 3000,
        verbose = FALSE
      )
      hvgs <- VariableFeatures(obj)
    }
    misc_markers <- tryCatch(
      unique(unlist(obj@misc$marker_genes)),
      error = function(e)
        character(0)
    )
    features_whitelist <- unique(intersect(union(hvgs, misc_markers), bg))
    if (!length(features_whitelist))
      features_whitelist <- hvgs[hvgs %in% bg]
    if (!length(features_whitelist))
      features_whitelist <- bg
  }
  layer_or_slot <- pick_layer_arg()
  old_id <- Idents(obj)
  on.exit(Idents(obj) <- old_id, add = TRUE)
  Idents(obj) <- obj[[group_col]][, 1]
  lv <- levels(Idents(obj))
  if (!length(lv)) {
    message("compute_degs_robust: no levels in '", group_col, "'.")
    return(
      tibble::tibble(
        cluster = character(),
        gene = character(),
        avg_log2FC = double(),
        p_val_adj = double()
      )
    )
  }
  pieces <- lapply(lv, function(cl) {
    args <- list(
      object = obj,
      ident.1 = cl,
      only.pos = TRUE,
      min.pct = 0.20,
      logfc.threshold = 0.25,
      test.use = "wilcox",
      verbose = FALSE,
      random.seed = cfg$seed,
      features = features_whitelist
    )
    if (!is.null(cfg$deg_subsample_per_ident))
      args$max.cells.per.ident <- cfg$deg_subsample_per_ident
    if (layer_or_slot == "layer")
      args$layer <- "data"
    else
      args$slot <- "data"
    fm <- tryCatch(
      do.call(Seurat::FindMarkers, args),
      error = function(e)
        NULL
    )
    if (is.null(fm) || !nrow(fm))
      return(NULL)
    tibble::tibble(
      cluster    = canon_cluster(cl),
      gene       = rownames(fm),
      avg_log2FC = suppressWarnings(pick(
        fm, c("avg_log2FC", "avg_logFC", "log2FC"), NA_real_
      )),
      p_val_adj  = suppressWarnings(pick(
        fm,
        c("p_val_adj", "p_val.adj", "p_val_adj_fdr", "p_val"),
        1
      ))
    ) %>% dplyr::filter(is.finite(avg_log2FC),
                        is.finite(p_val_adj),
                        p_val_adj < 0.05,
                        avg_log2FC > 0)
  })
  res <- dplyr::bind_rows(pieces)
  if (!nrow(res)) {
    message("compute_degs_robust: no DEGs passed filters.")
    return(
      tibble::tibble(
        cluster = character(),
        gene = character(),
        avg_log2FC = double(),
        p_val_adj = double()
      )
    )
  }
  res
}
.safe_hash <- function(x) {
  x <- paste(x, collapse = "|")
  if (has_digest)
    digest::digest(x, algo = "xxhash64")
  else
    sprintf("h%08x", abs(as.integer(sum(utf8ToInt(
      x
    ))) %% 2^31))
}
.deg_ckpt_tag <- function(obj, group_col, features = NULL) {
  memb <- canon_cluster(as.character(obj[[group_col]][, 1]))
  sz   <- sort(as.integer(table(memb)), decreasing = TRUE)
  features <- sort(unique(as.character(features %||% character(0))))
  feat_stub <- features[seq_len(min(200L, length(features)))]
  .safe_hash(c(
    sprintf("n=%d", length(memb)),
    sprintf("k=%d", length(sz)),
    paste0("sz:", paste(sz, collapse = ",")),
    sprintf("f=%d", length(features)),
    paste0("feat:", paste(feat_stub, collapse = ","))
  ))
}
.deg_ckpt_file <- function(prefix, tag)
  file.path(CKPT_DIR, sprintf("deg_%s_%s.rds", prefix, tag))
deg_ckpt_save <- function(prefix,
                          obj,
                          group_col,
                          degs,
                          features = NULL,
                          extra = list()) {
  .dir_ok()
  tag <- .deg_ckpt_tag(obj, group_col, features)
  saveRDS(c(
    list(
      saved_at = Sys.time(),
      prefix = prefix,
      group_col = group_col,
      tag = tag,
      features = features,
      degs = degs
    ),
    extra
  ), .deg_ckpt_file(prefix, tag))
}
deg_ckpt_load <- function(prefix, obj, group_col, features = NULL) {
  tag <- .deg_ckpt_tag(obj, group_col, features)
  f <- .deg_ckpt_file(prefix, tag)
  if (!file.exists(f))
    return(NULL)
  x <- readRDS(f)
  if (!is.list(x) || is.null(x$degs))
    return(NULL)
  x$degs
}

# Neighbors/Clustering ----------------------------------------------------
run_neighbors_if_needed <- function(obj, dims, k, reduction = cfg$pca_name) {
  gname <- paste0(cfg$base_assay, "_snn_k", k)
  if (is.null(obj@graphs[[gname]])) {
    obj <- FindNeighbors(
      obj,
      reduction = reduction,
      dims = 1:dims,
      k.param = k,
      graph.name = gname,
      verbose = FALSE
    )
  }
  obj
}
safe_findclusters <- function(obj, graph.name, resolution, seed = cfg$seed) {
  ok <- FALSE
  res <- NULL
  for (alg in c(4, 3, 1)) {
    res <- try(FindClusters(
      obj,
      graph.name = graph.name,
      resolution = resolution,
      algorithm = alg,
      random.seed = seed,
      verbose = FALSE
    ),
    silent = TRUE)
    if (!inherits(res, "try-error")) {
      ok <- TRUE
      break
    }
  }
  if (!ok)
    stop("FindClusters failed for graph=",
         graph.name,
         " res=",
         resolution)
  res
}

# Knee-based resolution picker -------------------------------------------
pick_res_by_knee <- function(res_vals, nclu_vals, target_n = NA_integer_) {
  stopifnot(length(res_vals) == length(nclu_vals), length(res_vals) > 0)
  o <- order(res_vals)
  x <- as.numeric(res_vals[o])
  y <- as.numeric(nclu_vals[o])
  y <- cummax(y)
  dy  <- c(NA_real_, diff(y))
  d2  <- c(NA_real_, diff(dy))
  elbow_ix <- suppressWarnings(which.max(replace(-d2, is.na(d2), -Inf)))
  if (length(elbow_ix) == 0 ||
      is.infinite(elbow_ix) || is.na(elbow_ix) || elbow_ix < 1) {
    if (is.na(target_n))
      return(x[ceiling(length(x) / 2)])
    return(x[which.min(abs(y - target_n))])
  }
  x[elbow_ix]
}

# Grid search with stability & memory care ---------------------------
grid_search_clusters <- function(obj,
                                 dims,
                                 target_n = cfg$target_n_clusters,
                                 k_grid = c(5L, 8L, 10L, 15L, 20L),
                                 res_init = 0.6,
                                 res_max  = 10,
                                 max_steps = 20) {
  attempts <- list()
  best <- NULL
  best_membership <- NULL
  best_k <- NULL
  best_r <- NULL
  
  for (k in k_grid) {
    obj_k <- run_neighbors_if_needed(obj, dims, k)
    gname <- paste0(cfg$base_assay, "_snn_k", k)
    
    eval_r <- function(r) {
      x <- safe_findclusters(
        obj_k,
        graph.name = gname,
        resolution = r,
        seed = cfg$seed
      )
      memb <- as.character(Idents(x))
      names(memb) <- colnames(x)
      rec  <- score_solution(x, gname, memb, target_n, cfg$tiny_frac_cut)
      rec$k <- k
      rec$resolution <- r
      list(rec = rec,
           memb = memb,
           obj = x)
    }
    
    r_lo <- res_init
    e_lo <- eval_r(r_lo)
    n_lo <- e_lo$rec$n_clusters
    attempts[[length(attempts) + 1]] <- e_lo$rec
    
    if (is.na(target_n)) {
      for (r in seq(res_init, min(res_max, res_init + 4), by = 0.4)) {
        e <- eval_r(r)
        attempts[[length(attempts) + 1]] <- e$rec
        if (is.null(best) ||
            e$rec$score > best$score) {
          best <- e$rec
          best_membership <- e$memb
          best_k <- k
          best_r <- r
        }
      }
      next
    }
    
    r_hi <- r_lo
    e_hi <- e_lo
    n_hi <- n_lo
    step <- 0
    while (n_hi < target_n && r_hi < res_max && step < max_steps) {
      r_hi <- r_hi * 1.5
      if (r_hi <= r_lo)
        r_hi <- r_lo + 0.2
      e_hi <- eval_r(r_hi)
      attempts[[length(attempts) + 1]] <- e_hi$rec
      n_hi <- e_hi$rec$n_clusters
      step <- step + 1
    }
    
    cand_list <- list(e_lo, e_hi)
    if (n_hi < target_n) {
      e_closest <- cand_list[[which.min(abs(c(n_lo, n_hi) - target_n))]]
      if (is.null(best) || e_closest$rec$score > best$score) {
        best <- e_closest$rec
        best_membership <- e_closest$memb
        best_k <- k
        best_r <- e_closest$rec$resolution
      }
      next
    }
    
    l_r <- r_lo
    l_e <- e_lo
    l_n <- n_lo
    h_r <- r_hi
    h_e <- e_hi
    h_n <- n_hi
    
    it <- 0
    while (it < max_steps && (abs(h_r - l_r) > 0.05)) {
      it <- it + 1
      m_r <- (l_r + h_r) / 2
      m_e <- eval_r(m_r)
      attempts[[length(attempts) + 1]] <- m_e$rec
      m_n <- m_e$rec$n_clusters
      if (m_n < target_n) {
        l_r <- m_r
        l_e <- m_e
        l_n <- m_n
      } else {
        h_r <- m_r
        h_e <- m_e
        h_n <- m_n
      }
    }
    
    final_e <- if (abs(l_n - target_n) <= abs(h_n - target_n))
      l_e
    else
      h_e
    if (is.null(best) || final_e$rec$score > best$score) {
      best <- final_e$rec
      best_membership <- final_e$memb
      best_k <- k
      best_r <- final_e$rec$resolution
    }
    gc(FALSE)
  }
  
  stopifnot(!is.null(best))
  cl_fac <- factor(best_membership, levels = unique(best_membership))
  names(cl_fac) <- names(best_membership)
  Idents(obj) <- cl_fac[colnames(obj)]
  obj$seurat_clusters <- Idents(obj)
  
  keep_graph <- paste0(cfg$base_assay, "_snn_k", best_k)
  obj@graphs <- obj@graphs[intersect(names(obj@graphs), keep_graph)]
  
  diag <- dplyr::bind_rows(attempts)
  obj@misc$grid_diag <- diag
  
  list(obj = obj,
       stats = best,
       keep_graph = keep_graph)
}

# Annotation methods ------------------------------------------------------
annotate_avgexp_matrix <- function(avg_mat,
                                   marker_ref,
                                   curated_genes,
                                   top_n = 10) {
  stopifnot(is.matrix(avg_mat) || inherits(avg_mat, "Matrix"))
  feats <- intersect(curated_genes, rownames(avg_mat))
  if (!length(feats))
    return(setNames(character(0), character(0)))
  Z <- t(scale(t(as.matrix(avg_mat[feats, , drop = FALSE]))))
  Z[is.na(Z)] <- 0
  res <- vapply(seq_len(ncol(Z)), function(i) {
    top <- head(names(sort(Z[, i], decreasing = TRUE)), top_n)
    tab <- dplyr::filter(marker_ref, gene %in% top) %>% dplyr::count(final_cluster, sort = TRUE)
    if (nrow(tab) == 0)
      "Unknown"
    else
      tab$final_cluster[1]
  }, FUN.VALUE = character(1))
  names(res) <- colnames(Z)
  res
}
annotate_hypergeom <- function(seurat_clusters_DEG,
                               marker_ref,
                               bg_genes) {
  bg_genes <- unique(bg_genes)
  if (!nrow(seurat_clusters_DEG))
    return(setNames(character(0), character(0)))
  clusters <- unique(seurat_clusters_DEG$cluster)
  res <- sapply(clusters, function(cl) {
    cl_genes <- intersect(unique(seurat_clusters_DEG$gene[seurat_clusters_DEG$cluster == cl]), bg_genes)
    total <- length(bg_genes)
    df <- do.call(rbind, lapply(unique(marker_ref$final_cluster), function(ct) {
      ct_genes <- intersect(unique(marker_ref$gene[marker_ref$final_cluster == ct]), bg_genes)
      overlap <- length(intersect(cl_genes, ct_genes))
      m <- length(ct_genes)
      n <- total - m
      k <- length(cl_genes)
      pval <- stats::phyper(overlap - 1, m, n, k, lower.tail = FALSE)
      data.frame(cell_type = ct, pval = pval)
    }))
    df$p_adj <- p.adjust(df$pval, method = "BH")
    df$cell_type[which.min(df$p_adj)]
  })
  names(res) <- clusters
  res
}
annotate_majority <- function(seurat_clusters_DEG, marker_ref) {
  if (!nrow(seurat_clusters_DEG))
    return(setNames(character(0), character(0)))
  df <- dplyr::inner_join(seurat_clusters_DEG,
                          marker_ref,
                          by = "gene",
                          relationship = "many-to-many") %>%
    dplyr::group_by(cluster, final_cluster) %>% dplyr::summarise(n = dplyr::n(), .groups =
                                                                   "drop") %>%
    dplyr::group_by(cluster) %>% dplyr::slice_max(n, n = 1, with_ties =
                                                    FALSE)
  setNames(df$final_cluster, df$cluster)
}
annotate_logfc <- function(seurat_clusters_DEG,
                           marker_ref,
                           gene_specificity = NULL,
                           use_padj_weight = TRUE,
                           neg_penalty = 0.2) {
  if (!nrow(seurat_clusters_DEG))
    return(setNames(character(0), character(0)))
  df <- dplyr::inner_join(seurat_clusters_DEG,
                          marker_ref,
                          by = "gene",
                          relationship = "many-to-many")
  if (use_padj_weight &&
      "p_val_adj" %in% names(seurat_clusters_DEG)) {
    df <- dplyr::mutate(df, w = pmax(0, -log10(pmin(
      p_val_adj, 1e-300
    ))))
  } else
    df$w <- 1
  if (!is.null(gene_specificity))
    df$spec <- pmax(0.2, gene_specificity[match(df$gene, names(gene_specificity))])
  else
    df$spec <- 1
  if ("neg" %in% names(df) &&
      any(!is.na(df$neg)))
    df$neg_w <- ifelse(isTRUE(df$neg), -neg_penalty, 0)
  else
    df$neg_w <- 0
  df <- df %>%
    dplyr::group_by(cluster, final_cluster) %>%
    dplyr::summarise(
      score = stats::weighted.mean(pmax(0, avg_log2FC) * spec + neg_w, w, na.rm =
                                     TRUE),
      .groups = "drop"
    ) %>%
    dplyr::group_by(cluster) %>% dplyr::slice_max(score, n = 1, with_ties =
                                                    FALSE)
  setNames(df$final_cluster, df$cluster)
}
annotate_cellmanam <- function(seurat_obj,
                               seurat_clusters_DEG,
                               marker_ref,
                               top_n = 2,
                               p_val = 0.01,
                               level = 1) {
  if (!isTRUE(has_cellmanam))
    return(setNames(character(0), character(0)))
  if (is.null(seurat_clusters_DEG) ||
      !nrow(seurat_clusters_DEG))
    return(setNames(character(0), character(0)))
  DefaultAssay(seurat_obj) <- cfg$base_assay
  occ <- tryCatch(
    CellMaNam::calc_occurrence(
      markers_data  = marker_ref,
      features_col  = "gene",
      cell_column   = "final_cluster"
    ),
    error = function(e)
      NULL
  )
  if (is.null(occ) ||
      !nrow(occ))
    return(setNames(character(0), character(0)))
  occ2 <- tryCatch(
    CellMaNam::select_top_occ(occ, top_n = top_n),
    error = function(e)
      NULL
  )
  if (is.null(occ2) ||
      !nrow(occ2))
    return(setNames(character(0), character(0)))
  ann_tbl <- tryCatch(
    CellMaNam::get_annotation(
      cell_markers = seurat_clusters_DEG %>% dplyr::select(cell_annotation = cluster, markers = gene),
      markers_occ  = occ2,
      max_genes    = nrow(
        Seurat::GetAssayData(seurat_obj, assay = cfg$base_assay, layer = "data")
      )
    ),
    error = function(e)
      NULL
  )
  if (is.null(ann_tbl) ||
      !nrow(ann_tbl))
    return(setNames(character(0), character(0)))
  cell_types <- tryCatch(
    CellMaNam::cell_typing(
      annotation_data = ann_tbl,
      hierarchy_data = NULL,
      p_val = p_val,
      level = level,
      hierarchy = FALSE
    ),
    error = function(e)
      NULL
  )
  if (is.null(cell_types) ||
      !nrow(cell_types))
    return(setNames(character(0), character(0)))
  need <- c("annotation", "full_names", "completed")
  if (!all(need %in% names(cell_types)))
    return(setNames(character(0), character(0)))
  df <- cell_types %>% dplyr::group_by(annotation) %>% dplyr::filter(completed == max(completed, na.rm =
                                                                                        TRUE)) %>% dplyr::ungroup() %>%
    dplyr::select(cluster = annotation, annotation = full_names)
  if (!nrow(df))
    return(setNames(character(0), character(0)))
  res <- df$annotation
  names(res) <- canon_cluster(df$cluster)
  res
}
cluster_expressed_bg <- function(seurat_obj, cluster_idents, min_cells = 5) {
  sct <- Seurat::GetAssayData(seurat_obj, assay = cfg$base_assay, layer = "data")
  g <- seurat_obj[[cluster_idents]][, 1] %>% canon_cluster()
  if (is.null(names(g)))
    names(g) <- colnames(seurat_obj)
  lapply(split(names(g), g), function(cells) {
    if (!length(cells))
      return(character(0))
    keep <- Matrix::rowSums(sct[, cells, drop = FALSE] > 0) >= min_cells
    rownames(sct)[keep]
  })
}
annotate_hypergeomX <- function(seurat_obj,
                                seurat_clusters_DEG,
                                marker_ref,
                                cluster_idents,
                                min_cells_bg = 5) {
  if (!nrow(seurat_clusters_DEG))
    return(setNames(character(0), character(0)))
  seurat_clusters_DEG$cluster <- canon_cluster(as.character(seurat_clusters_DEG$cluster))
  clusters <- unique(seurat_clusters_DEG$cluster)
  bg_by_cluster <- cluster_expressed_bg(seurat_obj, cluster_idents, min_cells =
                                          min_cells_bg)
  vals <- sapply(clusters, function(cl) {
    bg_genes <- unique(bg_by_cluster[[cl]])
    if (!length(bg_genes))
      return("Unknown")
    cl_genes <- intersect(unique(seurat_clusters_DEG$gene[seurat_clusters_DEG$cluster == cl]), bg_genes)
    if (!length(cl_genes))
      return("Unknown")
    df <- do.call(rbind, lapply(unique(marker_ref$final_cluster), function(ct) {
      ct_genes <- intersect(unique(marker_ref$gene[marker_ref$final_cluster == ct]), bg_genes)
      overlap <- length(intersect(cl_genes, ct_genes))
      m <- length(ct_genes)
      n <- length(bg_genes) - m
      k <- length(cl_genes)
      pval <- stats::phyper(overlap - 1, m, n, k, lower.tail = FALSE)
      data.frame(cell_type = ct, pval = pval)
    }))
    if (!nrow(df))
      return("Unknown")
    df$p_adj <- p.adjust(df$pval, method = "BH")
    as.character(df$cell_type[which.min(df$p_adj)])
  }, USE.NAMES = FALSE)
  names(vals) <- clusters
  vals
}
annotate_fgsea <- function(seurat_clusters_DEG,
                           marker_ref,
                           minSize = 3,
                           maxSize = 500) {
  if (!has_fgsea ||
      !nrow(seurat_clusters_DEG))
    return(setNames(character(0), character(0)))
  genesets <- split(marker_ref$gene, marker_ref$final_cluster)
  res <- lapply(split(seurat_clusters_DEG, seurat_clusters_DEG$cluster), function(df_cl) {
    lfc <- df_cl$avg_log2FC
    names(lfc) <- df_cl$gene
    if (!length(lfc))
      return("Unknown")
    lfc <- sort(tapply(lfc, names(lfc), max), decreasing = TRUE)
    gs <- lapply(genesets, function(v)
      intersect(v, names(lfc)))
    len <- sapply(gs, length)
    gs <- gs[len >= minSize & len <= maxSize]
    if (!length(gs))
      return("Unknown")
    gsr <- suppressWarnings(fgsea::fgsea(
      pathways = gs,
      stats = lfc,
      nperm = 2000
    ))
    if (!nrow(gsr))
      return("Unknown")
    as.character(gsr$pathway[order(gsr$padj, -abs(gsr$NES))][1])
  })
  labs <- unlist(res)
  names(labs) <- names(res)
  labs
}

# ---- UCell lean ----
fast_ucell_labels_lean <- function(seurat_obj,
                                   marker_ref,
                                   cluster_idents,
                                   min_genes   = cfg$ucell_min_genes,
                                   cells_per_cluster = cfg$ucell_cells_per_cluster,
                                   max_signatures    = cfg$ucell_max_signatures,
                                   ncores = cfg$ucell_ncores) {
  if (!has_ucell || !isTRUE(cfg$enable_ucell)) {
    message("[annot][ucell] UCell unavailable or disabled; skipping.")
    return(list(
      obj = seurat_obj,
      labels = setNames(character(0), character(0))
    ))
  }
  if (!cluster_idents %in% colnames(seurat_obj@meta.data)) {
    message("[annot][ucell] Cluster key '",
            cluster_idents,
            "' not found; skipping.")
    return(list(
      obj = seurat_obj,
      labels = setNames(character(0), character(0))
    ))
  }
  
  DefaultAssay(seurat_obj) <- cfg$base_assay
  bg <- rownames(Seurat::GetAssayData(seurat_obj, assay = cfg$base_assay, layer = "data"))
  if (!length(bg)) {
    message("[annot][ucell] No background genes in assay; skipping.")
    return(list(
      obj = seurat_obj,
      labels = setNames(character(0), character(0))
    ))
  }
  
  sigs <- split(marker_ref$gene, marker_ref$final_cluster)
  sigs <- lapply(sigs, function(v)
    intersect(unique(v), bg))
  sigs <- sigs[sapply(sigs, length) >= min_genes]
  if (!length(sigs)) {
    message(
      "[annot][ucell] No usable signatures after filtering (min_genes=",
      min_genes,
      "); skipping."
    )
    return(list(
      obj = seurat_obj,
      labels = setNames(character(0), character(0))
    ))
  }
  sig_len  <- sort(sapply(sigs, length), decreasing = TRUE)
  keep_sig <- names(sig_len)[seq_len(min(length(sig_len), max_signatures))]
  sigs     <- sigs[keep_sig]
  
  grp <- seurat_obj[[cluster_idents]][, 1] |> as.character() |> canon_cluster()
  if (length(grp) != ncol(seurat_obj)) {
    message("[annot][ucell] length(grp) != ncol(object); skipping.")
    return(list(
      obj = seurat_obj,
      labels = setNames(character(0), character(0))
    ))
  }
  cells_by <- split(colnames(seurat_obj), grp)
  subs <- unlist(lapply(cells_by, function(v)
    if (length(v) <= cells_per_cluster)
      v
    else
      sample(v, cells_per_cluster)), use.names = FALSE)
  if (!length(subs)) {
    message("[annot][ucell] Subsampling kept 0 cells; skipping.")
    return(list(
      obj = seurat_obj,
      labels = setNames(character(0), character(0))
    ))
  }
  
  sub <- subset(seurat_obj, cells = subs)
  pre_cols <- colnames(sub@meta.data)
  message(
    "[annot][ucell] Running UCell on ",
    length(subs),
    " cells across ",
    length(sigs),
    " signatures..."
  )
  sub <- UCell::AddModuleScore_UCell(sub,
                                     features = sigs,
                                     name = "U",
                                     ncores = ncores)
  
  post_cols <- setdiff(colnames(sub@meta.data), pre_cols)
  if (!length(post_cols))
    post_cols <- grep("^U[._-]", colnames(sub@meta.data), value = TRUE)
  if (!length(post_cols)) {
    message("[annot][ucell] No UCell score columns were created; skipping.")
    return(list(
      obj = seurat_obj,
      labels = setNames(character(0), character(0))
    ))
  }
  
  clean_uc <- function(x) {
    x <- gsub("^U[._-]*", "", x, perl = TRUE)
    x <- gsub("(?:[._-]*(?:UCell|U))?$", "", x, perl = TRUE)
    trim(x)
  }
  col2label <- setNames(clean_uc(post_cols), post_cols)
  
  sub_grp <- sub[[cluster_idents]][, 1] |> as.character() |> canon_cluster()
  df <- data.frame(cluster = sub_grp, sub@meta.data[, post_cols, drop = FALSE], check.names = FALSE)
  lab <- tapply(seq_len(nrow(df)), df$cluster, function(ix) {
    med <- suppressWarnings(apply(df[ix, post_cols, drop = FALSE], 2, stats::median, na.rm = TRUE))
    best_col <- names(which.max(med))
    col2label[[best_col]]
  })
  
  labs <- unlist(lab)
  labs <- setNames(as.character(labs), names(lab))
  list(obj = seurat_obj, labels = labs)
}

# Two-pass annotation orchestrator ---------------------------------------
annotate_all_methods <- function(obj,
                                 marker_ref,
                                 cluster_key_name = "cluster_key",
                                 prefix = "annot_detailed",
                                 limit_clusters = NULL,
                                 deg_prefix = "deg",
                                 degs_precomputed = NULL) {
  stopifnot(inherits(obj, "Seurat"))
  if (is.null(marker_ref) ||
      !nrow(marker_ref) ||
      !all(c("gene", "final_cluster") %in% names(marker_ref))) {
    stop("[annot] marker_ref must be a data.frame with columns: gene, final_cluster")
  }
  obj <- ensure_cluster_key(obj, cluster_key_name)
  
  old_id <- Idents(obj)
  on.exit(try(Idents(obj) <- old_id, silent = TRUE)
          , add = TRUE)
  Idents(obj) <- obj[[cluster_key_name]][, 1]
  clv <- Idents(obj)
  
  DefaultAssay(obj) <- cfg$base_assay
  X_full <- Seurat::GetAssayData(obj, assay = cfg$base_assay, layer = "data")
  if (!inherits(X_full, "dgCMatrix"))
    X_full <- methods::as(X_full, "dgCMatrix")
  bg <- rownames(X_full)
  
  curated_genes <- unique(as.character(marker_ref$gene))
  feats_cur <- intersect(curated_genes, bg)
  if (!length(feats_cur))
    stop("[annot] No curated markers found in assay background.")
  if (isTRUE(cfg$annot_features_max) &&
      length(feats_cur) > cfg$annot_features_max)
    feats_cur <- feats_cur[seq_len(cfg$annot_features_max)]
  
  obj_lean <- Seurat::DietSeurat(
    obj,
    assays = cfg$base_assay,
    counts = TRUE,
    data = TRUE,
    scale.data = FALSE,
    dimreducs = character(),
    graphs = character(),
    features = feats_cur
  )
  DefaultAssay(obj_lean) <- cfg$base_assay
  Idents(obj_lean) <- obj_lean[[cluster_key_name]][, 1]
  
  if (!is.null(limit_clusters)) {
    limit_clusters <- canon_cluster(as.character(limit_clusters))
    grp_all <- as.character(Idents(obj_lean))
    keep_cells <- colnames(obj_lean)[grp_all %in% limit_clusters]
    if (!length(keep_cells)) {
      message("[annot] limit_clusters matched 0 cells; returning empty table.")
      tab_empty <- tibble::tibble(cluster = character(0))
      if (!prefix %in% colnames(obj@meta.data))
        obj[[prefix]] <- NA_character_
      if (!"final_consensus" %in% colnames(obj@meta.data))
        obj$final_consensus <- NA_character_
      try(ckpt_save("annot_table_last",
                    list(table = tab_empty, saved_at = Sys.time())), silent = TRUE)
      return(list(
        obj = obj,
        table = tab_empty,
        degs = tibble::tibble()
      ))
    }
    obj_lean <- subset(obj_lean, cells = keep_cells)
  }
  
  clusters <- Idents(obj_lean)
  lv <- levels(clusters)
  expected_clusters <- lv
  
  tab_placeholder <- tibble::tibble(cluster = canon_cluster(as.character(expected_clusters)))
  try(ckpt_save("annot_table_last",
                list(table = tab_placeholder, saved_at = Sys.time())), silent = TRUE)
  
  message("[annot] DEGs: preparing ...")
  degs <- degs_precomputed
  if (is.null(degs)) {
    if (isTRUE(cfg$annot_skip_deg)) {
      cached <- try(deg_ckpt_load(
        prefix = deg_prefix,
        obj = obj,
        group_col = cluster_key_name,
        features = feats_cur
      ),
      silent = TRUE)
      if (!inherits(cached, "try-error") &&
          !is.null(cached))
        degs <- cached
    } else {
      layer_or_slot <- pick_layer_arg()
      pieces <- lapply(expected_clusters, function(cl) {
        args <- list(
          object = obj_lean,
          ident.1 = cl,
          only.pos = TRUE,
          min.pct = 0.20,
          logfc.threshold = 0.25,
          test.use = "wilcox",
          verbose = FALSE,
          features = feats_cur,
          random.seed = cfg$seed
        )
        if (!is.null(cfg$deg_subsample_per_ident))
          args$max.cells.per.ident <- cfg$deg_subsample_per_ident
        if (layer_or_slot == "layer")
          args$layer <- "data"
        else
          args$slot <- "data"
        fm <- tryCatch(
          do.call(Seurat::FindMarkers, args),
          error = function(e)
            NULL
        )
        if (is.null(fm) || !nrow(fm))
          return(NULL)
        tibble::tibble(
          cluster    = canon_cluster(as.character(cl)),
          gene       = rownames(fm),
          avg_log2FC = suppressWarnings(
            if ("avg_log2FC" %in% names(fm))
              fm$avg_log2FC
            else if ("avg_logFC" %in% names(fm))
              fm$avg_logFC
            else if ("log2FC" %in% names(fm))
              fm$log2FC
            else
              NA_real_
          ),
          p_val_adj  = suppressWarnings(
            if ("p_val_adj"  %in% names(fm))
              fm$p_val_adj
            else if ("p_val.adj" %in% names(fm))
              fm$p_val.adj
            else if ("p_val" %in% names(fm))
              fm$p_val
            else
              1
          )
        )
      })
      degs <- dplyr::bind_rows(pieces)
      if (!is.null(degs) && nrow(degs)) {
        degs <- dplyr::filter(
          degs,
          is.finite(avg_log2FC),
          is.finite(p_val_adj),
          p_val_adj < 0.05,
          avg_log2FC > 0
        )
        try(deg_ckpt_save(
          prefix = deg_prefix,
          obj = obj,
          group_col = cluster_key_name,
          degs = degs,
          features = feats_cur
        ),
        silent = TRUE)
      }
    }
  }
  message(sprintf(
    "[annot] DEGs: %s",
    ifelse(
      is.null(degs) ||
        !nrow(degs),
      "none (empty table)",
      sprintf("n=%d", nrow(degs))
    )
  ))
  
  X <- Seurat::GetAssayData(obj_lean, assay = cfg$base_assay, layer = "data")
  stopifnot(inherits(X, "dgCMatrix"))
  avg_gc   <- avg_by_group_sparse(X, clusters)
  avg_gc_m <- as.matrix(avg_gc)
  
  obj_lean@misc$annot_feats   <- feats_cur
  obj_lean@misc$annot_avg_exp <- avg_gc_m
  obj_lean@misc$annot_degs    <- degs
  assign(".annot_obj_lean", obj_lean, envir = .GlobalEnv)
  
  out <- list()
  curated_genes <- unique(as.character(marker_ref$gene))
  out$avg_exp   <- .run_annot_method("avg_exp", function()
    annotate_avgexp_matrix(avg_gc_m, marker_ref, curated_genes = curated_genes), expect_clusters = expected_clusters)
  out$hypergeom <- .run_annot_method("hypergeom", function()
    annotate_hypergeom(degs %||% tibble::tibble(), marker_ref, rownames(X)), expect_clusters = expected_clusters)
  out$majority  <- .run_annot_method("majority", function()
    annotate_majority(degs %||% tibble::tibble(), marker_ref), expect_clusters = expected_clusters)
  out$logfc     <- .run_annot_method("logfc", function()
    annotate_logfc(degs %||% tibble::tibble(), marker_ref, use_padj_weight = TRUE), expect_clusters = expected_clusters)
  out$cellmanam <- .run_annot_method("cellmanam", function()
    annotate_cellmanam(obj, degs %||% tibble::tibble(), marker_ref), expect_clusters = expected_clusters)
  out$hypergeomX <- .run_annot_method("hypergeomX", function()
    annotate_hypergeomX(
      obj,
      degs %||% tibble::tibble(),
      marker_ref,
      cluster_idents = cluster_key_name,
      min_cells_bg = 5
    ), expect_clusters = expected_clusters)
  out$gsea      <- .run_annot_method("gsea", function()
    annotate_fgsea(degs %||% tibble::tibble(), marker_ref), expect_clusters = expected_clusters)
  
  uc <- fast_ucell_labels_lean(obj, marker_ref, cluster_idents = cluster_key_name)
  obj <- uc$obj
  out$ucell     <- .run_annot_method("ucell", function()
    uc$labels, expect_clusters = expected_clusters)
  
  build_tbl <- function(named_vec, nm) {
    if (!length(named_vec))
      return(NULL)
    tibble::tibble(cluster = names(named_vec), !!nm := unname(named_vec))
  }
  dfs <- purrr::compact(lapply(names(out), function(nm)
    build_tbl(out[[nm]], nm)))
  tab <- if (length(dfs))
    Reduce(function(x, y)
      dplyr::full_join(x, y, by = "cluster"), dfs)
  else
    tibble::tibble(cluster = character(0))
  if (!nrow(tab))
    tab <- tibble::tibble(cluster = canon_cluster(as.character(expected_clusters)))
  tab$cluster <- canon_cluster(tab$cluster)
  if (!is.null(limit_clusters))
    tab <- dplyr::filter(tab, cluster %in% expected_clusters)
  
  methods <- intersect(
    c(
      "avg_exp",
      "hypergeom",
      "majority",
      "logfc",
      "cellmanam",
      "hypergeomX",
      "gsea",
      "ucell"
    ),
    names(tab)
  )
  weights <- c(
    avg_exp = 1,
    hypergeom = 1,
    majority = 1,
    logfc = 1,
    cellmanam = 1,
    hypergeomX = 1,
    gsea = 1,
    ucell = 1
  )
  weights <- weights[methods]
  wvote <- function(row) {
    # methods + weights already defined above
    vals <- as.list(row[methods])
    labs <- vapply(vals, function(x) as.character(x[[1]]), "", USE.NAMES = FALSE)
    keep <- !is.na(labs) & labs != "Unknown"
    if (!any(keep)) return(NA_character_)
    # align weights to the methods that actually voted
    ms   <- methods[keep]
    labs <- labs[keep]
    # sum method weights per label
    sc <- tapply(weights[ms], labs, sum, simplify = TRUE)
    names(which.max(sc))
  }
  
  if (nrow(tab) &&
      length(methods))
    tab$final_consensus <- apply(tab[, methods, drop = FALSE], 1, wvote)
  else
    tab$final_consensus <- NA_character_
  
  if (nrow(tab) && length(methods)) {
    tab$agree <- apply(tab[, methods, drop = FALSE], 1, function(x) {
      v <- x[!is.na(x) & x != "Unknown"]
      if (!length(v))
        return(0)
      max(table(v)) / length(v)
    })
  } else
    tab$agree <- numeric(nrow(tab))
  
  grp <- obj[[cluster_key_name]][, 1] %>% as.character() %>% canon_cluster()
  names(grp) <- colnames(obj)
  init_meta <- function(n)
    rep(NA_character_, n)
  # existing block writes: annot_detailed_avg_exp, ..., annot_detailed_ucell
  for (m in methods) {
    colname <- paste0(prefix, "_", m)
    if (!colname %in% colnames(obj@meta.data))
      obj[[colname]] <- init_meta(ncol(obj))
    cmap <- setNames(tab[[m]], tab$cluster)
    obj[[colname]] <- apply_mapping(grp, cmap)
  }
  # also write UNPREFIXED aliases by request (result_obj$avg_exp, etc.)
  for (m in methods) {
    cmap <- setNames(tab[[m]], tab$cluster)
    obj[[m]] <- apply_mapping(grp, cmap)
  }
  
  if (!prefix %in% colnames(obj@meta.data))
    obj[[prefix]] <- init_meta(ncol(obj))
  if (nrow(tab))
    obj[[prefix]] <- apply_mapping(grp, setNames(tab$final_consensus, tab$cluster))
  if (!"final_consensus" %in% colnames(obj@meta.data))
    obj$final_consensus <- init_meta(ncol(obj))
  if (nrow(tab))
    obj$final_consensus <- apply_mapping(grp, setNames(tab$final_consensus, tab$cluster))
  
  tab <- dplyr::arrange(tab, cluster)
  try(ckpt_save("annot_table_last", list(table = tab, saved_at = Sys.time())), silent = TRUE)
  
  res <- list(
    obj = obj,
    table = tab,
    degs = degs %||% tibble::tibble()
  )
  class(res) <- c("annot_res", "list")
  return(res)
}

print_annotation_summary <- function(tab, n = NULL) {
  if (is.null(tab) ||
      !nrow(tab)) {
    message("[annot][summary] Empty annotation table.")
    return(invisible(NULL))
  }
  cols <- intersect(
    c(
      "cluster",
      "avg_exp",
      "hypergeom",
      "majority",
      "logfc",
      "cellmanam",
      "hypergeomX",
      "gsea",
      "ucell",
      "final_consensus",
      "agree"
    ),
    colnames(tab)
  )
  show <- tab[, cols, drop = FALSE] %>% dplyr::arrange(cluster)
  if (!is.null(n))
    show <- utils::head(show, n)
  print(show, row.names = FALSE)
  invisible(show)
}

write_results_xlsx <- function(wb, sheet_prefix, annot_tab, degs, anno_df) {
  addWorksheet(wb, paste0(sheet_prefix, "_annot"))
  writeData(wb, paste0(sheet_prefix, "_annot"), annot_tab)
  if (is.null(degs) ||
      !is.data.frame(degs) || !nrow(degs))
    return(invisible(NULL))
  if (!"cluster" %in% names(degs))
    stop("DEG table lacks 'cluster' column.")
  by_cluster <- split(degs, as.character(degs$cluster))
  adf_ok <- is.data.frame(anno_df) && nrow(anno_df) > 0
  if (adf_ok) {
    adf <- anno_df
    if (!"gene" %in% names(adf)) {
      alt <- intersect(names(adf),
                       c(
                         "Gene",
                         "GeneID",
                         "gene_id",
                         "Gene_name",
                         "GeneID_version"
                       ))
      if (length(alt))
        names(adf)[match(alt[1], names(adf))] <- "gene"
    }
    if (!"all_anno" %in% names(adf)) {
      alt <- intersect(
        c(
          "Annotation",
          "Uniprot_protein_name",
          "PFAM_domain_name",
          "NCBI_ID"
        ),
        names(adf)
      )
      adf$all_anno <- if (length(alt))
        do.call(paste, c(adf[alt], sep = "; "))
      else
        NA_character_
    }
    adf$gene <- trim(as.character(adf$gene))
    adf <- unique(adf[, c("gene", "all_anno")])
    adf$gene_nov <- sub("\\.[0-9]+$", "", adf$gene)
    adf_slim <- unique(adf[, c("gene", "gene_nov", "all_anno")])
  }
  
  norm_chr <- function(x)
    trim(as.character(x))
  for (nm in names(by_cluster)) {
    df <- by_cluster[[nm]]
    df <- as.data.frame(df, stringsAsFactors = FALSE)
    if (!"gene" %in% names(df))
      df$gene <- rownames(df)
    df$gene <- norm_chr(df$gene)
    if (adf_ok) {
      # 1) First join
      df1 <- dplyr::left_join(
        df,
        adf_slim[, c("gene","all_anno"), drop = FALSE],
        by = "gene"
      )
      df1$all_anno <- vapply(df1$all_anno, as_chr_collapse, "", USE.NAMES = FALSE)
      # 2) Only if any annotations are missing
      need <- is.na(df1$all_anno) | df1$all_anno == ""
      if (any(need)) {
        df1$gene_nov <- sub("\\.[0-9]+$", "", df1$gene)
        stopifnot(is.data.frame(df1[need, c("gene_nov"), drop = FALSE]))
        
        lk <- dplyr::left_join(
          df1[need, c("gene_nov"), drop = FALSE],                             # <- drop = FALSE
          unique(adf_slim[, c("gene_nov","all_anno"), drop = FALSE]),         # <- drop = FALSE
          by = "gene_nov"
        )
        
        df1$all_anno[need] <- lk$all_anno
        df1$gene_nov <- NULL
        df1$all_anno <- vapply(df1$all_anno, as_chr_collapse, "", USE.NAMES = FALSE)
      }
      df <- df1
      
      
    }
    sht <- sanitize_sheet(paste0(sheet_prefix, "_", nm))
    addWorksheet(wb, sht)
    writeData(wb, sht, df)
  }
}

# ---------------------------
# Cluster keys & subclustering
# ---------------------------

ensure_cluster_key <- function(obj, key = "cluster_key") {
  if (!(key %in% colnames(obj@meta.data))) {
    # build it from Idents if possible
    if (length(Idents(obj))) {
      obj[[key]] <- canon_cluster(as.character(Idents(obj)))
    } else {
      stop("'",
           key,
           "' not found in Seurat object meta.data and Idents() is empty.")
    }
  } else {
    obj[[key]] <- canon_cluster(as.character(obj[[key]][, 1]))
  }
  obj
}

.pretty_sizes <- function(tab_named)
  paste(sprintf("%d:%d", seq_along(tab_named), as.integer(tab_named)), collapse = ", ")
.min_allowed <- function(n, abs_min, prop_min, tiny_prop) {
  list(
    min_allowed   = max(abs_min, ceiling(prop_min  * n)),
    tiny_allowed  = max(abs_min, ceiling(tiny_prop * n)),
    abs_min       = abs_min,
    prop_min_calc = ceiling(prop_min  * n),
    tiny_min_calc = ceiling(tiny_prop * n)
  )
}
# --- replace the centroid merge helper with this hardened version ---
.merge_labels_by_centroid <- function(emb, lab, tiny_levels) {
  if (!length(tiny_levels))
    return(lab)
  lab <- as.character(lab)
  lab[is.na(lab)] <- "NA"                    # guard NAs
  tab <- table(lab)
  big_levels <- setdiff(names(tab), tiny_levels)
  if (!length(big_levels))
    return(lab)
  
  emb <- as.matrix(emb)
  stopifnot(nrow(emb) == length(lab))
  
  # centroids per level
  lab_f <- factor(lab)                # ensure factor
  cent  <- rowsum(emb, lab_f) / as.numeric(table(lab_f))
  cent <- as.matrix(cent)
  
  for (sm in tiny_levels) {
    # if tiny level vanished, skip
    if (!sm %in% rownames(cent) || !length(big_levels))
      next
    d <- sapply(big_levels, function(b) {
      v <- cent[sm, , drop = FALSE] - cent[b, , drop = FALSE]
      sum(v^2)
    })
    to <- names(which.min(d))
    lab[lab == sm] <- to
  }
  lab
}



# ---- CORE:  subclustering for ONE parent ---------------------
subcluster_one <- function(obj_in,
                           parent_key,
                           cluster_key_name = "cluster_key",
                           new_key_name     = "cluster_key_v2",
                           res_grid         = cfg$subclust_res,
                           npcs             = cfg$subclust_npcs,
                           k                = cfg$subclust_k,
                           min_abs          = cfg$min_child_abs,
                           min_prop         = cfg$min_child_prop,
                           tiny_prop        = cfg$subclust_tiny_prop,
                           collapse_tiny    = TRUE) {
  # --- helpers ---
  has_valid_graph <- function(o, gnm) {
    G <- o@graphs[[gnm]]
    if (is.null(G))
      return(FALSE)
    if (!methods::is(G, "dgCMatrix"))
      G <- methods::as(G, "dgCMatrix")
    Matrix::nnzero(G) > 0
  }
  pick_snn <- function(o) {
    cands <- names(o@graphs)
    if (!length(cands))
      return(NULL)
    snn <- cands[grepl("_snn", cands, fixed = TRUE)]
    if (!length(snn))
      snn <- cands[1]
    snn[1]
  }
  
  # --- find parent cells from PARENT labels only ---
  base_lab <- obj_in[[cluster_key_name]][, 1] |> as.character() |> canon_cluster()
  p_regex <- paste0("^", parent_key, "(\\.|$)")
  cells_parent <- colnames(obj_in)[grepl(p_regex, base_lab, perl = TRUE)]
  parent_n <- length(cells_parent)
  if (!parent_n) {
    message("[subclust] (", parent_key, ") no cells; skip.")
    return(obj_in)
  }
  
  # small parents: skip early
  if (parent_n < (cfg$sub_min_cells_for_split %||% 30L)) {
    message("[subclust] (keep ",
            parent_key,
            ") too few cells (n=",
            parent_n,
            "); skip.")
    return(obj_in)
  }
  
  thr <- .min_allowed(parent_n, min_abs, min_prop, tiny_prop)
  
  # subset
  DefaultAssay(obj_in) <- cfg$base_assay
  sub <- tryCatch(
    subset(obj_in, cells = cells_parent),
    error = function(e)
      NULL
  )
  if (is.null(sub) || ncol(sub) < 3L) {
    message("[subclust] (keep ", parent_key, ") subset invalid; skip.")
    return(obj_in)
  }
  
  # HVG + drop zero-variance
  if (length(VariableFeatures(sub)) < 200) {
    sub <- FindVariableFeatures(
      sub,
      selection.method = "vst",
      nfeatures = min(2000, nrow(sub)),
      verbose = FALSE
    )
  }
  feats <- intersect(VariableFeatures(sub), rownames(sub))
  if (!length(feats)) {
    message("[subclust] (keep ", parent_key, ") no HVGs; skip.")
    return(obj_in)
  }
  
  Xsub <- tryCatch(
    Seurat::GetAssayData(obj_in, assay = cfg$base_assay, layer = "data")[feats, cells_parent, drop = FALSE],
    error = function(e)
      NULL
  )
  if (is.null(Xsub)) {
    message("[subclust] (keep ", parent_key, ") no data layer; skip.")
    return(obj_in)
  }
  
  mu  <- Matrix::rowMeans(Xsub)
  mu2 <- Matrix::rowMeans(Xsub^2)
  vrs <- as.numeric(mu2 - mu^2)
  feats <- feats[vrs > 0]
  if (!length(feats)) {
    message("[subclust] (keep ", parent_key, ") no variable signal; skip.")
    return(obj_in)
  }
  
  # scaling
  sub <- tryCatch(
    ScaleData(sub, features = feats, verbose = FALSE),
    error = function(e)
      NULL
  )
  if (is.null(sub)) {
    message("[subclust] (keep ", parent_key, ") ScaleData failed; skip.")
    return(obj_in)
  }
  
  # PCA (exact SVD, bounded PCs)
  use_pcs_max <- max(2L, min(length(feats), parent_n - 2L))
  Xsc <- tryCatch(
    Seurat::GetAssayData(sub, assay = cfg$base_assay, layer = "scale.data")[feats, , drop = FALSE],
    error = function(e)
      NULL
  )
  if (is.null(Xsc) || nrow(Xsc) < 2L || ncol(Xsc) < 3L) {
    message("[subclust] (keep ",
            parent_key,
            ") scaled matrix too small; skip.")
    return(obj_in)
  }
  rank_est <- tryCatch(
    as.integer(Matrix::rankMatrix(t(Xsc))),
    error = function(e)
      NA_integer_
  )
  if (is.finite(rank_est))
    use_pcs_max <- max(2L, min(use_pcs_max, rank_est - 1L))
  use_pcs <- min(npcs, 50L, use_pcs_max)
  
  sub <- tryCatch(
    RunPCA(
      sub,
      features = rownames(Xsc),
      npcs = use_pcs,
      approx = FALSE,
      verbose = FALSE
    ),
    error = function(e)
      NULL
  )
  if (is.null(sub) || is.null(sub@reductions$pca)) {
    message("[subclust] (keep ", parent_key, ") PCA failed; skip.")
    return(obj_in)
  }
  
  emb <- tryCatch(
    Embeddings(sub, "pca"),
    error = function(e)
      NULL
  )
  if (is.null(emb) || ncol(emb) < 2L) {
    message("[subclust] (keep ",
            parent_key,
            ") PCA embeddings too small; skip.")
    return(obj_in)
  }
  emb <- as.matrix(emb)
  storage.mode(emb) <- "double"
  avail <- min(ncol(emb), use_pcs)
  
  # Neighbors with safe k
  k_eff <- max(5L, min(k, parent_n - 2L, max(10L, floor(parent_n * 0.25))))
  gname <- paste0(cfg$base_assay, "_snn_sub_k", k_eff)
  sub <- tryCatch(
    FindNeighbors(
      sub,
      dims = 1:avail,
      k.param = k_eff,
      graph.name = gname,
      verbose = FALSE
    ),
    error = function(e)
      NULL
  )
  if (is.null(sub)) {
    message("[subclust] (keep ",
            parent_key,
            ") FindNeighbors failed; skip.")
    return(obj_in)
  }
  
  # verify graph
  g_pick <- if (has_valid_graph(sub, gname))
    gname
  else
    pick_snn(sub)
  if (is.null(g_pick) || !has_valid_graph(sub, g_pick)) {
    message("[subclust] (keep ",
            parent_key,
            ") SNN graph empty/invalid; skip.")
    return(obj_in)
  }
  
  # scan resolutions, try multiple algorithms via safe_findclusters()
  found_multi <- FALSE
  last_ok <- NULL
  for (res in res_grid) {
    sub_try <- try(safe_findclusters(
      sub,
      graph.name = g_pick,
      resolution = res,
      seed = cfg$seed
    ),
    silent = TRUE)
    if (!inherits(sub_try, "try-error")) {
      cl <- as.character(Idents(sub_try))
      if (length(unique(cl)) > 1L) {
        sub <- sub_try
        found_multi <- TRUE
        break
      }
      last_ok <- sub_try
    }
  }
  # if never >1, keep latest valid object and skip
  if (!found_multi) {
    message("[subclust] (keep ",
            parent_key,
            ") only one child at all resolutions; skip.")
    return(obj_in)
  }
  
  # children sizes (pre-merge)
  cl <- as.character(Idents(sub))
  tab <- sort(table(cl), decreasing = TRUE)
  message(
    "[subclust] (",
    parent_key,
    ") parent_n=",
    parent_n,
    " | children: ",
    paste(sprintf("%d:%d", seq_along(tab), tab), collapse = ", "),
    " | min_allowed=",
    thr$min_allowed,
    " [abs=",
    thr$abs_min,
    ", prop=",
    min_prop,
    "→",
    thr$prop_min_calc,
    ", tiny=",
    tiny_prop,
    "→",
    thr$tiny_min_calc,
    "]"
  )
  
  # collapse tiny by centroid
  if (collapse_tiny && any(tab < thr$tiny_min_calc)) {
    tiny_lvls <- names(tab)[tab < thr$tiny_min_calc]
    cl <- .merge_labels_by_centroid(emb[, 1:avail, drop = FALSE], cl, tiny_lvls)
    Idents(sub) <- factor(cl)
    tab <- sort(table(cl), decreasing = TRUE)
    if (length(tab) <= 1L) {
      message("[subclust] (keep ",
              parent_key,
              ") tiny merge → single child; skip.")
      return(obj_in)
    }
  }
  
  # enforce min size
  if (any(tab < thr$min_allowed)) {
    small_lvls <- names(tab)[tab < thr$min_allowed]
    cl <- .merge_labels_by_centroid(emb[, 1:avail, drop = FALSE], cl, small_lvls)
    Idents(sub) <- factor(cl)
    tab <- sort(table(cl), decreasing = TRUE)
    if (length(tab) <= 1L) {
      message("[subclust] (keep ",
              parent_key,
              ") size merge → single child; skip.")
      return(obj_in)
    }
  }
  
  # final relabel gX.N
  ord <- names(tab)
  idx_map <- setNames(seq_along(ord), ord)
  child_nums   <- unname(idx_map[as.character(Idents(sub))])
  child_labels <- paste0(parent_key, ".", child_nums)
  names(child_labels) <- Cells(sub)
  
  # ensure v2 column exists as character
  if (!(new_key_name %in% colnames(obj_in@meta.data))) {
    obj_in@meta.data[[new_key_name]] <- as.character(obj_in[[cluster_key_name]][, 1])
  } else if (!is.character(obj_in@meta.data[[new_key_name]])) {
    obj_in@meta.data[[new_key_name]] <- as.character(obj_in@meta.data[[new_key_name]])
  }
  
  stopifnot(all(cells_parent %in% rownames(obj_in@meta.data)))
  # write values in the exact order of 'cells_parent'
  vals <- unname(child_labels[match(cells_parent, names(child_labels))])
  if (anyNA(vals)) {
    # if any NA slipped in, abort cleanly
    message("[subclust] (keep ",
            parent_key,
            ") child label mismatch; skip.")
    return(obj_in)
  }
  obj_in@meta.data[cells_parent, new_key_name] <- vals
  obj_in@meta.data[[new_key_name]] <- factor(obj_in@meta.data[[new_key_name]])
  
  message("[subclust] (",
          parent_key,
          ") accepted split → ",
          paste(sprintf("%d:%d", seq_along(tab), tab), collapse = ", "))
  obj_in
}




# --- subclustering -----------------------------------------------
auto_subcluster_suspicious <- function(obj_in,
                                       annot_table,
                                       cluster_key_name = "cluster_key",
                                       agree_cut = cfg$agree_cut) {
  stopifnot(inherits(obj_in, "Seurat"))
  obj_in <- ensure_cluster_key(obj_in, cluster_key_name)
  if (!all(c("cluster", "agree") %in% colnames(annot_table))) {
    message("[subclust] Annotation table lacks 'cluster' and/or 'agree'; skipping.")
    return(list(
      obj = obj_in,
      changed = FALSE,
      parents = character(0)
    ))
  }
  susp <- annot_table %>% dplyr::filter(agree <= agree_cut) %>% dplyr::arrange(agree) %>% dplyr::pull(cluster) %>% unique()
  nk_name <- paste0(cluster_key_name, "_v2")
  if (!length(susp)) {
    message("[subclust] No suspicious clusters (agree ≤ ",
            agree_cut,
            "); nothing to split.")
    return(list(
      obj = obj_in,
      changed = FALSE,
      parents = character(0)
    ))
  }
  message(
    "[subclust] Will subcluster ",
    length(susp),
    " parent(s): ",
    paste(utils::head(susp, 10), collapse = ", "),
    if (length(susp) > 10)
      " ..."
    else
      ""
  )
  obj_work <- obj_in
  accepted <- character(0)
  for (p in susp) {
    obj_try <- try(subcluster_one(
      obj_in = obj_work,
      parent_key = p,
      cluster_key_name = cluster_key_name,
      new_key_name = nk_name,
      collapse_tiny = TRUE
    ),
    silent = TRUE)
    if (inherits(obj_try, "try-error")) {
      msg <- tryCatch(
        conditionMessage(attr(obj_try, "condition")),
        error = function(e)
          "unknown error"
      )
      message("[subclust] (keep ", p, ") error: ", msg)
      next
    }
    if (nk_name %in% colnames(obj_try@meta.data)) {
      lab_now <- as.character(obj_try[[nk_name]][, 1])
      if (any(startsWith(lab_now, paste0(p, ".")), na.rm = TRUE)) {
        accepted <- c(accepted, p)
        obj_work <- obj_try
      } else {
        message("[subclust] (keep ", p, ") no new children written; skip.")
      }
    }
  }
  list(
    obj = obj_work,
    changed = length(accepted) > 0,
    parents = unique(accepted),
    new_key = nk_name
  )
}

needs_rerun_subclust <- function() {
  if (!ckpt_has("subclust"))
    return(TRUE)
  st <- ckpt_load("subclust")
  if (!isTRUE(st$changed))
    return(FALSE)
  nk <- st$new_key %||% "cluster_key_v2"
  obj2 <- st$obj
  is.null(obj2) || !(nk %in% colnames(obj2@meta.data))
}

# --- timing helpers ---
.stage_time_log_add <- function(stage,
                                t0,
                                t1,
                                ok = TRUE,
                                note = NULL) {
  rec <- data.frame(
    stage = stage,
    start = as.character(t0),
    end   = as.character(t1),
    elapsed_sec = as.numeric(difftime(t1, t0, units = "secs")),
    ok = ok,
    note = if (is.null(note))
      ""
    else
      as.character(note),
    stringsAsFactors = FALSE
  )
  if (ckpt_has("timings")) {
    log <- ckpt_load("timings")
    log <- rbind(log, rec)
  } else
    log <- rec
  ckpt_save("timings", log)
  invisible(rec)
}
save_timings_to_xlsx <- function(wb, sheet_name = "timings") {
  if (!ckpt_has("timings"))
    return(invisible(NULL))
  log <- ckpt_load("timings")
  if (sheet_name %in% names(wb))
    removeWorksheet(wb, sheet_name)
  addWorksheet(wb, sheet_name)
  writeData(wb, sheet_name, log)
  invisible(log)
}

# --- CONFIG KNOBS defaults  ---
cfg$subclust_res        <- cfg$subclust_res        %||% c(0.2, 0.4, 0.6, 0.8, 1.0, 1.5, 2.0)
cfg$subclust_npcs       <- cfg$subclust_npcs       %||% 30L
cfg$subclust_k          <- cfg$subclust_k          %||% 30L
cfg$agree_cut           <- cfg$agree_cut           %||% 0.6
cfg$min_child_abs       <- cfg$min_child_abs       %||% 10L
cfg$min_child_prop      <- cfg$min_child_prop      %||% 0.005
cfg$subclust_tiny_prop  <- cfg$subclust_tiny_prop  %||% 0.015

# ---------------------------
# STAGES
# ---------------------------
stage_load <- function() {
  message("Stage load...")
  stopifnot(file.exists(cfg$paths$markers_xlsx))
  cell_markers <- openxlsx::read.xlsx(cfg$paths$markers_xlsx, sheet = 1)
  refs <- prepare_marker_ref(cell_markers, gene_col = "Markers_positive_SMESG")
  marker_ref_general  <- refs$general
  marker_ref_detailed <- refs$detailed
  marker_ref          <- marker_ref_detailed
  
  anno_df <- NULL
  if (!is.null(cfg$paths$anno_rdata) &&
      file.exists(cfg$paths$anno_rdata)) {
    load(cfg$paths$anno_rdata)
    nn <- ls()
    hit <- nn[sapply(nn, function(x)
      is.data.frame(get(x)) && "gene" %in% names(get(x)))]
    if (length(hit)) {
      anno_df <- get(hit[1])
      if (!"all_anno" %in% names(anno_df)) {
        alt <- intersect(
          c(
            "Annotation",
            "Uniprot_protein_name",
            "PFAM_domain_name",
            "NCBI_ID"
          ),
          names(anno_df)
        )
        anno_df$all_anno <- if (length(alt))
          do.call(paste, c(anno_df[alt], sep = "; "))
        else
          NA_character_
      }
      anno_df <- anno_df[, intersect(c("gene", "all_anno"), names(anno_df)), drop =
                           FALSE]
    }
  }
  if (!is.null(anno_df)) {
    # Ensure 'gene' and 'all_anno' exist and are plain character
    if (!"gene" %in% names(anno_df)) {
      alt <- intersect(names(anno_df),
                       c("Gene", "GeneID", "gene_id", "Gene_name", "GeneID_version"))
      if (length(alt)) names(anno_df)[match(alt[1], names(anno_df))] <- "gene"
    }
    if (!"all_anno" %in% names(anno_df)) {
      alt <- intersect(names(anno_df),
                       c("Annotation", "Uniprot_protein_name", "PFAM_domain_name", "NCBI_ID"))
      anno_df$all_anno <- if (length(alt)) do.call(paste, c(anno_df[alt], sep="; ")) else NA_character_
    }
    # <- NEW: flatten possible list column to character
    if (is.list(anno_df$all_anno)) {
      anno_df$all_anno <- vapply(anno_df$all_anno, as_chr_collapse, "", USE.NAMES = FALSE)
    } else {
      anno_df$all_anno <- as.character(anno_df$all_anno %||% "")
    }
    anno_df$gene <- trim(as.character(anno_df$gene))
    anno_df <- unique(anno_df[, c("gene", "all_anno")])
  }
  
  
  obj <- NULL
  if (!is.null(cfg$paths$matrix_rds) &&
      file.exists(cfg$paths$matrix_rds)) {
    message("Loading from matrix_rds...")
    mat <- readRDS(cfg$paths$matrix_rds)
    if (!inherits(mat, "dgCMatrix"))
      stop("matrix_rds must be dgCMatrix.")
    if (is.null(rownames(mat)))
      rownames(mat) <- paste0("g", seq_len(nrow(mat)))
    if (is.null(colnames(mat)))
      colnames(mat) <- paste0("c", seq_len(ncol(mat)))
    empty_counts <- new("dgCMatrix",
                        Dim = dim(mat),
                        Dimnames = dimnames(mat))
    obj <- CreateSeuratObject(
      counts = empty_counts,
      assay = cfg$base_assay,
      min.cells = 0,
      min.features = 0
    )
    if ("SetAssayData" %in% getNamespaceExports("SeuratObject")) {
      obj <- SeuratObject::SetAssayData(
        obj,
        assay = cfg$base_assay,
        layer = "data",
        new.data = mat
      )
      obj <- SeuratObject::SetAssayData(
        obj,
        assay = cfg$base_assay,
        layer = "counts",
        new.data = empty_counts
      )
    } else {
      obj[[cfg$base_assay]]@data   <- mat
      obj[[cfg$base_assay]]@counts <- empty_counts
    }
    DefaultAssay(obj) <- cfg$base_assay
  } else {
    stopifnot(file.exists(cfg$paths$seurat_rdata))
    message("Loading from seurat_rdata...")
    load(cfg$paths$seurat_rdata)
    if (exists("integrated_seurat_obj_annotated_new"))
      obj <- integrated_seurat_obj_annotated_new
    else if (exists("integrated_seurat_obj"))
      obj <- integrated_seurat_obj
    else {
      nn <- ls()
      cls <- sapply(nn, function(x)
        class(get(x, inherits = FALSE))[1])
      cand <- nn[grepl("Seurat", cls)]
      if (!length(cand))
        stop("No Seurat object found inside: ", cfg$paths$seurat_rdata)
      obj <- get(cand[1], inherits = FALSE)
    }
    DefaultAssay(obj) <- cfg$base_assay
  }
  
  obj_small <- diet_for_checkpoint(
    obj,
    keep_assay = cfg$base_assay,
    keep_reduc = character(),
    keep_graphs = character(),
    drop_counts = TRUE,
    drop_scale = TRUE
  )
  
  ckpt_save(
    "load",
    list(
      obj = obj_small,
      marker_ref_general  = marker_ref_general,
      marker_ref_detailed = marker_ref_detailed,
      marker_ref          = marker_ref,
      anno_df = anno_df
    )
  )
  invisible(TRUE)
}

ensure_scaled <- function(obj, assay, features, chunk = 4000L) {
  DefaultAssay(obj) <- assay
  present <- intersect(unique(features), rownames(Seurat::GetAssayData(
    obj, assay = assay, layer = "data"
  )))
  if (!length(present))
    return(obj)
  parts <- split(present, ceiling(seq_along(present) / max(1L, chunk)))
  for (fs in parts)
    obj <- ScaleData(
      object = obj,
      assay = assay,
      features = fs,
      do.center = TRUE,
      do.scale = TRUE,
      verbose = FALSE
    )
  obj
}

stage_pca <- function() {
  message("Stage PCA...")
  st  <- ckpt_load("load")
  obj <- st$obj
  DefaultAssay(obj) <- cfg$base_assay
  if (!length(VariableFeatures(obj)))
    obj <- FindVariableFeatures(
      object = obj,
      assay = cfg$base_assay,
      selection.method = "vst",
      nfeatures = 3000,
      verbose = FALSE
    )
  hvg <- VariableFeatures(obj)
  if (!length(hvg))
    stop("No variable features.")
  obj <- ensure_scaled(
    obj,
    assay = cfg$base_assay,
    features = hvg,
    chunk = 4000L
  )
  obj <- RunPCA(
    object = obj,
    features = hvg,
    npcs = cfg$max_pcs,
    reduction.name = cfg$pca_name,
    reduction.key = "PCu_",
    verbose = FALSE
  )
  stdev <- obj@reductions[[cfg$pca_name]]@stdev
  if (is.null(stdev) || !length(stdev))
    stop("PCA failed.")
  dims <- choose_pcs_by_knee(
    stdev,
    max_pcs = min(cfg$max_pcs, length(stdev)),
    variance_cut = cfg$variance_cut,
    smooth_k = cfg$knee_smooth,
    min_pcs = 25L
  )
  message("Chosen PCs: ", dims)
  ckpt_save(
    "pca",
    list(
      obj = obj,
      dims = dims,
      marker_ref_general  = st$marker_ref_general,
      marker_ref_detailed = st$marker_ref_detailed,
      marker_ref          = st$marker_ref,
      anno_df = st$anno_df
    )
  )
  invisible(TRUE)
}

stage_grid <- function() {
  message("Stage grid...")
  st  <- ckpt_load("pca")
  obj <- st$obj
  dims <- st$dims
  gs <- grid_search_clusters(
    obj,
    dims,
    target_n = cfg$target_n_clusters,
    k_grid   = cfg$k_grid %||% c(5L, 8L, 10L, 15L, 20L),
    res_init = cfg$res_init %||% 0.6,
    res_max = cfg$res_max %||% 10,
    max_steps = cfg$grid_max_steps %||% 20
  )
  obj_best   <- gs$obj
  best_stats <- gs$stats
  keep_graph <- gs$keep_graph
  message(
    sprintf(
      "Best: k=%s res=%s clusters=%s modularity=%s",
      best_stats$k,
      best_stats$resolution,
      best_stats$n_clusters,
      round(best_stats$modularity, 3)
    )
  )
  obj_small <- diet_for_checkpoint(obj_best,
                                   keep_reduc  = cfg$pca_name,
                                   keep_graphs = keep_graph)
  ckpt_save(
    "grid",
    list(
      obj        = obj_small,
      dims = dims,
      best_stats = best_stats,
      keep_graph = keep_graph,
      marker_ref_general  = st$marker_ref_general,
      marker_ref_detailed = st$marker_ref_detailed,
      marker_ref          = st$marker_ref,
      anno_df    = st$anno_df
    )
  )
  invisible(TRUE)
}

stage_umap <- function() {
  message("Stage UMAP...")
  st <- ckpt_load("grid")
  obj <- st$obj
  dims <- st$dims
  obj <- RunUMAP(
    obj,
    reduction = cfg$pca_name,
    dims = 1:dims,
    n.neighbors = 30,
    min.dist = 0.5,
    metric = "cosine",
    reduction.name = cfg$umap_name,
    seed.use = cfg$seed,
    verbose = FALSE
  )
  obj$cluster_key <- canon_cluster(as.character(Idents(obj)))
  obj_small <- diet_for_checkpoint(
    obj,
    keep_reduc = c(cfg$pca_name, cfg$umap_name),
    keep_graphs = st$keep_graph
  )
  ckpt_update("umap", st, obj = obj_small)
  invisible(TRUE)
}

stage_deg <- function() {
  message("Stage DEG...")
  st <- ckpt_load("umap")
  obj <- st$obj
  obj <- ensure_cluster_key(obj, "cluster_key")
  degs <- compute_degs_robust(obj, group_col = "cluster_key", features_whitelist = NULL)
  if (!nrow(degs))
    message("Stage DEG: no DEGs found.")
  ckpt_update("deg",
              st,
              obj = obj,
              degs = degs,
              group_col = "cluster_key")
  try(deg_ckpt_save(
    prefix = "deg",
    obj = obj,
    group_col = "cluster_key",
    degs = degs,
    features = NULL
  ),
  silent = TRUE)
  obj_small <- diet_for_checkpoint(
    obj,
    keep_reduc = c(cfg$pca_name, cfg$umap_name),
    keep_graphs = st$keep_graph
  )
  ckpt_update("umap", st, obj = obj_small)
  invisible(TRUE)
}

rehydrate_refs <- function(st = NULL) {
  try_ck <- function(tag)
    if (ckpt_has(tag))
      ckpt_load(tag)
  else
    NULL
  bins <- list(st,
               try_ck("umap"),
               try_ck("grid"),
               try_ck("pca"),
               try_ck("load"))
  bins <- bins[!vapply(bins, is.null, TRUE)]
  pick_df <- function(x)
    is.data.frame(x) &&
    nrow(x) > 0 && all(c("gene", "final_cluster") %in% names(x))
  pick_anno <- function(x)
    is.data.frame(x) && nrow(x) > 0 && "gene" %in% names(x)
  
  mr_gen <- mr_det <- mr <- anno_df <- NULL
  for (b in bins) {
    if (is.null(mr_gen) &&
        pick_df(b$marker_ref_general))
      mr_gen <- b$marker_ref_general
    if (is.null(mr_det) &&
        pick_df(b$marker_ref_detailed))
      mr_det <- b$marker_ref_detailed
    if (is.null(mr)     &&
        pick_df(b$marker_ref))
      mr     <- b$marker_ref
    if (is.null(anno_df) &&
        pick_anno(b$anno_df))
      anno_df <- b$anno_df
  }
  if (is.null(mr) || !pick_df(mr)) {
    stopifnot(file.exists(cfg$paths$markers_xlsx))
    cm  <- openxlsx::read.xlsx(cfg$paths$markers_xlsx, sheet = 1)
    refs <- prepare_marker_ref(cm, gene_col = "Markers_positive_SMESG")
    mr_gen <- refs$general
    mr_det <- refs$detailed
    mr <- mr_det
  }
  if (is.null(anno_df) &&
      !is.null(cfg$paths$anno_rdata) &&
      file.exists(cfg$paths$anno_rdata)) {
    load(cfg$paths$anno_rdata)
    nn <- ls()
    hit <- nn[sapply(nn, function(x)
      is.data.frame(get(x)) && "gene" %in% names(get(x)))]
    if (length(hit)) {
      anno_df <- get(hit[1])
      if (!"all_anno" %in% names(anno_df)) {
        alt <- intersect(
          c(
            "Annotation",
            "Uniprot_protein_name",
            "PFAM_domain_name",
            "NCBI_ID"
          ),
          names(anno_df)
        )
        anno_df$all_anno <- if (length(alt))
          do.call(paste, c(anno_df[alt], sep = "; "))
        else
          NA_character_
      }
      anno_df <- anno_df[, intersect(c("gene", "all_anno"), names(anno_df)), drop =
                           FALSE]
    }
  }
  list(
    marker_ref_general = mr_gen,
    marker_ref_detailed = mr_det,
    marker_ref = mr,
    anno_df = anno_df
  )
}

stage_annot1 <- function() {
  message("Stage annot1...")
  st <- ckpt_load("umap")
  refs <- rehydrate_refs(st)
  st$marker_ref_general  <- refs$marker_ref_general
  st$marker_ref_detailed <- refs$marker_ref_detailed
  st$marker_ref          <- refs$marker_ref
  st$anno_df             <- refs$anno_df
  ckpt_update(
    "umap",
    st,
    marker_ref_general = st$marker_ref_general,
    marker_ref_detailed = st$marker_ref_detailed,
    marker_ref = st$marker_ref,
    anno_df = st$anno_df
  )
  
  mr <- st$marker_ref %||% st$marker_ref_detailed %||% st$marker_ref_general
  obj <- st$obj
  
  degs_in <- NULL
  if (ckpt_has("deg")) {
    d <- ckpt_load("deg")
    if (is.list(d) && !is.null(d$degs))
      degs_in <- d$degs
  }
  if (is.null(degs_in)) {
    degs_in <- try(deg_ckpt_load(
      prefix = "deg",
      obj = obj,
      group_col = "cluster_key",
      features = NULL
    ),
    silent = TRUE)
    if (inherits(degs_in, "try-error"))
      degs_in <- NULL
  }
  mr <- st$marker_ref %||% st$marker_ref_detailed %||% st$marker_ref_general
  
  ann1 <- try(annotate_all_methods(
    obj,
    mr,
    cluster_key_name = "cluster_key",
    prefix = "annot_detailed",
    limit_clusters = NULL,
    deg_prefix = "deg",
    degs_precomputed = degs_in
  ),
  silent = TRUE)
  
  good <- is.list(ann1) &&
    !is.null(ann1$obj) && !is.null(ann1$table)
  if (!good) {
    msg <- if (inherits(ann1, "try-error"))
      as.character(attr(ann1, "condition"))
    else
      "non-list return"
    message("[annot1] annotate_all_methods failed: ", msg)
    message("[annot1] Falling back to a minimal round1 table so the pipeline can proceed.")
    obj <- ensure_cluster_key(obj, "cluster_key")
    parents <- sort(unique(canon_cluster(as.character(
      obj$cluster_key
    ))))
    tab1 <- tibble::tibble(cluster = parents,
                           final_consensus = NA_character_,
                           agree = 0)
    try(ckpt_save("annot_table_last", list(table = tab1, saved_at = Sys.time())), silent = TRUE)
    wb <- wb_load_or_new(cfg$paths$out_xlsx)
    write_results_xlsx(
      wb,
      sheet_prefix = "round1",
      annot_tab = tab1,
      degs = if (isTRUE(cfg$write_round1_degs) &&
                 !is.null(degs_in) && nrow(degs_in))
        degs_in
      else
        NULL,
      anno_df = st$anno_df
    )
    save_timings_to_xlsx(wb)
    wb_save(wb, cfg$paths$out_xlsx)
    obj_small <- diet_for_checkpoint(
      obj,
      keep_reduc  = c(cfg$pca_name, cfg$umap_name),
      keep_graphs = st$keep_graph
    )
    ckpt_update("annot1", st, obj = obj_small, tab1 = tab1)
    return(invisible(TRUE))
  }
  
  obj  <- ann1$obj
  tab1 <- ann1$table %>% dplyr::arrange(cluster)
  message("[annot][summary] Round1 (detailed) first 20 rows:")
  invisible(print_annotation_summary(tab1, n = 20))
  try(ckpt_save("annot1_table", list(tab = tab1, saved_at = Sys.time())), silent = TRUE)
  
  wb <- wb_load_or_new(cfg$paths$out_xlsx)
  write_results_xlsx(
    wb,
    sheet_prefix = "round1",
    annot_tab = tab1,
    degs = if (isTRUE(cfg$write_round1_degs))
      ann1$degs
    else
      NULL,
    anno_df = st$anno_df
  )
  save_timings_to_xlsx(wb)
  wb_save(wb, cfg$paths$out_xlsx)
  
  obj_small <- diet_for_checkpoint(
    obj,
    keep_reduc  = c(cfg$pca_name, cfg$umap_name),
    keep_graphs = st$keep_graph
  )
  ckpt_update("annot1", st, obj = obj_small, tab1 = tab1)
  invisible(TRUE)
}

auto_split_large <- function(obj,
                             key        = "cluster_key_v2",
                             max_cells  = 1200L,
                             max_passes = 2L) {
  for (pass in seq_len(max_passes)) {
    lab <- as.character(obj[[key]][,1])
    tt  <- sort(table(lab), decreasing = TRUE)
    big <- names(tt)[tt > max_cells]
    if (!length(big)) {
      message("[autosplit] pass ", pass, ": nothing above ", max_cells, " cells.")
      break
    }
    message("[autosplit] pass ", pass, ": splitting ", length(big), " parents: ",
            paste(head(big, 10), collapse=", "), if (length(big)>10) " ..." else "")
    for (p in big) {
      obj <- subcluster_one(
        obj_in = obj,
        parent_key = p,
        cluster_key_name = key,
        new_key_name     = key,   # in-place refinement of the same key
        collapse_tiny    = TRUE
      )
    }
  }
  obj$cluster_key_final <- obj[[key]][,1]
  obj
}

stage_subclust <- function() {
  message("Stage subcluster...")
  st <- ckpt_load("annot1")
  if (is.null(st))
    stop("[subclust] Missing 'annot1' checkpoint. Run annot1 first.")
  obj  <- st$obj
  tab1 <- st$tab1
  obj <- ensure_cluster_key(obj, "cluster_key")
  Idents(obj) <- obj$cluster_key
  # if v2 exists but has no children (no dots), drop it
  if ("cluster_key_v2" %in% colnames(obj@meta.data) &&
      !any(grepl("\\.", obj$cluster_key_v2))) {
    obj$cluster_key_v2 <- NULL
  }
  
  
  susp <- tryCatch(
    tab1 %>% dplyr::filter(!is.na(agree) &
                             agree <= cfg$agree_cut) %>% dplyr::pull(cluster) %>% unique() %>% as.character(),
    error = function(e)
      character(0)
  )
  if (!length(susp)) {
    message("[subclust] No parents below agree_cut (",
            cfg$agree_cut,
            "). Nothing to split.")
    ckpt_save("subclust", c(st, list(obj = obj, changed = FALSE)))
    return(invisible(TRUE))
  }
  message(
    "[subclust] Will subcluster ",
    length(susp),
    " parent(s): ",
    paste(head(susp, 10), collapse = ", "),
    if (length(susp) > 10)
      " ..."
    else
      ""
  )
  
  res <- auto_subcluster_suspicious(
    obj,
    annot_table = tab1,
    cluster_key_name = "cluster_key",
    agree_cut = cfg$agree_cut
  )
  obj <- res$obj
  
  if (!isTRUE(res$changed)) {
    message("[subclust] No accepted splits (",
            length(res$parents),
            " accepted).")
    ckpt_save("subclust", c(st, list(obj = obj, changed = FALSE)))
    return(invisible(TRUE))
  }
  
  gd <- tryCatch(
    ckpt_load("grid"),
    error = function(e)
      NULL
  )
  dims_umap <- if (!is.null(gd) && length(gd$dims))
    gd$dims
  else
    30L
  obj <- RunUMAP(
    obj,
    reduction = cfg$pca_name,
    dims = 1:dims_umap,
    n.neighbors = 30,
    min.dist = 0.5,
    metric = "cosine",
    reduction.name = paste0(cfg$umap_name, ".v2"),
    seed.use = cfg$seed,
    verbose = FALSE
  )
  
  #ckpt_save("subclust", c(st, list(obj = obj, parents = unique(res$parents), new_key = res$new_key, changed = TRUE)))
  st$obj     <- obj
  st$parents <- unique(res$parents)
  st$new_key <- res$new_key
  st$changed <- TRUE
  tot <- ncol(st$obj)
  max_cells_auto <- max(800L, round(0.0125 * tot))  #  ~1.25% of total cells, but at least 800
  st$obj <- auto_split_large(st$obj, key = "cluster_key_v2", max_cells = max_cells_auto, max_passes = 2L)
  ckpt_save("subclust", st)
  invisible(TRUE)
}




stage_annot2_and_final <- function() {
  message("Stage annot2/final...")
  st1 <- ckpt_load("annot1")
  stopifnot(!is.null(st1))
  obj  <- st1$obj
  tab1 <- st1$tab1
  
  # Use subclustered key if available/meaningful
  sc <- tryCatch(
    ckpt_load("subclust"),
    error = function(e)
      NULL
  )
  use_v2 <- FALSE
  if (!is.null(sc) && !is.null(sc$obj)) {
    if ("cluster_key_v2" %in% colnames(sc$obj@meta.data)) {
      v2 <- as.character(sc$obj$cluster_key_v2)
      use_v2 <- any(grepl("\\.", v2), na.rm = TRUE) ||
        (length(unique(na.omit(v2))) > length(unique(as.character(
          sc$obj$cluster_key
        ))))
      if (use_v2)
        obj <- sc$obj
    }
  }
  final_key <- if (use_v2)
    "cluster_key_v2"
  else
    "cluster_key"
  message(
    if (use_v2)
      "[annot2] Using subclustered key (cluster_key_v2)."
    else
      "[annot2] Using round1 key (cluster_key)."
  )
  obj$cluster_key_final <- obj[[final_key]][, 1]
  
  # Rehydrate marker refs / annotations
  refs <- rehydrate_refs()
  mr   <- refs$marker_ref %||% refs$marker_ref_detailed %||% refs$marker_ref_general
  
  # compute DEGs for the FINAL grouping and cache them (prefix 'deg_final')
  degs_final <- compute_degs_robust(obj, group_col = "cluster_key_final", features_whitelist = NULL)
  if (nrow(degs_final)) {
    try(deg_ckpt_save(
      prefix = "deg_final",
      obj = obj,
      group_col = "cluster_key_final",
      degs = degs_final,
      features = NULL
    ),
    silent = TRUE)
  } else {
    message(
      "[annot2] WARNING: no DEGs passed filters for final key; DEG-based methods may return 0 labels."
    )
  }
  
  # Re-annotate using final key; feed the DEGs that were just computed
  ann2 <- annotate_all_methods(
    obj,
    mr,
    cluster_key_name = "cluster_key_final",
    prefix           = "annot_final",
    limit_clusters   = NULL,
    deg_prefix       = "deg_final",
    degs_precomputed = if (nrow(degs_final))
      degs_final
    else
      NULL
  )
  obj  <- ann2$obj
  tab2 <- ann2$table %>% dplyr::arrange(cluster)
  
  # Attach parent info from round-1 so it’s visible in Excel
  # Force data.frame/tibble types before joining 
  # attach parent info without dplyr join
  # Ensure tab2 is a data.frame with a 'cluster' column
  tab2 <- ann2$table
  if (!is.data.frame(tab2))
    tab2 <- as.data.frame(tab2, stringsAsFactors = FALSE)
  if (!"cluster" %in% names(tab2)) {
    if (!is.null(rownames(tab2))) {
      tab2$cluster <- rownames(tab2)
    } else {
      stop("[annot2] 'tab2' lacks a 'cluster' column and rownames; cannot augment.")
    }
  }
  tab2$cluster <- as.character(tab2$cluster)
  tab2 <- tab2[order(tab2$cluster), , drop = FALSE]
  
  # Map each final label to its parent
  df_labels <- data.frame(cluster = sort(unique(as.character(
    obj$cluster_key_final
  ))), stringsAsFactors = FALSE)
  df_labels$parent <- sub("\\..*$", "", df_labels$cluster)
  
  # Slim round1 table (parent consensus/agree)
  tab1_slim <- as.data.frame(st1$tab1, stringsAsFactors = FALSE)
  tab1_slim <- tab1_slim[, c("cluster", "final_consensus", "agree")]
  names(tab1_slim) <- c("parent", "parent_consensus", "parent_agree")
  tab1_slim$parent <- as.character(tab1_slim$parent)
  message("tab2 class: ", paste(class(ann2$table), collapse = ", "))
  message("tab1 class: ", paste(class(st1$tab1), collapse = ", "))
  
  # Base merges (keeps left order; no S3 generics involved)
  tab2_tmp <- merge(tab2,
                    df_labels,
                    by = "cluster",
                    all.x = TRUE,
                    sort = FALSE)
  tab2_aug <- merge(tab2_tmp,
                    tab1_slim,
                    by = "parent",
                    all.x = TRUE,
                    sort = FALSE)
  # --------------------------------------------------------
  
  # Write Excel: final_annot (augmented) + per-cluster DEG sheets
  wb <- wb_load_or_new(cfg$paths$out_xlsx)
  write_results_xlsx(
    wb,
    sheet_prefix = "final",
    annot_tab    = tab2_aug,
    degs         = ann2$degs,
    # will create final_<cluster> sheets now that DEGs exist
    anno_df      = refs$anno_df
  )
  save_timings_to_xlsx(wb)
  wb_save(wb, cfg$paths$out_xlsx)
  
  # Save checkpoint
  obj_small <- diet_for_checkpoint(obj, keep_reduc = c(cfg$pca_name, cfg$umap_name, paste0(cfg$umap_name, ".v2")))
  ckpt_save(
    "annot2_and_final",
    list(
      obj          = obj_small,
      final_key    = "cluster_key_final",
      final_labels = tab2_aug
    )
  )
  invisible(TRUE)
}


# Orchestrator ------------------------------------------------------------
auto_run <- function(resume = TRUE,
                     stop_after = NULL) {
  stages <- c("load",
              "pca",
              "grid",
              "umap",
              "deg",
              "annot1",
              "subclust",
              "annot2_and_final")
  fns <- list(
    load = stage_load,
    pca = stage_pca,
    grid = stage_grid,
    umap = stage_umap,
    deg = stage_deg,
    annot1 = stage_annot1,
    subclust = stage_subclust,
    annot2_and_final = stage_annot2_and_final
  )
  for (s in stages) {
    skip <- resume && ckpt_has(s)
    if (s == "subclust" &&
        resume && ckpt_has("subclust") && needs_rerun_subclust()) {
      message("[auto_run] subclust checkpoint is stale; re-running subclust.")
      skip <- FALSE
    }
    if (skip) {
      message("Skipping stage ", s, " (found checkpoint).")
      next
    }
    message(">>> Running stage ", s)
    t0 <- Sys.time()
    ok <- TRUE
    note <- NULL
    tryCatch({
      fns[[s]]()
    }, error = function(e) {
      ok <- FALSE
      note <- conditionMessage(e)
      stop(e)
    }, finally = {
      t1 <- Sys.time()
      .stage_time_log_add(
        stage = s,
        t0 = t0,
        t1 = t1,
        ok = ok,
        note = note
      )
      message(sprintf("[timing] %s: %.1f sec", s, as.numeric(difftime(t1, t0, units =
                                                                        "secs"))))
    })
    if (!is.null(stop_after) &&
        identical(s, stop_after)) {
      message("Stop requested after stage: ", s)
      break
    }
  }
  invisible(TRUE)
}

# ---------------------------
# RUN
# ---------------------------
CKPT_DIR  <- cfg$ckpt_dir
ckpt_file <- function(stage)
  file.path(CKPT_DIR, paste0("auto_annot_ckpt_", stage, ".rds"))

setwd("G:/PhD_final/sncRNA")

# Clear only late-stage checkpoints if needed
# unlink(ckpt_path_qs("annot1"),          force = TRUE); unlink(ckpt_path_rds("annot1"),          force = TRUE)
# unlink(ckpt_path_qs("subclust"),        force = TRUE); unlink(ckpt_path_rds("subclust"),        force = TRUE)
# unlink(ckpt_path_qs("annot2_and_final"), force = TRUE)
# unlink(ckpt_path_rds("annot2_and_final"), force = TRUE)
# unlink(ckpt_path_qs("final"), force = TRUE)
# unlink(ckpt_path_rds("final"), force = TRUE)

auto_run(resume = TRUE)

# Convenience diagnostics:
# diag <- if (ckpt_has("grid"))
#   ckpt_load("grid")$obj@misc$grid_diag
# else
#   NULL

# Convenience: pull final object (saved in annot2_and_final)
if (ckpt_has("annot2_and_final")) {
  result_obj <- ckpt_load("annot2_and_final")$obj
  message("Final object available as `result_obj` (labels in '$cluster_key_final').")
}
if (exists(".annot_obj_lean", envir = .GlobalEnv)) {
  obj_lean <- get(".annot_obj_lean", envir = .GlobalEnv)
  X <- Seurat::GetAssayData(obj_lean, assay = cfg$base_assay, layer = "data")
  print(class(X))
  print(dim(X))
  print(object.size(X))
  str(obj_lean@misc$annot_avg_exp)
}
st <- ckpt_load("annot1")
v <- as.character(st$obj$cluster_key)   # st <- ckpt_load("annot1")
names(v)
sc <- ckpt_load("subclust")
table(names(sc))["obj"]  # will be 2

#result_obj
#result_obj$final_consensus
# 1) Confirm final key has children
table(grepl("\\.", result_obj$cluster_key_final))
head(sort(unique(
  as.character(result_obj$cluster_key_final)
)))

# 2) Confirm per-method unprefixed columns exist
head(colnames(result_obj@meta.data)[colnames(result_obj@meta.data) %in%
                                      c(
                                        "avg_exp",
                                        "hypergeom",
                                        "majority",
                                        "logfc",
                                        "cellmanam",
                                        "hypergeomX",
                                        "gsea",
                                        "ucell",
                                        "final_consensus",
                                        "final_agree"
                                      )])

# 3) Excel: look for a 'final_annot' sheet and per-cluster 'final_<cluster>' sheets (with DEGs)
table(grepl("\\.", result_obj$cluster_key_final))           # subclusters present
head(result_obj@meta.data$avg_exp)                           # unprefixed methods exist
"final_annot" %in% openxlsx::getSheetNames(cfg$paths$out_xlsx)
st <- ckpt_load("annot1")
obj <- st$obj
Idents(obj) <- obj$cluster_key
obj2 <- subcluster_one(
  obj_in = obj,
  parent_key = "g60",
  cluster_key_name = "cluster_key",
  new_key_name = "cluster_key_v2"
)
# Should not error; either split or skip cleanly:
table(startsWith(as.character(obj2$cluster_key_v2), "g60."))
unique(result_obj$annot_final)
#result_obj$cluster_key_final
warnings()
head(result_obj$cluster_key_final)
unique(result_obj$cluster_key_final)
unique(result_obj$cluster_key_v2)
unique(result_obj$final_consensus)

script_5.R


# Marker statistics per cluster 
#   1) Computes average expression and % detected for a given gene set per cluster.
#   2) Builds a long, join-ready marker table from Seurat::DotPlot() data.
#   3) Builds a per-cluster marker summary (top markers passing a % threshold).
#   4) Optionally maps a curated per-cluster annotation file back to the Seurat object.

# ----------------------------
# 0) Minimal dependencies
# ----------------------------
stopifnot(requireNamespace("Seurat", quietly = TRUE))
stopifnot(requireNamespace("Matrix", quietly = TRUE))
stopifnot(requireNamespace("dplyr", quietly = TRUE))
stopifnot(requireNamespace("openxlsx", quietly = TRUE))

# ----------------------------
# 1) Small helpers
# ----------------------------

# Seurat v4/v5 compatibility: GetAssayData() uses slot (v4) or layer (v5).
.get_data <- function(obj, assay, layer_or_slot = "data") {
  # Seurat v5 uses `layer=`; v4 uses `slot=`.
  fmls <- names(formals(Seurat::GetAssayData))
  if ("layer" %in% fmls) {
    return(Seurat::GetAssayData(obj, assay = assay, layer = layer_or_slot))
  }
  Seurat::GetAssayData(obj, assay = assay, slot = layer_or_slot)
}

# Remove leading "g" only when it precedes a Greek letter (to keep g1/g2 etc intact).
.strip_g_before_greek <- function(x) {
  x <- as.character(x)
  sub("^g(?=\\p{Greek})", "", x, perl = TRUE)
}

# Safe column extraction from meta.data
.get_md <- function(obj, col) {
  stopifnot(col %in% colnames(obj@meta.data))
  obj@meta.data[[col]]
}

# ----------------------------
# 2) Core computations
# ----------------------------

# Compute average expression + percent detected per group for gene list.
#
compute_avg_and_pct <- function(obj,
                                genes,
                                group_col,
                                assay = Seurat::DefaultAssay(obj),
                                layer_or_slot = "data",
                                fix_group_names = TRUE) {
  stopifnot(is.character(genes), length(genes) > 0)
  stopifnot(group_col %in% colnames(obj@meta.data))

  genes_use <- intersect(genes, rownames(obj))
  if (!length(genes_use)) stop("None of the provided genes are present in the object.", call. = FALSE)

  grp_raw <- .get_md(obj, group_col)
  grp <- as.character(grp_raw)
  if (fix_group_names) grp <- .strip_g_before_greek(grp)

  # Drop NA groups 
  keep_cells <- !is.na(grp)
  grp <- grp[keep_cells]

  mat <- .get_data(obj, assay = assay, layer_or_slot = layer_or_slot)
  mat <- mat[genes_use, keep_cells, drop = FALSE]

  idx_by_grp <- split(seq_along(grp), grp)

  avg_mat <- do.call(cbind, lapply(idx_by_grp, function(ii) {
    Matrix::rowMeans(mat[, ii, drop = FALSE])
  }))

  pct_mat <- do.call(cbind, lapply(idx_by_grp, function(ii) {
    Matrix::rowMeans(mat[, ii, drop = FALSE] > 0)
  }))

  colnames(avg_mat) <- names(idx_by_grp)
  colnames(pct_mat) <- names(idx_by_grp)
  rownames(avg_mat) <- rownames(mat)
  rownames(pct_mat) <- rownames(mat)

  stopifnot(identical(colnames(avg_mat), colnames(pct_mat)))

  list(
    genes_use = genes_use,
    group_levels = colnames(avg_mat),
    avg_mat = avg_mat,
    pct_mat = pct_mat
  )
}

# For each gene: report top1/top2 groups and the delta between them.
rank_top_groups_per_gene <- function(avg_mat, pct_mat) {
  stopifnot(all(dim(avg_mat) == dim(pct_mat)))

  genes <- rownames(avg_mat)
  out <- lapply(genes, function(g) {
    v <- avg_mat[g, ]
    ord <- order(v, decreasing = TRUE, na.last = TRUE)

    top1 <- names(v)[ord[1]]
    top2 <- if (length(ord) >= 2) names(v)[ord[2]] else NA_character_

    data.frame(
      gene      = g,
      top1_pop  = top1,
      top1_avg  = unname(v[ord[1]]),
      top1_pct  = unname(pct_mat[g, top1]),
      top2_pop  = top2,
      top2_avg  = if (!is.na(top2)) unname(v[ord[2]]) else NA_real_,
      top2_pct  = if (!is.na(top2)) unname(pct_mat[g, top2]) else NA_real_,
      delta12   = if (!is.na(top2)) unname(v[ord[1]] - v[ord[2]]) else NA_real_,
      stringsAsFactors = FALSE
    )
  })

  dplyr::bind_rows(out) |>
    dplyr::arrange(dplyr::desc(.data$delta12), dplyr::desc(.data$top1_avg))
}

# Build a long marker table for a marker reference XLSX using DotPlot() output.
build_marker_stats_table <- function(obj,
                                     markers_xlsx,
                                     marker_gene_col = "Markers_positive_SMESG",
                                     group_col = "cluster_key_final",
                                     assay = "SCT",
                                     layer_or_slot = "data") {
  stopifnot(file.exists(markers_xlsx))
  stopifnot(group_col %in% colnames(obj@meta.data))

  # Read marker reference and keep only the relevant join columns.
  mk <- openxlsx::read.xlsx(markers_xlsx)
  names(mk) <- make.unique(names(mk))

  need_cols <- c(
    marker_gene_col,
    "Cell_population_general",
    "Cell_population_detailed",
    "Markers_positive_common.name"
  )
  miss <- setdiff(need_cols, names(mk))
  if (length(miss)) stop("Missing columns in marker XLSX: ", paste(miss, collapse = ", "), call. = FALSE)

  mk_join <- mk[, need_cols, drop = FALSE]
  mk_join[[marker_gene_col]] <- trimws(gsub("\\t", "", mk_join[[marker_gene_col]]))
  mk_join <- unique(mk_join)

  # Restrict to markers present in the object (exact match; keep ".1" etc intact).
  feat_all <- rownames(.get_data(obj, assay = assay, layer_or_slot = layer_or_slot))
  marker_raw <- unique(mk_join[[marker_gene_col]])
  marker_raw <- marker_raw[!is.na(marker_raw) & marker_raw != ""]
  marker_present <- marker_raw[marker_raw %in% feat_all]

  if (!length(marker_present)) stop("No marker genes from XLSX were found in the object (assay/layer mismatch?).", call. = FALSE)

  # Cluster size table
  n_cells_tbl <- as.data.frame(table(obj[[group_col, drop = TRUE]]), stringsAsFactors = FALSE)
  colnames(n_cells_tbl) <- c(group_col, "n_cells")

  # DotPlot gives per-group mean and % detected (in the plotting data). 
  dp <- Seurat::DotPlot(
    object   = obj,
    features = marker_present,
    group.by = group_col,
    assay    = assay
  )$data

  # Normalize column names across Seurat versions
  stats_tbl <- dp |>
    dplyr::transmute(
      !!group_col := .data$id,
      gene        = .data$features.plot,
      mean_expr   = .data$avg.exp,
      pct_expr    = .data$pct.exp
    ) |>
    dplyr::left_join(n_cells_tbl, by = group_col) |>
    dplyr::left_join(mk_join, by = setNames(marker_gene_col, "gene")) |>
    dplyr::arrange(.data[[group_col]], dplyr::desc(.data$pct_expr), dplyr::desc(.data$mean_expr))

  stats_tbl
}

# Summarize markers per cluster for quick manual review.
build_cluster_marker_summary <- function(stats_tbl,
                                         group_col = "cluster_key_final",
                                         thr = 20) {
  stopifnot(all(c(group_col, "gene", "pct_expr", "mean_expr", "n_cells") %in% colnames(stats_tbl)))

  stats_tbl |>
    dplyr::mutate(
      Markers_positive_common.name = trimws(gsub("\\s+", " ", .data$Markers_positive_common.name))
    ) |>
    dplyr::filter(
      !is.na(.data$Markers_positive_common.name),
      .data$Markers_positive_common.name != "",
      .data$pct_expr >= thr
    ) |>
    dplyr::group_by(.data[[group_col]]) |>
   
    dplyr::mutate(
      name_n_in_cluster = ave(.data$Markers_positive_common.name, .data$Markers_positive_common.name, FUN = length),
      marker_label = ifelse(
        .data$name_n_in_cluster > 1,
        paste0(.data$Markers_positive_common.name, " [", .data$gene, "]"),
        .data$Markers_positive_common.name
      )
    ) |>
    dplyr::summarise(
      n_cells = dplyr::first(.data$n_cells),
      n_markers_passing = dplyr::n_distinct(.data$Markers_positive_common.name),
      markers_common = paste(
        .data$marker_label[order(-.data$pct_expr, .data$marker_label)],
        sprintf("(%.1f%%)", .data$pct_expr[order(-.data$pct_expr, .data$marker_label)]),
        sep = " ",
        collapse = "; "
      ),
      markers_SMESG = paste(sort(unique(.data$gene)), collapse = "; "),
      .groups = "drop"
    ) |>
    dplyr::arrange(dplyr::desc(.data$n_markers_passing), .data[[group_col]])
}

# Compute mode + purity of an existing annotation per cluster.
mode_and_purity_by_cluster <- function(obj, cluster_col, anno_col) {
  md <- obj@meta.data |>
    dplyr::select(dplyr::all_of(c(cluster_col, anno_col))) |>
    dplyr::filter(!is.na(.data[[cluster_col]]), !is.na(.data[[anno_col]]))

  md |>
    dplyr::count(.data[[cluster_col]], .data[[anno_col]], name = "n_label") |>
    dplyr::group_by(.data[[cluster_col]]) |>
    dplyr::mutate(
      n_cluster = sum(.data$n_label),
      purity = .data$n_label / .data$n_cluster
    ) |>
    dplyr::slice_max(order_by = .data$n_label, n = 1, with_ties = FALSE) |>
    dplyr::ungroup() |>
    dplyr::transmute(
      !!cluster_col := .data[[cluster_col]],
      anno_mode = .data[[anno_col]],
      anno_purity = round(100 * .data$purity, 1),
      n_cells_meta = .data$n_cluster
    )
}

# Map curated per-cluster annotations back onto cell
apply_curated_cluster_annotation <- function(obj,
                                             curated_xlsx,
                                             cluster_col = "cluster_key_final",
                                             out_col = "final_population",
                                             fallback = "Unknown") {
  stopifnot(file.exists(curated_xlsx))
  stopifnot(cluster_col %in% colnames(obj@meta.data))

  cur <- openxlsx::read.xlsx(curated_xlsx)
  if (ncol(cur) < 2) stop("Curated XLSX must have at least 2 columns: cluster and final_population.", call. = FALSE)

  # Standardize expected names: first col = cluster, second col = final_population.
  names(cur)[1:2] <- c("cluster", "final_population")
  cur <- cur[!duplicated(cur$cluster) & !is.na(cur$cluster), c("cluster", "final_population"), drop = FALSE]

  map <- setNames(as.character(cur$final_population), as.character(cur$cluster))
  cl <- as.character(obj[[cluster_col, drop = TRUE]])

  anno <- unname(map[cl])
  anno[is.na(anno) | anno == ""] <- fallback

  obj[[out_col]] <- factor(anno, levels = unique(cur$final_population))
  obj
}



# Parameters
cluster_col <- "cluster_key_final"      # cluster labels
assay_use   <- "SCT"                    # or Seurat::DefaultAssay(result_obj)
layer_use   <- "data"                   # "data" for log-normalized/SCT

# Marker XLSX workflow
markers_xlsx <- "G:/PhD_final/tables/cell_markers_curated_new_new_new_new.xlsx" #Table S4

# Output paths
out_stats_xlsx   <- "G:/PhD_final/tables/cluster_marker_metrics_by_cluster_key_final.xlsx"
out_summary_xlsx <- "G:/PhD_final/tables/cluster_marker_summary2.xlsx"

# Optional curated mapping
curated_xlsx <- "G:/PhD_final/tables/cluster_marker_summary_verified.xlsx"
out_obj_rdata <- "G:/PhD_final/result_obj_new.RData"

# --- Build marker statistics and write to XLSX ---
stats_tbl <- build_marker_stats_table(
  obj = result_obj,
  markers_xlsx = markers_xlsx,
  marker_gene_col = "Markers_positive_SMESG",
  group_col = cluster_col,
  assay = assay_use,
  layer_or_slot = layer_use
)
openxlsx::write.xlsx(stats_tbl, file = out_stats_xlsx, overwrite = TRUE)

# --- Summarize markers per cluster (for manual review) ---
cluster_marker_summary <- build_cluster_marker_summary(stats_tbl, group_col = cluster_col, thr = 20)

anno_col_existing <- "final_population_fixed"
if (anno_col_existing %in% colnames(result_obj@meta.data)) {
  old_anno <- mode_and_purity_by_cluster(result_obj, cluster_col = cluster_col, anno_col = anno_col_existing)
  cluster_marker_summary <- cluster_marker_summary |>
    dplyr::left_join(old_anno, by = cluster_col) |>
    dplyr::relocate(.data$anno_mode, .data$anno_purity, .after = dplyr::all_of(cluster_col))
}

openxlsx::write.xlsx(cluster_marker_summary, file = out_summary_xlsx, overwrite = TRUE)

# --- Apply curated annotation back to Seurat object ---
if (file.exists(curated_xlsx)) {
  result_obj <- apply_curated_cluster_annotation(
    obj = result_obj,
    curated_xlsx = curated_xlsx,
    cluster_col = cluster_col,
    out_col = "final_population",
    fallback = "Unknown"
  )

  # Basic cleanup 
  result_obj$final_population <- factor(trimws(as.character(result_obj$final_population)))

  save(result_obj, file = out_obj_rdata)
}

script_7.sh

#!/usr/bin/env bash
set -euo pipefail
#set -x
# ---------- SM paths ----------
BASEDIR="/mnt/d/scRNA-seq/small_RNA"
DATADIR="${BASEDIR}/Data"
REFDIR="${BASEDIR}/reference"
REF_FASTA="${REFDIR}/SM_ncRNA_filtered.fa"
BOWTIE_PREFIX="${REFDIR}/SM_ncRNA_filtered_bowtie2_index"  
#GTF="${REFDIR}/final_anno_only_ncRNA.gtf"                       
THREADS=8
ADAPTER="AGATCGGAAGAGCACACGTCTGAACTCCAGTCAC"
MINLEN=14
# ----------------------------------------

# ---------- Conda env  ----------
source "$(conda info --base)/etc/profile.d/conda.sh"
conda activate srna_env

# ---------- Directory layout ----------
OUT="${BASEDIR}/GenXPro_our"
RAWQC="${OUT}/qc/raw"
TRIMQC="${OUT}/qc/trimmed"
TRIM1="${OUT}/trimmed_adapter_q"       # pass1
CLEAN="${OUT}/cleaned_poly"            # pass2
EXTRACT="${OUT}/umi_extracted"         # UMI after trimming
DEDUP="${OUT}/dedup_pre_align"         # pre-align dedup
MAP="${OUT}/map_bowtie2"
COUNT="${OUT}/counts"
TPM="${OUT}/tpm"
LOGS="${OUT}/logs"
mkdir -p "$OUT" "$RAWQC" "$TRIMQC" "$TRIM1" "$CLEAN" "$EXTRACT" "$DEDUP" "$MAP" "$COUNT" "$TPM" "$LOGS"

# ---------- Helper: relaxed UMI extractor (8nt 5′ UMI, optional 4nt 3′ UMI) ----------
# Keeps reads lacking a recognizable 3′ UMI by assigning UMI3=NNNN
set -x
echo "UMI extraction (fast; parallel)"
n_in=$(ls -1 "${CLEAN}"/*.clean.fastq.gz 2>/dev/null | wc -l); [ "$n_in" -gt 0 ] || { echo "No inputs in ${CLEAN}"; exit 1; }

EXTRACT_PY="${OUT}/_umi_extract_relaxed.py"
cat > "$EXTRACT_PY" << 'PY'
import sys, gzip, os
from Bio.SeqIO.QualityIO import FastqGeneralIterator

inp, outp = sys.argv[1], sys.argv[2]
os.makedirs(os.path.dirname(outp), exist_ok=True)

def open_in(p):
    return gzip.open(p, "rt") if p.endswith(".gz") else open(p, "r")

with open_in(inp) as fh, gzip.open(outp, "wb") as out:
    for h, s, q in FastqGeneralIterator(fh):
        if len(s) < 9:
            continue
        umi5 = s[:8]; body = s[8:]
        if len(body) >= 5:
            umi3 = body[-4:]; ins = body[:-4]; qins = q[8:-4]
        else:
            umi3 = "NNNN"; ins = body; qins = q[8:8+len(ins)]
        if not ins:
            continue
        out.write(f"@{h.rstrip()} UMI:{umi5}-{umi3}\n{ins}\n+\n{qins}\n".encode())
PY


echo "UMI extraction (Biopython; 5' UMI required, 3' UMI optional)"
for fq in "${CLEAN}"/*.clean.fastq.gz; do
 base=$(basename "$fq" .clean.fastq.gz)
 python -u "${EXTRACT_PY}" "$fq" "${EXTRACT}/${base}.umi.fastq.gz" \
   2> "${LOGS}/${base}.umi_extract.stderr" || { echo "UMI extract failed for $base"; exit 1; }
 ls -lh "${EXTRACT}/${base}.umi.fastq.gz" | cat
 break
done

for f in "${EXTRACT}"/*.umi.fastq.gz; do
  tmp="${f%.umi.fastq.gz}.umi.fixed.fastq.gz"
  gzip -cd "$f" | awk 'NR%4==1{$0="@"$0}1' | gzip > "$tmp" && mv "$tmp" "$f"
done



# ---------- Helper: FASTQ dedup by exact (sequence + UMI pair) ----------
DEDUP_PY="${OUT}/_dedup_by_seq_umi.py"
cat > "$DEDUP_PY" << 'PY'
import sys, gzip, hashlib, os

inp, outp, logp = sys.argv[1], sys.argv[2], sys.argv[3]
os.makedirs(os.path.dirname(outp), exist_ok=True)

def opn_r(p): return gzip.open(p, "rt") if p.endswith(".gz") else open(p, "r")
def opn_w(p): return gzip.open(p, "wb") if p.endswith(".gz") else open(p, "wb")

def umi_from_header(h):
    # header starts with '@'; UMI stored like "... UMI:NNNNNNNN-NNNN"
    i = h.rfind("UMI:")
    return "" if i == -1 else h[i+4:].split()[0]

seen=set(); kept=dup=0
with opn_r(inp) as r, opn_w(outp) as w, open(logp,"w") as lg:
    while True:
        h = r.readline()
        if not h: break
        s = r.readline().rstrip("\n")
        plus = r.readline()
        q = r.readline().rstrip("\n")

        if not h.startswith("@"):    # guard against malformed records
            continue
        umi = umi_from_header(h.rstrip())
        key = hashlib.md5((s + "|" + umi).encode()).hexdigest()

        if key in seen:
            dup += 1
            continue
        seen.add(key); kept += 1

        w.write(h.encode())
        w.write((s + "\n").encode())
        w.write(plus.encode())
        w.write((q + "\n").encode())

    lg.write(f"kept\t{kept}\nremoved_duplicates\t{dup}\n")
PY


# ---------- FastQC raw ----------
fastqc -t $THREADS -o "$RAWQC" "${DATADIR}"/*.fastq.gz || true

# ---------- Cutadapt pass 1: adapter + quality (as in vendor report) ----------
echo "Cutadapt pass 1 (adapter+qtrim)"
for fq in "${DATADIR}"/*.fastq.gz; do
  base=$(basename "$fq" .fastq.gz)
  cutadapt -e 0.1 -O 3 -q 20 -m ${MINLEN} -n 8 \ #check!!!!!!!
           -a "${ADAPTER}" \
           -o "${TRIM1}/${base}.trim1.fastq.gz" "$fq" \
           > "${TRIM1}/${base}.trim1.log"
done

# ---------- Cutadapt pass 2: homopolymer cleaning ----------
echo "Cutadapt pass 2 (homopolymer cleaning)"
for fq in "${TRIM1}"/*.trim1.fastq.gz; do
  base=$(basename "$fq" .trim1.fastq.gz)
  cutadapt -a 'A{10};o=10' -a 'T{10};o=10' -a 'C{10};o=10' -a 'G{10};o=10' \ #check!!!!!!!
           -n 3 -m ${MINLEN} \
           -o "${CLEAN}/${base}.clean.fastq.gz" "$fq" \
           > "${CLEAN}/${base}.clean.log"
done

# ---------- UMI extraction AFTER trimming (relaxed) ----------
echo "UMI extraction (fast; 5' UMI required, 3' UMI optional)"
# sanity: inputs present?
n_in=$(ls -1 "${CLEAN}"/*.clean.fastq.gz 2>/dev/null | wc -l)
if [ "$n_in" -eq 0 ]; then
  echo "ERROR: No inputs in ${CLEAN}/*.clean.fastq.gz" >&2; exit 1
fi

# parallel if available, else serial
if command -v parallel >/dev/null 2>&1; then
  ls "${CLEAN}"/*.clean.fastq.gz \
  | sed 's#.*/##; s/.clean.fastq.gz$//' \
  | parallel -j ${THREADS} '
      python -u "'"${EXTRACT_PY}"'" \
        "'"${CLEAN}"'"/{}.clean.fastq.gz \
        "'"${EXTRACT}"'"/{}.umi.fastq.gz \
      2> "'"${LOGS}"'"/{}.umi_extract.stderr
    '
else
  for fq in "${CLEAN}"/*.clean.fastq.gz; do
    base=$(basename "$fq" .clean.fastq.gz)
    python -u "${EXTRACT_PY}" "$fq" "${EXTRACT}/${base}.umi.fastq.gz" \
      2> "${LOGS}/${base}.umi_extract.stderr"
  done
fi


# ---------- Deduplicate BEFORE mapping: exact (UMI pair + insert) ----------
echo "Deduplicate (pre-align) by (UMI pair + insert sequence)"
for fq in "${EXTRACT}"/*.umi.fastq.gz; do
  base=$(basename "$fq" .umi.fastq.gz)
  python "${DEDUP_PY}" "$fq" "${DEDUP}/${base}.dedup.fastq.gz" "${DEDUP}/${base}.dedup.stats.txt"
done

# ---------- Build Bowtie2 index if missing ----------
if [ ! -e "${BOWTIE_PREFIX}.1.bt2" ] && [ ! -e "${BOWTIE_PREFIX}.1.bt2l" ]; then
  echo "Building Bowtie2 index for ${REF_FASTA}"
  bowtie2-build "${REF_FASTA}" "${BOWTIE_PREFIX}"
fi

# ---------- Map to ncRNA with Bowtie2 --sensitive --local ----------
echo "Bowtie2 mapping (--sensitive --local) to ncRNA"
mkdir -p "${MAP}"
for fq in "${DEDUP}"/*.dedup.fastq.gz; do
  base=$(basename "$fq" .dedup.fastq.gz)
  echo $base
  bowtie2 --threads ${THREADS} --sensitive --local \
          -x "${BOWTIE_PREFIX}" -U "$fq" \
    2> "${MAP}/${base}.bowtie2.log" \
  | samtools view -b -F 4 - \
  | samtools sort -@4 -o "${MAP}/${base}.sorted.bam"
  samtools index "${MAP}/${base}.sorted.bam"
done

# ---------- GTF (autogenerate if absent) ----------
if [ ! -s "${GTF}" ]; then
  echo "No GTF found at ${GTF}. Autogenerating from FASTA headers (single exon per record)."
  GTF="${REFDIR}/autogen_ncRNA_from_fasta.gtf"
  python - "$REF_FASTA" "$GTF" << 'PYCODE'
import sys, gzip
fa, gtf = sys.argv[1], sys.argv[2]
def op(p): return gzip.open(p,'rt') if p.endswith('.gz') else open(p)
with op(fa) as fh, open(gtf,'w') as out:
    name=None; seq=[]
    def flush(nm, seqlen):
        if not nm: return
        out.write(f"{nm}\tgenxpro\ttranscript\t1\t{seqlen}\t.\t+\t.\ttranscript_id \"{nm}\"; gene_id \"{nm}\";\n")
    for line in fh:
        if line.startswith('>'):
            if name is not None:
                flush(name, len(''.join(seq)))
            name=line[1:].strip().split()[0]
            seq=[]
        else:
            seq.append(line.strip())
    if name is not None:
        flush(name, len(''.join(seq)))
PYCODE
fi

# ---------- htseq-count ----------
echo "htseq-count on BAMs"
mkdir -p "${COUNT}"
GTF="${REFDIR}/autogen_ncRNA_from_fasta.gtf"
for bam in "${MAP}"/*.sorted.bam; do
  base=$(basename "$bam" .sorted.bam)
  echo $base
  htseq-count -f bam -r pos -s no -a 0 -t transcript -i transcript_id \
    --nonunique=fraction \
    "$bam" "$GTF" > "${COUNT}/${base}.counts.txt"
  #htseq-count \
  #  -f bam \
  #  -r pos \                # coordinate-sorted BAM
  #  -s no \                 # unstranded
  #  -a 0 \                  # no min AQual filter
  #  -t transcript \         # your GTF uses 'transcript'
  #  -i transcript_id \      # attribute you wrote
  #  "$bam" "$GTF" > "${COUNT}/${base}.counts.txt"
done



featureCounts -T ${THREADS} -s 0 -t transcript -g transcript_id -M --fraction \
  -a "${REFDIR}/autogen_ncRNA_from_fasta.gtf" \
  -o "${COUNT}/featureCounts.transcript.txt" \
  ${MAP}/*.sorted.bam

# ---------- Merge counts and compute TPM ----------
echo "Merge counts and compute TPM"
samtools faidx "${REF_FASTA}"
mkdir -p "${TPM}"
cut -f1,2 "${REF_FASTA}.fai" > "${TPM}/lengths.tsv"

python - "${COUNT}" "${TPM}/lengths.tsv" "${TPM}" << 'PYCODE'
import sys, os, glob, csv
from collections import defaultdict
count_dir, len_path, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
lengths = {}
with open(len_path) as f:
    for line in f:
        tid, ln = line.rstrip().split('\t')[:2]
        lengths[tid] = float(ln)
samples=[]; counts=defaultdict(dict)
for fn in sorted(glob.glob(os.path.join(count_dir, "*.counts.txt"))):
    s=os.path.basename(fn).replace(".counts.txt",""); samples.append(s)
    with open(fn) as fh:
        for row in fh:
            if row.startswith("__"): continue
            tid,c=row.rstrip().split('\t'); counts[tid][s]=float(c)
all_ids=list(counts.keys())
with open(os.path.join(out_dir,"counts_matrix.tsv"),"w",newline="") as out:
    w=csv.writer(out,delimiter='\t'); w.writerow(["transcript_id"]+samples)
    for tid in all_ids: w.writerow([tid]+[int(counts[tid].get(s,0)) for s in samples])
def tpm_for_sample(s):
    rpk={}; 
    for tid in all_ids:
        ln=lengths.get(tid,0.0); c=counts[tid].get(s,0.0)
        rpk[tid]=0.0 if ln<=0 else c/(ln/1000.0)
    denom=sum(rpk.values()) or 1.0
    return {tid:(v/denom)*1e6 for tid,v in rpk.items()}
per={s:tpm_for_sample(s) for s in samples}
with open(os.path.join(out_dir,"tpm_matrix.tsv"),"w",newline="") as out:
    w=csv.writer(out,delimiter='\t'); w.writerow(["transcript_id"]+samples)
    for tid in all_ids: w.writerow([tid]+[f"{per[s][tid]:.6f}" for s in samples])
PYCODE

# ---------- MultiQC ----------
multiqc "${OUT}" -o "${OUT}/multiqc"

echo "Done."
echo "Outputs:"
echo "  - Trimmed:       ${TRIM1}"
echo "  - Cleaned:       ${CLEAN}"
echo "  - UMI-extracted: ${EXTRACT}"
echo "  - Deduped:       ${DEDUP}"
echo "  - BAMs:          ${MAP}"
echo "  - Counts:        ${COUNT}"
echo "  - TPM:           ${TPM}"

script_8.R

###############################################################################
# small RNA activity / enrichment in scRNA-seq 
#
# Core idea
# 1) Build sRNA→target gene sets by seed scanning against UTR (± CDS).
# 2) Score each cell by a control-matched module score for the target set.
#    - This reduces “cell type baseline” bias that often makes one population dominate.
# 3) Summarize per (celltype × timepoint × genotype) and test genotype effects
#    using a permutation test (directional, based on bulk sRNA change).
#
# Note
# - As there is only 1 library per (condition × timepoint), so genotype p-values are
#   exploratory (cells are not true biological replicates). Treat as prioritization.
###############################################################################

options(stringsAsFactors = FALSE)
set.seed(1)

suppressPackageStartupMessages({
  library(Seurat)
  library(Matrix)
  library(Biostrings)
  library(data.table)
  library(stringr)
  library("xlsx")
  library(ggplot2)
})
# ---------------------- sRNA sequence extraction -----------------------------
extract_seq_from_id <- function(id) {
  # pull a plausible nucleotide string; supports N and U
  s <- stringr::str_extract(id, "[ACGTUNacgtun]{15,100}")
  if (is.na(s)) return(NA_character_)
  s <- toupper(s)
  s <- chartr("U", "T", s)
  s
}
# ------------------------------- load data -----------------------------------
load("G:/PhD_final/result_obj_new.RData")                  # result_obj: Seurat obj
#load("G:/PhD_final/tRNA_miRNA_selected_raw_counts.RData")  # tRNA_miRNA_selected_raw_counts
load("G:/PhD_final/all_stringtie_selected.RData")          # all_stringtie_selected
load("D:/scRNA-seq/tRF_motif/cds_tx_seq.RData")            # cds_tx_seq
load("D:/scRNA-seq/tRF_motif/tx2gene.RData")               # tx2gene
load("D:/scRNA-seq/tRF_motif/utr_tx_seq.RData")            # utr_tx_seq
load("D:/scRNA-seq/AZ_final_obj/filtered_DEG_abr_new.RData")# filtered_DE (optional)
load("G:/PhD_final/final_bulk_DGE.RData")                  # final_test_DGE (optional)
#load("G:/PhD_final/tables/bulk_dir_tbl_LFC057.RData")      # bulk_dir_tbl (sRNA DE)

#all RNA type accumulation
not_rRNA=read.xlsx("D:/Elac2/final_results/tables/DGE_other_than_rRF_snRNA_filtered_new.xlsx",sheetIndex=1)
unique(not_rRNA$RNA_type)
bulk_dir_tbl <- not_rRNA
bulk_dir_tbl <- bulk_dir_tbl[bulk_dir_tbl$set=="Elac_vs_WT_dpa3" & bulk_dir_tbl$RNA_type %in% c("miRNAs","piRNAs","tRFs"),]
bulk_dir_tbl <- bulk_dir_tbl[,c("snRNA_type","Sequence","log2FoldChange")]
colnames(bulk_dir_tbl) <- c("snRNA_type","Sequence","bulk_log2FC_3dpa")
bulk_dir_tbl$Sequence <- chartr("U", "T", bulk_dir_tbl$Sequence)
bulk_dir_tbl$sRNA <- paste(bulk_dir_tbl$Sequence,bulk_dir_tbl$snRNA_type,sep=" ")

# ----------------------------- user parameters --------------------------------
OUT_DIR <- "G:/PhD_final/tables/srna_activity_sc"
if (!dir.exists(OUT_DIR)) dir.create(OUT_DIR, recursive = TRUE)

# Seurat parsing
ASSAY_USE <- "RNA"              # or "SCT" 
SLOT_USE  <- "data"             # "data" = log-normalized
MIN_CELLS_PER_GROUP <- 20       # per (celltype,timepoint,genotype)

# Which genotypes to compare
GENO_KEEP <- c("WT", "ELAC", "GFP")   

# Timepoints to analyze (NULL = all detected)
#TIMEPOINTS_TO_USE <- c(0, 16, 24, 72)
TIMEPOINTS_TO_USE <- 72
# Target scanning
USE_CDS_IN_SCAN <- TRUE
CDS_WEIGHT_MULT <- 0.5          # downweight CDS sites vs UTR
MIN_GENE_SCORE  <- 3            # minimum weighted site score to keep a gene
TOP_N_GENES     <- 500          # cap target set size (keeps scoring stable)

# Module score (control-matched)
N_BINS_EXPR     <- 24           # expression bins for control matching
CTRL_PER_TARGET <- 20           # controls sampled per target gene
MIN_TARGETS_FOR_SCORE <- 15

# Permutation test (run only on top strata per sRNA to keep runtime sane)
DO_PERMUTATION  <- TRUE
N_PERM          <- 2000
TEST_TOP_STRATA <- 15           # per (sRNA,model): compute perm p only for top |delta| strata

# If bulk_dir_tbl has expected_target_change (DOWN_in_ELAC/UP_in_ELAC) we will use it; else derive from bulk_log2FC_3dpa
# expected_target_change = "DOWN_in_ELAC" means sRNA is UP in ELAC (targets expected DOWN) => activity expected HIGHER in ELAC.

TARGET_MODELS <- c("miRNA_canonical", "piRNA_extended", "off1_7mer", "off2_7mer", "off3_7mer")

# Cache (target scanning is expensive)
TARGET_CACHE_RDS <- file.path(OUT_DIR, "targets_cache_corr.rds")

# ----------------------------- sanity checks ---------------------------------
stopifnot(exists("result_obj"))
seu <- result_obj
Seurat::DefaultAssay(seu) <- ASSAY_USE

stopifnot(exists("bulk_dir_tbl"))
bulk_dir_tbl <- as.data.table(bulk_dir_tbl)
stopifnot("sRNA" %in% names(bulk_dir_tbl))

if (!inherits(utr_tx_seq, "DNAStringSet")) utr_tx_seq <- Biostrings::DNAStringSet(utr_tx_seq)
if (!inherits(cds_tx_seq, "DNAStringSet")) cds_tx_seq <- Biostrings::DNAStringSet(cds_tx_seq)

# tx2gene can be a named vector or a 2-col df
as_tx2gene_named <- function(tx2gene_obj) {
  if (is.vector(tx2gene_obj) && !is.null(names(tx2gene_obj))) return(tx2gene_obj)
  if (is.data.frame(tx2gene_obj)) {
    cn <- tolower(colnames(tx2gene_obj))
    tx_col   <- colnames(tx2gene_obj)[match(TRUE, cn %in% c("tx","transcript","transcript_id","tx_id"))]
    gene_col <- colnames(tx2gene_obj)[match(TRUE, cn %in% c("gene","gene_id","geneid"))]
    if (is.na(tx_col) || is.na(gene_col)) stop("tx2gene needs transcript and gene columns.")
    v <- as.character(tx2gene_obj[[gene_col]])
    names(v) <- as.character(tx2gene_obj[[tx_col]])
    return(v)
  }
  stop("Unrecognized tx2gene format.")
}
tx2gene_map <- as_tx2gene_named(tx2gene)

# ------------------------- metadata column discovery --------------------------
pick_col <- function(df, candidates) {
  hit <- candidates[candidates %in% colnames(df)]
  if (length(hit) > 0) return(hit[[1]])
  NA_character_
}

meta <- seu@meta.data

CELLTYPE_COL <- pick_col(meta, c("final_population","celltype_use","celltype","CellType","seurat_clusters"))
COND_COL     <- pick_col(meta, c("condition_correct","sc_condition_full","condition","orig.ident"))

unique(meta$final_population)
if (is.na(CELLTYPE_COL) || is.na(COND_COL)) {
  stop("Could not find metadata columns for celltype and condition. Update CELLTYPE_COL / COND_COL candidates.")
}

# Parse strings like "WT13S", "ELAC24S", "GFP0S" -> genotype + timepoint
parse_condition <- function(x) {
  x <- as.character(x)
  base <- stringr::str_extract(x, "^[A-Za-z]+")
  tp   <- suppressWarnings(as.integer(stringr::str_extract(x, "[0-9]+")))
  data.table(genotype = base, timepoint = tp, condition_full = x)
}

parsed <- parse_condition(meta[[COND_COL]])
seu$genotype <- parsed$genotype
seu$timepoint <- parsed$timepoint
seu$celltype_use <- as.character(meta[[CELLTYPE_COL]])

# keep timepoints
tp_present <- sort(unique(seu$timepoint[!is.na(seu$timepoint)]))
TP_USE <- if (is.null(TIMEPOINTS_TO_USE)) tp_present else intersect(tp_present, TIMEPOINTS_TO_USE)
if (length(TP_USE) == 0) stop("No requested timepoints found in Seurat metadata.")

# subset
cells_keep <- which(seu$genotype %in% GENO_KEEP & seu$timepoint %in% TP_USE & !is.na(seu$celltype_use))
seu_use <- subset(seu, cells = rownames(seu@meta.data)[cells_keep])

# group filter: require MIN_CELLS_PER_GROUP per (celltype,timepoint,genotype)
md <- as.data.table(seu_use@meta.data, keep.rownames = "cell")
gcnt <- md[, .N, by=.(celltype_use, timepoint, genotype)]
good_groups <- gcnt[N >= MIN_CELLS_PER_GROUP]
md <- md[good_groups, on=.(celltype_use, timepoint, genotype), nomatch=0L]
seu_use <- subset(seu_use, cells = md$cell)

md <- as.data.table(seu_use@meta.data, keep.rownames = "cell")

# -------------------------- expression matrix --------------------------------
Seurat::DefaultAssay(seu_use) <- ASSAY_USE
expr <- Seurat::GetAssayData(seu_use, slot = SLOT_USE)
if (!inherits(expr, "dgCMatrix")) expr <- as(expr, "dgCMatrix")


rc_dna <- function(dna) {
  as.character(Biostrings::reverseComplement(Biostrings::DNAString(dna)))
}

# ---------------------------- target models ----------------------------------
get_patterns_miRNA_canonical <- function(seq_dna) {
  if (is.na(seq_dna) || nchar(seq_dna) < 8) return(character(0))
  s7 <- substr(seq_dna, 2, 8)  # 2–8
  s6 <- substr(seq_dna, 2, 7)  # 2–7
  c(
    "7mer-m8" = rc_dna(s7),
    "7mer-1A" = paste0("A", rc_dna(s6)),
    "8mer-1A" = paste0("A", rc_dna(s7))
  )
}

get_patterns_piRNA_extended <- function(seq_dna) {
  if (is.na(seq_dna) || nchar(seq_dna) < 12) return(character(0))
  s7  <- substr(seq_dna, 2, 8)
  s10 <- substr(seq_dna, 2, 11)
  s11 <- substr(seq_dna, 2, 12)
  c(
    "seed_2_8" = rc_dna(s7),
    "ext_2_11" = rc_dna(s10),
    "ext_2_12" = rc_dna(s11)
  )
}

get_patterns_offset_kmer <- function(seq_dna, offset = 1, k = 7) {
  if (is.na(seq_dna) || nchar(seq_dna) < (offset + k - 1)) return(character(0))
  s <- substr(seq_dna, offset, offset + k - 1)
  setNames(rc_dna(s), paste0(k, "mer_off", offset))
}

model_def <- list(
  miRNA_canonical = function(seq_dna) {
    pats <- get_patterns_miRNA_canonical(seq_dna)
    w <- c("7mer-m8"=2, "7mer-1A"=1, "8mer-1A"=3)
    list(patterns=pats, weights=w)
  },
  piRNA_extended = function(seq_dna) {
    pats <- get_patterns_piRNA_extended(seq_dna)
    w <- c("seed_2_8"=1, "ext_2_11"=3, "ext_2_12"=4)
    list(patterns=pats, weights=w)
  },
  off1_7mer = function(seq_dna) {
    pats <- get_patterns_offset_kmer(seq_dna, offset=1, k=7)
    w <- setNames(2, names(pats))
    list(patterns=pats, weights=w)
  },
  off2_7mer = function(seq_dna) {
    pats <- get_patterns_offset_kmer(seq_dna, offset=2, k=7)
    w <- setNames(2, names(pats))
    list(patterns=pats, weights=w)
  },
  off3_7mer = function(seq_dna) {
    pats <- get_patterns_offset_kmer(seq_dna, offset=3, k=7)
    w <- setNames(2, names(pats))
    list(patterns=pats, weights=w)
  }
)

# ---------------------- fast scanning with PDict ------------------------------
count_weighted_hits <- function(subject_seqs, patterns_named, weights_named, fixed = TRUE) {
  if (length(patterns_named) == 0) return(setNames(numeric(0), character(0)))
  
  nm <- intersect(names(patterns_named), names(weights_named))
  patterns_named <- patterns_named[nm]
  weights_named  <- weights_named[nm]
  if (length(patterns_named) == 0) return(setNames(numeric(0), character(0)))
  
  subj_names <- names(subject_seqs)
  if (is.null(subj_names)) subj_names <- as.character(seq_along(subject_seqs))
  
  score <- setNames(numeric(length(subject_seqs)), subj_names)
  
  wlen <- nchar(unname(patterns_named))
  idx_by_len <- split(seq_along(patterns_named), wlen)
  
  ns <- length(subject_seqs)
  
  for (idx in idx_by_len) {
    pats <- Biostrings::DNAStringSet(unname(patterns_named[idx]))
    names(pats) <- names(patterns_named)[idx]
    np <- length(pats)
    
    cnt <- tryCatch(
      {
        pd <- Biostrings::PDict(pats)
        Biostrings::vcountPDict(pd, subject_seqs, fixed = fixed)
      },
      error = function(e) {
        # vapply returns subjects × patterns when np > 1
        vapply(seq_along(pats), function(i) {
          Biostrings::vcountPattern(pats[[i]], subject_seqs, fixed = fixed)
        }, FUN.VALUE = integer(ns))
      }
    )
    
    # ---- normalize cnt to patterns × subjects ----
    if (is.null(dim(cnt))) {
      # single pattern: vector of length ns
      cnt <- matrix(as.integer(cnt), nrow = 1, ncol = ns)
    } else {
      cnt <- as.matrix(cnt)
      # if subjects × patterns, transpose
      if (nrow(cnt) == ns && ncol(cnt) == np) cnt <- t(cnt)
    }
    
    # final sanity check
    if (nrow(cnt) != np || ncol(cnt) != ns) {
      stop(sprintf(
        "Unexpected count matrix shape: %d×%d; expected %d×%d (patterns×subjects).",
        nrow(cnt), ncol(cnt), np, ns
      ))
    }
    
    rownames(cnt) <- names(pats)
    colnames(cnt) <- subj_names
    
    w2 <- as.numeric(weights_named[rownames(cnt)])
    w2[is.na(w2)] <- 0
    
    score <- score + as.numeric(matrix(w2, nrow = 1) %*% cnt)
  }
  
  score
}



scan_targets_genelevel <- function(seq_dna,
                                   utr_seqs, cds_seqs, tx2gene_named,
                                   patterns, weights,
                                   use_cds = TRUE,
                                   cds_mult = 0.5,
                                   min_gene_score = 3,
                                   top_n_genes = 400) {
  if (is.na(seq_dna) || length(patterns) == 0) return(data.table(gene=character(0), weight=numeric(0)))
  
  utr_score <- count_weighted_hits(utr_seqs, patterns, weights, fixed=FALSE)
  tx_score <- utr_score
  
  if (isTRUE(use_cds)) {
    cds_score <- count_weighted_hits(cds_seqs, patterns, weights, fixed=FALSE)
    # align / add
    tx_all <- union(names(tx_score), names(cds_score))
    out <- setNames(numeric(length(tx_all)), tx_all)
    out[names(tx_score)] <- out[names(tx_score)] + tx_score
    out[names(cds_score)] <- out[names(cds_score)] + cds_mult * cds_score
    tx_score <- out
  }
  
  tx_keep <- names(tx_score)[tx_score >= min_gene_score]
  if (length(tx_keep) == 0) return(data.table(gene=character(0), weight=numeric(0)))
  
  g <- as.character(tx2gene_named[tx_keep])
  ok <- !is.na(g) & nzchar(g)
  if (!any(ok)) return(data.table(gene=character(0), weight=numeric(0)))
  
  # sum transcript scores per gene
  gene_score <- tapply(tx_score[tx_keep[ok]], g[ok], sum)
  gene_score <- sort(gene_score, decreasing = TRUE)
  gene_score <- gene_score[seq_len(min(length(gene_score), top_n_genes))]
  
  data.table(gene = names(gene_score), weight = as.numeric(gene_score))
}

# ----------------------- build / load target cache ----------------------------
srnas_to_run <- unique(bulk_dir_tbl$sRNA)
srna_seq_tbl <- data.table(sRNA = srnas_to_run)
srna_seq_tbl[, seq_dna := vapply(sRNA, extract_seq_from_id, character(1))]
#if (file.exists(TARGET_CACHE_RDS)) file.remove(TARGET_CACHE_RDS)

if (file.exists(TARGET_CACHE_RDS)) {
  message("Loading target cache: ", TARGET_CACHE_RDS)
  target_cache <- readRDS(TARGET_CACHE_RDS)
} else {
  message("Building targets (can be slow). Will save cache to: ", TARGET_CACHE_RDS)
  
  target_cache <- list()  # names: "<sRNA>__<model>" ; value: data.table(gene, weight)
  for (sid in srna_seq_tbl$sRNA) {
    seq_dna <- srna_seq_tbl[sRNA == sid, seq_dna][1]
    if (is.na(seq_dna)) next
    
    for (m in TARGET_MODELS) {
      def <- model_def[[m]](seq_dna)
      if (length(def$patterns) == 0) next
      
      tg <- scan_targets_genelevel(
        seq_dna = seq_dna,
        utr_seqs = utr_tx_seq,
        cds_seqs = cds_tx_seq,
        tx2gene_named = tx2gene_map,
        patterns = def$patterns,
        weights = def$weights,
        use_cds = USE_CDS_IN_SCAN,
        cds_mult = CDS_WEIGHT_MULT,
        min_gene_score = MIN_GENE_SCORE,
        top_n_genes = TOP_N_GENES
      )
      
      # keep only genes present in scRNA matrix
      tg <- tg[gene %in% rownames(expr)]
      if (nrow(tg) == 0) next
      
      key <- paste0(sid, "__", m)
      target_cache[[key]] <- tg
    }
  }
  saveRDS(target_cache, TARGET_CACHE_RDS)
}

if (length(target_cache) == 0) stop("No targets in cache. Check sequence parsing and scan thresholds.")

#saveRDS(target_cache, "G:/PhD_final/tables/srna_activity_sc/targets_cache_copy.rds")
#TARGET_CACHE_RDS <- file.path(OUT_DIR, "targets_cache.rds")
# ------------------ expression bins for control matching ----------------------
# Use global average expression to bin genes
# Use global average expression to bin genes
gene_means <- Matrix::rowMeans(expr)

# Ensure gene IDs are present as names (robust even if Matrix drops them)
if (is.null(names(gene_means))) names(gene_means) <- rownames(expr)

gene_means <- gene_means[is.finite(gene_means)]
gene_means <- gene_means[intersect(names(gene_means), rownames(expr))]

qs <- quantile(gene_means, probs = seq(0, 1, length.out = N_BINS_EXPR + 1), na.rm = TRUE)
qs <- unique(qs)
if (length(qs) < 5) stop("Expression binning failed (too few unique quantiles).")

gene_bin <- cut(gene_means, breaks = qs, include.lowest = TRUE, labels = FALSE)

# Fix: restore names so split() has gene IDs to split
names(gene_bin) <- names(gene_means)

gene_bin <- gene_bin[!is.na(gene_bin)]
genes_by_bin <- split(names(gene_bin), gene_bin)


seed_from_string <- function(s) {
  x <- utf8ToInt(s)
  as.integer((sum(x) + 131 * length(x)) %% .Machine$integer.max)
}

pick_controls_matched <- function(target_genes, ctrl_per_target = 20L, seed = 1L) {
  tg <- unique(target_genes)
  tg <- tg[tg %in% names(gene_bin)]
  if (length(tg) == 0) return(character(0))
  
  set.seed(seed)
  ctrl <- character(0)
  for (g in tg) {
    b <- gene_bin[[g]]
    pool <- setdiff(genes_by_bin[[as.character(b)]], tg)
    if (length(pool) == 0) next
    take <- min(length(pool), ctrl_per_target)
    ctrl <- c(ctrl, sample(pool, size = take, replace = FALSE))
  }
  unique(ctrl)
}

# ---------------------- activity scoring (per cell) ---------------------------
# Score = (weighted mean target expr) - (mean matched-control expr)
# Activity (repression) = -Score, so that “targets DOWN” => activity increases.
weighted_mean_expr <- function(expr_mat, genes, weights = NULL) {
  if (length(genes) == 0) return(rep(NA_real_, ncol(expr_mat)))
  genes <- genes[genes %in% rownames(expr_mat)]
  if (length(genes) == 0) return(rep(NA_real_, ncol(expr_mat)))
  
  sub <- expr_mat[genes, , drop = FALSE]
  if (is.null(weights)) {
    return(Matrix::colMeans(sub))
  }
  w <- weights[match(genes, names(weights))]
  w[is.na(w)] <- 0
  if (sum(w) <= 0) return(Matrix::colMeans(sub))
  w <- w / sum(w)
  
  # t(sub) %*% w  -> vector per cell
  as.numeric(Matrix::t(sub) %*% w)
}

score_srna_activity <- function(expr_mat, tg_dt, ctrl_per_target=20L, min_targets=15L, seed=1L) {
  tg <- unique(tg_dt$gene)
  if (length(tg) < min_targets) return(rep(NA_real_, ncol(expr_mat)))
  
  w <- tg_dt$weight
  names(w) <- tg_dt$gene
  w <- w[w > 0]
  
  ctrl <- pick_controls_matched(tg, ctrl_per_target = ctrl_per_target, seed = seed)
  if (length(ctrl) < min_targets) return(rep(NA_real_, ncol(expr_mat)))
  
  s_tg   <- weighted_mean_expr(expr_mat, tg, weights = w)
  s_ctrl <- weighted_mean_expr(expr_mat, ctrl, weights = NULL)
  
  score <- s_tg - s_ctrl
  activity <- -score
  activity
}

# ---------------------- permutation test helper -------------------------------
perm_test_delta_mean <- function(x, g01, nperm=2000L, seed=1L, alternative=c("two.sided","greater","less")) {
  alternative <- match.arg(alternative)
  ok <- is.finite(x) & !is.na(g01)
  x <- x[ok]
  g01 <- g01[ok]
  if (length(unique(g01)) != 2) return(list(p=NA_real_, obs=NA_real_))
  
  obs <- mean(x[g01 == 1]) - mean(x[g01 == 0])
  
  set.seed(seed)
  n <- length(x)
  g <- as.integer(g01)
  null <- numeric(nperm)
  for (i in seq_len(nperm)) {
    gp <- sample(g, size = n, replace = FALSE)
    null[i] <- mean(x[gp == 1]) - mean(x[gp == 0])
  }
  
  if (alternative == "two.sided") {
    p <- (1 + sum(abs(null) >= abs(obs))) / (nperm + 1)
  } else if (alternative == "greater") {
    p <- (1 + sum(null >= obs)) / (nperm + 1)
  } else {
    p <- (1 + sum(null <= obs)) / (nperm + 1)
  }
  
  list(p=p, obs=obs)
}

# ------------------- expected direction from bulk table -----------------------
if (!("expected_target_change" %in% names(bulk_dir_tbl))) {
  if (!("bulk_log2FC_3dpa" %in% names(bulk_dir_tbl))) stop("bulk_dir_tbl needs expected_target_change or bulk_log2FC_3dpa")
  bulk_dir_tbl[, expected_target_change := ifelse(bulk_log2FC_3dpa > 0, "DOWN_in_ELAC", "UP_in_ELAC")]
}
# expected activity delta sign: DOWN_in_ELAC -> activity higher in ELAC (+)
bulk_dir_tbl[, expected_activity_sign := ifelse(expected_target_change == "DOWN_in_ELAC", +1, -1)]

# ------------------- run: per sRNA × model activity summary -------------------
# Prepare strata ids
md[, strata := paste(celltype_use, timepoint, sep="||")]
md[, geno01 := ifelse(genotype == "ELAC", 1L, 0L)]

# Keep only strata with both genotypes present and enough cells
strata_ok <- md[, .(n_WT = sum(geno01 == 0), n_ELAC = sum(geno01 == 1)), by=strata]
strata_ok <- strata_ok[n_WT >= MIN_CELLS_PER_GROUP & n_ELAC >= MIN_CELLS_PER_GROUP]
md <- md[strata %in% strata_ok$strata]

if (nrow(strata_ok) == 0) stop("No strata pass MIN_CELLS_PER_GROUP for both WT and ELAC.")

# Main results collector
res_list <- list()

keys <- names(target_cache)
message("Scoring ", length(keys), " (sRNA,model) target sets.")


for (key in keys) {
  sid   <- sub("__.*$", "", key)
  model <- sub("^.*__", "", key)
  
  tg_dt <- target_cache[[key]]
  if (is.null(tg_dt) || nrow(tg_dt) < MIN_TARGETS_FOR_SCORE) next
  
  seed_key <- seed_from_string(key)
  
  # score activity
  activity <- score_srna_activity(
    expr_mat = expr,
    tg_dt = tg_dt,
    ctrl_per_target = CTRL_PER_TARGET,
    min_targets = MIN_TARGETS_FOR_SCORE,
    seed = seed_key
  )
  if (is.null(activity) || all(is.na(activity))) next
  
  # attach to metadata (keep only cells that exist in expr)
  dt <- md[, .(cell, celltype_use, timepoint, genotype, geno01, strata)]
  dt <- dt[cell %in% colnames(expr)]
  
  if (!is.null(names(activity))) {
    dt[, activity := activity[cell]]
  } else {
    dt[, activity := activity[match(cell, colnames(expr))]]
  }
  
  # summarize means per group
  sm <- dt[, .(
    n = .N,
    mean_activity = mean(activity, na.rm = TRUE),
    sd_activity = sd(activity, na.rm = TRUE)
  ), by = .(celltype_use, timepoint, genotype, strata)]
  
  # wide tables
  smw  <- data.table::dcast(sm,  celltype_use + timepoint + strata ~ genotype, value.var = "mean_activity")
  smn  <- data.table::dcast(sm,  celltype_use + timepoint + strata ~ genotype, value.var = "n")
  smsd <- data.table::dcast(sm,  celltype_use + timepoint + strata ~ genotype, value.var = "sd_activity")
  
  # need WT and ELAC to compute delta
  if (!all(c("WT", "ELAC") %in% names(smw))) next
  if (!all(c("WT", "ELAC") %in% names(smn))) next
  if (!all(c("WT", "ELAC") %in% names(smsd))) next
  
  smw[, delta_ELAC_minus_WT := ELAC - WT]
  smw[, n_WT   := smn[["WT"]]]
  smw[, n_ELAC := smn[["ELAC"]]]
  
  # effect size (Cohen's d) using pooled SD
  pooled_sd <- sqrt(
    ((smw$n_WT - 1) * (smsd[["WT"]]^2) + (smw$n_ELAC - 1) * (smsd[["ELAC"]]^2)) /
      pmax(smw$n_WT + smw$n_ELAC - 2, 1)
  )
  pooled_sd <- pmax(pooled_sd, 1e-8)
  smw[, cohen_d := delta_ELAC_minus_WT / pooled_sd]
  
  # expected direction info from bulk (3 dpa)
  exp <- bulk_dir_tbl[sRNA == sid]
  if (nrow(exp) == 0) next
  
  exp_sign <- exp$expected_activity_sign[1]
  smw[, expected_activity_sign := exp_sign]
  smw[, expected_activity_change := ifelse(exp_sign > 0, "DOWN_in_ELAC", "UP_in_ELAC")]
  smw[, bulk_log2FC_3dpa := exp$bulk_log2FC_3dpa[1]]
  smw[, expected_target_change := exp$expected_target_change[1]]
  
  # expected-direction filter at stratum level
  smw[, delta_expected := delta_ELAC_minus_WT * expected_activity_sign]
  smw[, ok_dir := is.finite(delta_expected) & (delta_expected > 0)]
  
  # permutation p-values
  smw[, perm_p_two_sided := NA_real_]
  smw[, perm_p_expected  := NA_real_]
  
  if (isTRUE(DO_PERMUTATION)) {
    
    strata_to_test <- unique(smw[ok_dir == TRUE, strata])
    
    for (st in strata_to_test) {
      sub <- dt[strata == st & is.finite(activity)]
      
      # require both groups to have enough cells
      n0 <- sub[, sum(geno01 == 0L)]
      n1 <- sub[, sum(geno01 == 1L)]
      if (n0 < MIN_CELLS_PER_GROUP || n1 < MIN_CELLS_PER_GROUP) next
      
      # two-sided
      pt2 <- perm_test_delta_mean(
        x = sub$activity,
        g01 = sub$geno01,
        nperm = N_PERM,
        seed = seed_key + seed_from_string(st),
        alternative = "two.sided"
      )
      
      # directional (expected sign)
      alt <- if (exp_sign > 0) "greater" else "less"
      ptd <- perm_test_delta_mean(
        x = sub$activity,
        g01 = sub$geno01,
        nperm = N_PERM,
        seed = seed_key + 7L + seed_from_string(st),
        alternative = alt
      )
      
      smw[strata == st, perm_p_two_sided := pt2$p]
      smw[strata == st, perm_p_expected  := ptd$p]
    }
  }
  
  # annotate and store
  smw[, sRNA := sid]
  smw[, model := model]
  smw[, n_targets := nrow(tg_dt)]
  
  res_list[[length(res_list) + 1L]] <- smw
}

res <- rbindlist(res_list, fill = TRUE)
if (nrow(res) == 0) stop("No results produced. Check target cache + scoring thresholds.")

# multiple testing correction within each (sRNA,model) over strata
res[, perm_p_expected_BH := p.adjust(perm_p_expected, method = "BH"), by=.(sRNA, model)]
res[, perm_p_two_sided_BH := p.adjust(perm_p_two_sided, method = "BH"), by=.(sRNA, model)]

# prioritization score that respects expected direction (only meaningful where perm_p_expected computed)
res[, score_expected := expected_activity_sign * delta_ELAC_minus_WT * (-log10(perm_p_expected_BH + 1e-300))]

# write outputs
#fwrite(res, file.path(OUT_DIR, "srna_activity_moduleScore_controlMatched_72_only_correct_delta.tsv"), sep = "\t")
#save(res,file="G:/PhD_final/tables/srna_activity_moduleScore_controlMatched_72_only_correct_delta.RData")
#save(activity,file="G:/PhD_final/tables/srna_activity_72_only_correct_delta.RData")
# “best hit” per (sRNA,model) using BH directional p, then |delta|
best <- res[order(perm_p_expected_BH, -abs(delta_ELAC_minus_WT))]
best <- best[, .SD[1], by=.(sRNA, model)]
#fwrite(best, file.path(OUT_DIR, "srna_activity_bestHit.tsv"), sep = "\t")

best$consistensy <- ifelse(((best$ELAC<best$GFP) & (best$ELAC<best$WT)) |
                             ((best$ELAC>best$GFP) & (best$ELAC>best$WT)),
                           TRUE,FALSE)
best <- best[best$consistensy==TRUE,]
#fwrite(best, file.path(OUT_DIR, "srna_activity_bestHit.tsv"), sep = "\t")
best[best$sRNA=="GCATCGGTGGTTCAGTGGTAGAATGCTCGCCT 5'-tiRNA-Gly-GCC",]
#load("G:/PhD_final/tables/srna_activity_72_only_correct_delta.RData")#activity

library(scCustomize)
result_obj


DETAILED_COLS <- c(
  # --- Epidermis (blue) ---
  "Early epidermal progenitor"             = "#ABCEDF",
  "Late epidermal progenitor"              = "#6693C4",
  "Epidermis (broad)"                      = "#5685BD",
  "Epidermis (multiciliated)"              = "#3468B0",
  "Epidermal progenitor (multiciliated)"   = "#4576B6",
  "Epidermal secretory gland progenitor"   = "#99BFD8",
  
  # --- Eye (cyan) ---
  "Eye progenitor"                         = "#00D3D3",
  
  # --- Intestine / immune-like (green) ---
  "Basal cell"                             = "#D3E4BA",
  "Goblet cell"                            = "#8DB180",
  "Phagocyte (broad)"                     = "#497F46",
  
  
  # --- Pigment (brown) ---
  "Body pigment progenitor"                = "#C08A5A",
  "Body pigment cell"                      = "#8E5A3C",
  
  # --- Muscle (red) ---
  "Muscle progenitor"                      = "#F2C7CA",
  "BWM (dorsal midline)"                   = "#E6A3A6",
  "ECM-producing muscle"                   = "#CC5C5D",
  "Posterior pole/PCG muscle"              = "#C1393A",
  
  # --- Neoblasts (grey) ---
  "σ-neoblast (broad-lineage)"             = "#E3E3E3",
  "Muscle neoblast"                        = "#D5D5D5",
  "Protonephridial neoblast"               = "#B8B8B8",
  "Pharyngeal neoblast"                    = "#7F7F7F",
  "γ-neoblast (intestinal-fated)"          = "#AAAAAA",
  "ζ-neoblast (epidermal-fated)"           = "#8D8D8D",
  "ν-neoblast (neural-fated)"              = "#707070",
  "GLIRP-1⁺ parenchymal neoblast"          = "#B3B3B3",
  "PGRN⁺ parenchymal neoblast"             = "#BEBEBE",
  "FER3L-2⁺ parenchymal neoblast"          = "#C7C7C7",
  
  # --- Neural lineage (violet) ---
  "Neural progenitor (broad)"              = "#E9E4FA",
  "Glutamatergic neural progenitor"        = "#DED5F6",
  "Neuropeptidergic neural progenitor"     = "#D3C7F2",
  "Mechanosensory neural progenitor"       = "#C9BAEE",
  "PKD⁺ sensory neural progenitor"         = "#BFADE8",
  "Glia"                                   = "#AF98D5",
  "Brain branch neuron"                    = "#D8CCF3",
  "Catecholaminergic neuron"               = "#CDBFEB",
  "Cholinergic neuron"                     = "#C3B2E4",
  "Glutamatergic neuron"                   = "#A58ACE",
  "Mechanosensory neuron"                  = "#9171BF",
  "Neuropeptidergic neuron"                = "#8764B8",
  "PKD⁺ sensory neuron"                    = "#7349A9",
  "Serotonergic neuron"                    = "#693CA2",
  
  # --- Parenchyma (sand) ---
  "AQP⁺ parenchymal cell"                  = "#F9E29D",
  "LDLRR-1⁺ parenchymal cell"              = "#F3D38E",
  "GLIRP-1⁺ parenchymal progenitor"        = "#EDC281",
  "PSAP⁺ parenchymal progenitor"           = "#F1D6A0",
  "PSAP⁺ parenchymal cell"                 = "#EBB670",
  "PGRN⁺ parenchymal cell"                 = "#EFC57F",
  "FER3L-2⁺ parenchymal progenitor"        = "#E8B567",
  "NKX2⁺ parenchymal progenitor"           = "#E39F55",
  "PTF⁺ head parenchymal progenitor"       = "#E6A860",
  "SSPO⁺ parenchymal progenitor"           = "#E19A51",
  "SSPO⁺ parenchymal cell"                 = "#DD8B42",
  "Abraçada cell"                          = "#FAE8B4",
  
  # --- Pharynx (yellow) ---
  "Pharyngeal epithelium"                  = "#FFFF00",
  "Pharyngeal progenitor"                  = "#D9D900",
  "Pharyngeal phagocytic-type cell"        = "#C9C900",
  
  # --- Protonephridia (wine) ---
  "Protonephridial flame cell"             = "#874A68",
  "Protonephridial tubule cell"            = "#87345F"
)


################################################################################
# Rescaled plot_sncRNA_activity_umap
?scale_fill_gradient2
plot_sncRNA_activity_umap_rescaled <- function(
    seu,
    activity,
    sid,
    model,
    focus_pops = unlist(c(unique(subset(best,sRNA==sid)[,3]))), 
    detailed_cols,
    reduction = "umap.d33.nn100.md0.3",
    pop_col = "final_population",
    cond_col = "condition_correct",
    facet_map = NULL,
    facet_levels = NULL,
    title = NULL,
    rel_widths = c(1, 2),
    
    # palette (will be pastelized by default)
    low_col  = "#A6CEE3",
    mid_col  = "#FFFFFF",
    high_col = "#FB9A99",
    pastelize = TRUE,
    pastel_strength = 0.2,   # 0..1, higher = lighter
    
    # point sizes
    pt_bg = 0.10,
    pt_pop = 0.25,
    pt_act = 0.55,
    
    # alphas and background
    alpha_bg = 0.45,
    alpha_act = 0.95,
    bg_col = "grey88",
    
    # legends + text
    legend_dot_size = 3,
    strip_text_size = 11,
    legend_text_size  = 12,
    legend_title_size = 13,
    title_size        = 14,
    title_gap_lines   = 6,
    
    # activity visualization controls
    activity_mode = c("raw", "delta_vs_ref"),
    ref_condition = "WT72",
    delta_within = c("pop", "global"),
    act_clip_q = c(0.05, 0.95),
    act_limits = NULL,
    diverging_symmetric = TRUE,
    midpoint = 0,
    
    # transform (new ggplot2 arg name is transform, not trans)
    act_transform = c("auto", "identity", "modulus"),
    modulus_p = 0.7,
    
    # drawing style
    use_fill = FALSE,         # default FALSE to avoid dark outlines from shape 21
    outline_col = NA,         # NA removes border for shape 21
    outline_stroke = 0,
    order_by_abs = TRUE
) {
  # ---- helpers ----
  .mix_with_white <- function(col, w = 0.55) {
    if (is.na(col) || !nzchar(col)) return(col)
    rgb <- grDevices::col2rgb(col) / 255
    rgb2 <- rgb * (1 - w) + 1 * w
    grDevices::rgb(rgb2[1], rgb2[2], rgb2[3])
  }
  
  # ---- checks ----
  stopifnot(inherits(seu, "Seurat"))
  stopifnot(!is.null(names(activity)))
  stopifnot(is.character(sid), length(sid) == 1)
  stopifnot(is.character(model), length(model) == 1)
  stopifnot(is.character(reduction), length(reduction) == 1)
  stopifnot(is.character(pop_col), length(pop_col) == 1)
  stopifnot(is.character(cond_col), length(cond_col) == 1)
  
  if (missing(detailed_cols) || is.null(detailed_cols)) {
    stop("Provide 'detailed_cols' (named vector: population -> color).")
  }
  if (!all(focus_pops %in% names(detailed_cols))) {
    stop("Some focus_pops are missing from names(detailed_cols): ",
         paste(setdiff(focus_pops, names(detailed_cols)), collapse = ", "))
  }
  
  activity_mode <- match.arg(activity_mode)
  delta_within  <- match.arg(delta_within)
  act_transform <- match.arg(act_transform)
  
  # ---- pastelize palette (reduces “dark” look on dense UMAPs) ----
  if (isTRUE(pastelize)) {
    w <- max(0, min(1, pastel_strength))
    low_col  <- .mix_with_white(low_col,  w = w)
    mid_col  <- .mix_with_white(mid_col,  w = w / 2)   # keep midpoint close to white
    high_col <- .mix_with_white(high_col, w = w)
  }
  
  # ---- activity column name ----
  act_col <- paste0("act__", make.names(paste(sid, model, sep = "__")))
  
  # ---- align activity to Seurat cells and store ----
  cells <- Seurat::Cells(seu)
  act_vec <- activity[cells]
  names(act_vec) <- cells
  
  n_match <- sum(is.finite(act_vec))
  if (n_match == 0) {
    stop("No overlap between names(activity) and Cells(seu). ",
         "Check: head(names(activity)) vs head(Cells(seu)).")
  }
  seu[[act_col]] <- act_vec
  
  # ---- embeddings ----
  if (!reduction %in% names(seu@reductions)) {
    stop("Reduction '", reduction, "' not found. Available: ",
         paste(names(seu@reductions), collapse = ", "))
  }
  um <- Seurat::Embeddings(seu, reduction = reduction)[, 1:2, drop = FALSE]
  colnames(um) <- c("UMAP_1", "UMAP_2")
  
  df <- data.table::as.data.table(um, keep.rownames = "cell")
  
  meta <- Seurat::FetchData(seu, vars = c(pop_col, cond_col, act_col))
  meta$cell <- rownames(meta)
  
  df <- data.table::merge.data.table(
    df,
    data.table::as.data.table(meta),
    by = "cell"
  )
  
  data.table::setnames(df, pop_col, "pop")
  data.table::setnames(df, cond_col, "cond")
  data.table::setnames(df, act_col, "activity_raw")
  
  # ---- facet labels (plotmath strings) ----
  if (is.null(facet_map)) {
    facet_map <- c(
      "ELAC72" = "paste(italic('Smed ELAC2'), ' KD 72 hpa')",
      "GFP72"  = "'GFP mock 72 hpa'",
      "WT72"   = "'WT 72 hpa'"
    )
  }
  
  df[, condition_facet := {
    cc <- as.character(cond)
    out <- unname(facet_map[cc])
    out[is.na(out)] <- shQuote(cc[is.na(out)])
    out
  }]
  
  if (is.null(facet_levels)) {
    base_levels <- unname(facet_map)
    extras <- setdiff(unique(df$condition_facet), base_levels)
    facet_levels <- c(base_levels, extras)
  }
  df[, condition_facet := factor(condition_facet, levels = facet_levels)]
  
  # ---- focus flags ----
  df[, is_focus := pop %in% focus_pops]
  
  # ---- compute activity to plot (raw or delta) ----
  df[, activity_plot := as.numeric(NA)]
  df[is_focus == TRUE, activity_plot := activity_raw]
  
  df_focus <- df[is_focus == TRUE & is.finite(activity_plot)]
  if (nrow(df_focus) == 0) stop("No finite activity values found within focus_pops.")
  
  if (activity_mode == "delta_vs_ref") {
    if (!ref_condition %in% unique(df$cond)) {
      stop("ref_condition '", ref_condition, "' not found in cond. Present: ",
           paste(sort(unique(df$cond)), collapse = ", "))
    }
    
    ref_dt <- df[is_focus == TRUE & cond == ref_condition & is.finite(activity_plot)]
    if (nrow(ref_dt) == 0) {
      stop("No finite focus activity values in ref_condition = '", ref_condition, "'.")
    }
    
    if (delta_within == "pop") {
      base <- ref_dt[, .(baseline = stats::median(activity_plot, na.rm = TRUE)), by = pop]
      df <- data.table::merge.data.table(df, base, by = "pop", all.x = TRUE)
      df[is_focus == TRUE & is.finite(activity_plot), activity_plot := activity_plot - baseline]
      df[, baseline := NULL]
    } else {
      baseline <- stats::median(ref_dt$activity_plot, na.rm = TRUE)
      df[is_focus == TRUE & is.finite(activity_plot), activity_plot := activity_plot - baseline]
    }
    
    df_focus <- df[is_focus == TRUE & is.finite(activity_plot)]
  }
  
  # ---- order so extremes draw on top ----
  if (order_by_abs) df_focus <- df_focus[order(abs(activity_plot))]
  
  # ---- robust limits (clip + squish) ----
  vals <- df_focus$activity_plot
  
  if (!is.null(act_limits)) {
    if (!is.numeric(act_limits) || length(act_limits) != 2) {
      stop("act_limits must be numeric length 2, e.g. c(-0.2, 0.2).")
    }
    lims <- sort(as.numeric(act_limits))
  } else {
    qq <- stats::quantile(vals, probs = act_clip_q, na.rm = TRUE, names = FALSE)
    lims <- sort(as.numeric(qq))
  }
  
  if (diverging_symmetric) {
    max_abs <- max(abs(lims - midpoint))
    lims <- midpoint + c(-max_abs, max_abs)
  }
  
  # ---- pick transform object ----
  act_transform_obj <- switch(
    act_transform,
    "identity" = "identity",
    "modulus"  = scales::transform_modulus(p = modulus_p),
    "auto"     = if (activity_mode == "delta_vs_ref") scales::transform_modulus(p = modulus_p) else "identity"
  )
  
  # ---- right panel background data ----
  df_nonfocus <- df[is_focus == FALSE]
  df_focus_na <- df[is_focus == TRUE & !is.finite(activity_plot)]
  
  # ---- left: highlight focus pops ----
  p_left <- ggplot2::ggplot(df, ggplot2::aes(UMAP_1, UMAP_2)) +
    ggplot2::geom_point(color = bg_col, size = pt_bg, alpha = alpha_bg) +
    ggplot2::geom_point(
      data = df[df$pop %in% focus_pops],
      ggplot2::aes(color = pop),
      size = pt_pop,
      alpha = 1
    ) +
    ggplot2::scale_color_manual(
      values = detailed_cols[focus_pops],
      breaks = focus_pops,
      name = "Population"
    ) +
    ggplot2::guides(
      color = ggplot2::guide_legend(override.aes = list(size = legend_dot_size, alpha = 1))
    ) +
    ggplot2::theme_void() +
    ggplot2::theme(
      legend.position = "right",
      legend.title = ggplot2::element_text(size = legend_title_size),
      legend.text  = ggplot2::element_text(size = legend_text_size),
      text = ggplot2::element_text(size = 11)
    )
  
  # ---- right: non-focus grey, focus colored by activity_plot ----
  scale_name <- if (activity_mode == "delta_vs_ref") "Δ Activity" else "Activity"
  
  if (use_fill) {
    p_right <- ggplot2::ggplot(df, ggplot2::aes(UMAP_1, UMAP_2)) +
      ggplot2::geom_point(data = df_nonfocus, color = bg_col, size = pt_bg, alpha = alpha_bg) +
      ggplot2::geom_point(data = df_focus_na, color = bg_col, size = pt_bg, alpha = alpha_bg) +
      ggplot2::geom_point(
        data = df_focus,
        ggplot2::aes(fill = activity_plot),
        shape = 21,
        color = outline_col,
        stroke = outline_stroke,
        size = pt_act,
        alpha = alpha_act
      ) +
      ggplot2::scale_fill_gradient2(
        name = scale_name,
        low = low_col, mid = mid_col, high = high_col,
        midpoint = midpoint,
        limits = lims,
        oob = scales::squish,
        transform = act_transform_obj
      ) +
      ggplot2::facet_wrap(~ condition_facet, nrow = 1, labeller = ggplot2::label_parsed) +
      ggplot2::theme_void() +
      ggplot2::theme(
        plot.margin = ggplot2::margin(t = 0, r = 18, b = 0, l = 0),  # more space on the right
        legend.box.margin = ggplot2::margin(t = 0, r = 6, b = 0, l = 6),
        legend.title = ggplot2::element_text(size = legend_title_size),
        legend.text  = ggplot2::element_text(size = legend_text_size),
        strip.text   = ggplot2::element_text(size = strip_text_size)
      )
      # ggplot2::theme(
      #   legend.position = "right",
      #   legend.title = ggplot2::element_text(size = legend_title_size),
      #   legend.text  = ggplot2::element_text(size = legend_text_size),
      #   strip.text   = ggplot2::element_text(size = strip_text_size)
      # )
  } else {
    p_right <- ggplot2::ggplot(df, ggplot2::aes(UMAP_1, UMAP_2)) +
      ggplot2::geom_point(data = df_nonfocus, color = bg_col, size = pt_bg, alpha = alpha_bg) +
      ggplot2::geom_point(data = df_focus_na, color = bg_col, size = pt_bg, alpha = alpha_bg) +
      ggplot2::geom_point(
        data = df_focus,
        ggplot2::aes(color = activity_plot),
        size = pt_act,
        alpha = alpha_act
      ) +
      ggplot2::scale_color_gradient2(
        name = scale_name,
        low = low_col, mid = mid_col, high = high_col,
        midpoint = midpoint,
        limits = lims,
        oob = scales::squish,
        transform = act_transform_obj
      ) +
      ggplot2::facet_wrap(~ condition_facet, nrow = 1, labeller = ggplot2::label_parsed) +
      ggplot2::theme_void() +
      # ggplot2::theme(
      #   legend.position = "right",
      #   legend.title = ggplot2::element_text(size = legend_title_size),
      #   legend.text  = ggplot2::element_text(size = legend_text_size),
      #   strip.text   = ggplot2::element_text(size = strip_text_size)
      # )
      ggplot2::theme(
        plot.margin = ggplot2::margin(t = 0, r = 18, b = 0, l = 0),  # more space on the right
        legend.box.margin = ggplot2::margin(t = 0, r = 6, b = 0, l = 6),
        legend.title = ggplot2::element_text(size = legend_title_size),
        legend.text  = ggplot2::element_text(size = legend_text_size),
        strip.text   = ggplot2::element_text(size = strip_text_size)
      )
  }
  
  # ---- combine + title ----
  core <- cowplot::plot_grid(
    p_left, p_right,
    ncol = 2,
    rel_widths = rel_widths,
    align = "h"
  )
  
  if (is.null(title)) title <- paste0(sid, " | ", model)
  
  title_grob <- cowplot::ggdraw() +
    cowplot::draw_label(
      title,
      x = 0.05, y = 1,
      hjust = 0, vjust = 1.25,
      fontface = "bold",
      size = title_size
    ) +
    ggplot2::theme(
      plot.margin = ggplot2::margin(t = 0, r = 0, b = title_gap_lines, l = 0)
    )
  
  final <- cowplot::plot_grid(
    title_grob,
    core,
    ncol = 1,
    rel_heights = c(0.10, 1)
  )
  
  list(
    plot = final,
    p_left = p_left,
    p_right = p_right,
    act_col = act_col,
    df = df,
    df_focus = df_focus,
    seu = seu,
    activity_limits = lims,
    activity_mode = activity_mode,
    ref_condition = ref_condition,
    act_transform = act_transform_obj
  )
}

out1 <- plot_sncRNA_activity_umap_rescaled(
  seu = seu_use,
  activity = activity,
  sid = "GCATCGGTGGTTCAGTGGTAGAATGCTCGCCT 5'-tiRNA-Gly-GCC",
  model = "miRNA_canonical",
  focus_pops = focus_pops,
  detailed_cols = DETAILED_COLS,
  title = "5'-tiRNA-Gly-GCC (GCATCGGTGGTTCAGTGGTAGAATGCTCGCCT)",
  activity_mode = "delta_vs_ref",
  ref_condition = "WT72",
  delta_within = "pop",
  act_clip_q = c(0.10, 0.90),
  low_col = "blue",
  high_col = "red",
  mid_col = "white"
)
#out1_new$plot



out2 <- plot_sncRNA_activity_umap_rescaled(
  seu = seu_use,
  activity = activity,
  sid = "TCTTTGGTTTTCTAGC sme-miR-9a-5p",
  model = "off1_7mer",
  focus_pops = focus_pops,
  detailed_cols = DETAILED_COLS,
  title = "miR-9a-5p (TCTTTGGTTTTCTAGC)",
  activity_mode = "delta_vs_ref",
  ref_condition = "WT72",
  delta_within = "pop",
  act_clip_q = c(0.10, 0.90),
  low_col = "blue",
  high_col = "red",
  mid_col = "white"
)


out4 <- plot_sncRNA_activity_umap_rescaled(
  seu = seu_use,
  activity = activity,
  sid = "CACATGACATGTATACTCTACAAACGCAC piRNA",
  model = "piRNA_extended",
  focus_pops = focus_pops,
  detailed_cols = DETAILED_COLS,
  title = "piRNA (CACATGACATGTATACTCTACAAACGCAC)",
  activity_mode = "delta_vs_ref",
  ref_condition = "WT72",
  delta_within = "pop",
  act_clip_q = c(0.10, 0.90),
  low_col = "blue",
  high_col = "red",
  mid_col = "white"
)

out3 <- plot_sncRNA_activity_umap_rescaled(
  seu = seu_use,
  activity = activity,
  sid = "ACCACTGACCGAGCATATCC sme-miR-190a-3p",
  model = "off1_7mer",
  focus_pops = focus_pops,
  detailed_cols = DETAILED_COLS,
  title = "miR-190a-3p (ACCACTGACCGAGCATATCC)",
  activity_mode = "delta_vs_ref",
  ref_condition = "WT72",
  delta_within = "pop",
  act_clip_q = c(0.10, 0.90),
  low_col = "blue",
  high_col = "red",
  mid_col = "white"
)


library(ggpubr)
ggarrange(out1$plot,
          out2$plot,
          out3$plot,
          out4$plot,
          ncol = 1,
          labels = c("A","B","C","D"),
          common.legend = TRUE)