DR. ATABAK KH

Cloud Platform Modernization Architect specializing in transforming legacy systems into reliable, observable, and cost-efficient Cloud platforms.

Certified: Google Professional Cloud Architect, AWS Solutions Architect, MapR Cluster Administrator

Goal: Make results re-runnable and comparable (CAFA-style).

1) Data & splits

  • Sources: UniProt sequences, GOA annotations; optional PPI (STRING/BioGRID).
  • Time-based split: train <= T0, test (T0, T1); prevents future leakage.
  • Remove evidence types you disallow (e.g., IEA) from both sides.

2) Label processing

  • Propagate labels up the DAG (ancestor closure).
  • Filter ultra-rare terms or keep and use class_weight.
  • Store class mapping + DAG snapshot with a version tag.

3) Features

  • PLM embeddings (cached .npy): N x D float32; L2-normalize.
  • PPI: sparse CSR, symmetrically normalized adjacency A.
  • Optional text TF-IDF for weak supervision.

4) Models

  • Baselines: kNN/BLAST label transfer; logistic regression (one-vs-rest).
  • Advanced: MLP (2-3 layers), GCN/GAT on PPI; late fusion of logits.
  • Calibration: isotonic or temperature scaling on validation.

5) Metrics & plots

  • Fmax (BP/MF/CC), micro/macro-auPRC, coverage.
  • Calibration: reliability curves, ECE.
  • Save results.json and PR curves (PNG/SVG).

6) Repro artifacts

  • environment.yml / requirements.txt (pin PLM version).
  • Makefile: prepare, embed, train, eval, plot.
  • Seed control: PYTHONHASHSEED, NumPy, torch.
prepare:
python scripts/prepare.py --cutoff 2024-12-31
embed:
python scripts/embed_plm.py --model esm2_t33 --out data/emb.npy
train:
python scripts/train.py --cfg cfgs/logreg.yaml
eval:
python scripts/eval.py --cfg cfgs/logreg.yaml --out results/logreg.json
plot:
python scripts/plot_pr.py --in results/logreg.json --out figs/pr.png

Error analysis (quick wins)

  • Confusion by information content (IC) bins (rare vs common terms).
  • Per-ontology breakdown (BP/MF/CC).
  • Check hierarchy violations before/after closure step.

Deliverables: zipped results/ + cfgs/ + short README = review-ready.


Detailed Implementation Guide

Step 1: Data Preparation and Splits

Time-based split rationale:

  • Prevents temporal leakage: test data must be from after training cutoff
  • Mimics real-world deployment: train on past, predict on future
  • Standard in CAFA (Critical Assessment of Function Annotation) evaluations

Example implementation:

import pandas as pd
from datetime import datetime

def prepare_data_splits(
    uniprot_file: str,
    goa_file: str,
    cutoff_date: str = "2024-12-31",
    test_end_date: str = "2025-03-31"
):
    """
    Split data by time to prevent future leakage.
    
    Args:
        uniprot_file: UniProt sequences file
        goa_file: GOA annotations file
        cutoff_date: Training cutoff (YYYY-MM-DD)
        test_end_date: Test set end date
    
    Returns:
        train_df, test_df: DataFrames with sequences and annotations
    """
    # Load sequences
    sequences = pd.read_csv(uniprot_file)
    sequences['date'] = pd.to_datetime(sequences['date'])
    
    # Load annotations
    annotations = pd.read_csv(goa_file)
    annotations['date'] = pd.to_datetime(annotations['date'])
    
    # Time-based split
    train_sequences = sequences[sequences['date'] <= cutoff_date]
    test_sequences = sequences[
        (sequences['date'] > cutoff_date) & 
        (sequences['date'] <= test_end_date)
    ]
    
    # Remove IEA (Inferred from Electronic Annotation) if not allowed
    train_annotations = annotations[
        (annotations['date'] <= cutoff_date) &
        (annotations['evidence'] != 'IEA')
    ]
    test_annotations = annotations[
        (annotations['date'] > cutoff_date) &
        (annotations['date'] <= test_end_date) &
        (annotations['evidence'] != 'IEA')
    ]
    
    return train_sequences, test_sequences, train_annotations, test_annotations

PPI data integration (optional):

def load_ppi_network(string_file: str, confidence_threshold: float = 0.7):
    """
    Load protein-protein interaction network from STRING.
    
    Args:
        string_file: STRING interactions file
        confidence_threshold: Minimum confidence score
    
    Returns:
        adjacency_matrix: scipy.sparse.csr_matrix
        protein_to_index: dict mapping protein ID to matrix index
    """
    import scipy.sparse as sp
    
    ppi_df = pd.read_csv(string_file, sep='\t')
    ppi_df = ppi_df[ppi_df['combined_score'] >= confidence_threshold * 1000]
    
    # Create protein index mapping
    all_proteins = sorted(set(ppi_df['protein1'].unique()) | 
                          set(ppi_df['protein2'].unique()))
    protein_to_index = {p: i for i, p in enumerate(all_proteins)}
    
    # Build sparse adjacency matrix
    rows = [protein_to_index[p1] for p1 in ppi_df['protein1']]
    cols = [protein_to_index[p2] for p2 in ppi_df['protein2']]
    data = ppi_df['combined_score'].values / 1000.0  # Normalize to [0,1]
    
    n = len(all_proteins)
    adjacency = sp.csr_matrix((data, (rows, cols)), shape=(n, n))
    
    # Make symmetric (undirected graph)
    adjacency = adjacency + adjacency.T
    adjacency.data = np.minimum(adjacency.data, 1.0)  # Cap at 1.0
    
    return adjacency, protein_to_index

Step 2: Label Processing and DAG Propagation

Ancestor closure implementation:

import networkx as nx
from goatools import obo_parser

def propagate_annotations(annotations: pd.DataFrame, go_dag_file: str):
    """
    Propagate GO annotations up the DAG (ancestor closure).
    
    Args:
        annotations: DataFrame with columns [protein_id, go_term]
        go_dag_file: Path to GO OBO file
    
    Returns:
        propagated_annotations: DataFrame with ancestor-closed labels
    """
    # Load GO DAG
    go_dag = obo_parser.GODag(go_dag_file)
    
    # Build graph for efficient traversal
    G = nx.DiGraph()
    for term_id, term in go_dag.items():
        for parent_id in term.parents:
            G.add_edge(term_id, parent_id)
    
    # Propagate annotations
    propagated = set()
    for _, row in annotations.iterrows():
        protein_id = row['protein_id']
        go_term = row['go_term']
        
        # Add original annotation
        propagated.add((protein_id, go_term))
        
        # Add all ancestors
        if go_term in go_dag:
            ancestors = nx.ancestors(G, go_term)
            for ancestor in ancestors:
                propagated.add((protein_id, ancestor))
    
    return pd.DataFrame(list(propagated), columns=['protein_id', 'go_term'])

Filtering rare terms:

def filter_rare_terms(annotations: pd.DataFrame, min_count: int = 10):
    """
    Filter out GO terms with too few annotations.
    
    Args:
        annotations: DataFrame with [protein_id, go_term]
        min_count: Minimum number of proteins per term
    
    Returns:
        filtered_annotations: DataFrame with rare terms removed
    """
    term_counts = annotations['go_term'].value_counts()
    common_terms = term_counts[term_counts >= min_count].index
    
    return annotations[annotations['go_term'].isin(common_terms)]

Step 3: Feature Engineering

PLM embeddings (cached):

import numpy as np
from transformers import AutoModel, AutoTokenizer
import torch

def generate_plm_embeddings(
    sequences: list,
    model_name: str = "facebook/esm2_t33_650M_UR50D",
    batch_size: int = 32,
    cache_file: str = "embeddings.npy"
):
    """
    Generate protein language model embeddings.
    
    Args:
        sequences: List of protein sequences (strings)
        model_name: HuggingFace model identifier
        batch_size: Batch size for inference
        cache_file: Path to save embeddings
    
    Returns:
        embeddings: numpy array [N, D] of L2-normalized embeddings
    """
    if os.path.exists(cache_file):
        print(f"Loading cached embeddings from {cache_file}")
        return np.load(cache_file)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()
    
    embeddings = []
    with torch.no_grad():
        for i in range(0, len(sequences), batch_size):
            batch = sequences[i:i+batch_size]
            
            # Tokenize
            encoded = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=1024,
                return_tensors="pt"
            ).to(device)
            
            # Get embeddings (mean pooling over sequence)
            outputs = model(**encoded)
            batch_embeddings = outputs.last_hidden_state.mean(dim=1)
            
            # L2 normalize
            batch_embeddings = batch_embeddings / batch_embeddings.norm(dim=1, keepdim=True)
            
            embeddings.append(batch_embeddings.cpu().numpy())
    
    embeddings = np.vstack(embeddings)
    
    # Save cache
    np.save(cache_file, embeddings)
    print(f"Saved embeddings to {cache_file}")
    
    return embeddings

PPI network normalization:

def normalize_adjacency(adjacency: sp.csr_matrix):
    """
    Symmetrically normalize adjacency matrix for GCN.
    
    Returns:
        normalized_adj: D^(-1/2) A D^(-1/2)
    """
    # Add self-loops
    adjacency = adjacency + sp.eye(adjacency.shape[0])
    
    # Compute degree matrix
    degree = np.array(adjacency.sum(axis=1)).flatten()
    degree_sqrt_inv = 1.0 / np.sqrt(degree)
    degree_sqrt_inv[np.isinf(degree_sqrt_inv)] = 0.0
    
    # Normalize: D^(-1/2) A D^(-1/2)
    D_inv_sqrt = sp.diags(degree_sqrt_inv)
    normalized = D_inv_sqrt @ adjacency @ D_inv_sqrt
    
    return normalized

Step 4: Model Training

Baseline: Logistic Regression

from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import StandardScaler

def train_logistic_baseline(X_train, y_train, X_val, y_val):
    """
    Train one-vs-rest logistic regression baseline.
    
    Returns:
        model: Trained classifier
        metrics: Dictionary of validation metrics
    """
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)
    
    # Train
    clf = OneVsRestClassifier(
        LogisticRegression(
            max_iter=4000,
            class_weight='balanced',
            solver='lbfgs',
            random_state=42
        ),
        n_jobs=-1
    )
    clf.fit(X_train_scaled, y_train)
    
    # Evaluate
    y_pred_proba = clf.predict_proba(X_val_scaled)
    metrics = evaluate_predictions(y_val, y_pred_proba)
    
    return clf, scaler, metrics

Advanced: GCN for PPI

import torch.nn as nn
import torch.nn.functional as F

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, x, adj):
        # GCN: H' = D^(-1/2) A D^(-1/2) H W
        x = self.linear(x)
        x = torch.sparse.mm(adj, x)
        return F.relu(x)

class GCNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.gcn1 = GCNLayer(input_dim, hidden_dim)
        self.gcn2 = GCNLayer(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, adj):
        x = self.gcn1(x, adj)
        x = self.gcn2(x, adj)
        x = self.classifier(x)
        return torch.sigmoid(x)  # Multi-label classification

Step 5: Evaluation Metrics

CAFA-style evaluation:

from sklearn.metrics import precision_recall_curve, auc
import numpy as np

def compute_fmax(y_true, y_pred_proba, thresholds=None):
    """
    Compute Fmax (maximum F1 over thresholds).
    
    Returns:
        fmax: Maximum F1 score
        best_threshold: Threshold achieving Fmax
    """
    if thresholds is None:
        thresholds = np.arange(0.01, 1.0, 0.01)
    
    fmax = 0.0
    best_threshold = 0.0
    
    for threshold in thresholds:
        y_pred = (y_pred_proba >= threshold).astype(int)
        
        # Per-class precision/recall
        tp = (y_true * y_pred).sum(axis=0)
        fp = ((1 - y_true) * y_pred).sum(axis=0)
        fn = (y_true * (1 - y_pred)).sum(axis=0)
        
        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        
        # Macro-averaged F1
        f1 = 2 * precision * recall / (precision + recall + 1e-10)
        f1_macro = f1.mean()
        
        if f1_macro > fmax:
            fmax = f1_macro
            best_threshold = threshold
    
    return fmax, best_threshold

def compute_auprc(y_true, y_pred_proba):
    """
    Compute micro and macro averaged AUPRC.
    """
    # Micro-averaged (flatten all predictions)
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred_proba.flatten()
    
    precision, recall, _ = precision_recall_curve(y_true_flat, y_pred_flat)
    micro_auprc = auc(recall, precision)
    
    # Macro-averaged (per-class, then average)
    macro_auprcs = []
    for i in range(y_true.shape[1]):
        if y_true[:, i].sum() > 0:  # Skip classes with no positives
            p, r, _ = precision_recall_curve(y_true[:, i], y_pred_proba[:, i])
            macro_auprcs.append(auc(r, p))
    
    macro_auprc = np.mean(macro_auprcs) if macro_auprcs else 0.0
    
    return micro_auprc, macro_auprc

Step 6: Reproducibility Artifacts

Complete Makefile:

.PHONY: prepare embed train eval plot clean

# Configuration
CUTOFF_DATE = 2024-12-31
MODEL = esm2_t33_650M_UR50D
CONFIG = cfgs/logreg.yaml

# Data preparation
prepare:
python scripts/prepare.py \
--uniprot data/uniprot.fasta \
--goa data/goa.tsv \
--cutoff $(CUTOFF_DATE) \
--out data/splits/

# Generate embeddings
embed:
python scripts/embed_plm.py \
--sequences data/splits/train_sequences.fasta \
--model $(MODEL) \
--batch-size 32 \
--out data/embeddings/train_emb.npy

# Train model
train:
python scripts/train.py \
--cfg $(CONFIG) \
--train-emb data/embeddings/train_emb.npy \
--train-labels data/splits/train_labels.npy \
--val-emb data/embeddings/val_emb.npy \
--val-labels data/splits/val_labels.npy \
--out models/logreg.pkl

# Evaluate
eval:
python scripts/eval.py \
--model models/logreg.pkl \
--test-emb data/embeddings/test_emb.npy \
--test-labels data/splits/test_labels.npy \
--out results/logreg.json

# Plot results
plot:
python scripts/plot_pr.py \
--in results/logreg.json \
--out figs/pr_curve.png

# Clean intermediate files
clean:
rm -rf data/embeddings/*.npy
rm -rf models/*.pkl
rm -rf results/*.json

Environment file:

# environment.yml
name: go-prediction
channels:
  - conda-forge
  - pytorch
dependencies:
  - python=3.10
  - numpy=1.24.3
  - pandas=2.0.3
  - scikit-learn=1.3.0
  - pytorch=2.0.1
  - transformers=4.33.2
  - networkx=3.1
  - matplotlib=3.7.2
  - seaborn=0.12.2
  - pip
  - pip:
    - goatools==1.3.0
    - biopython==1.81

Seed control:

import os
import random
import numpy as np
import torch

def set_seeds(seed: int = 42):
    """Set all random seeds for reproducibility."""
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Error Analysis and Debugging

Confusion by information content:

def analyze_by_ic(y_true, y_pred, go_dag, ic_scores):
    """
    Analyze errors by information content (IC) bins.
    
    IC measures term specificity: high IC = specific, low IC = general.
    """
    # Bin terms by IC
    ic_bins = {
        'rare': ic_scores < 2.0,
        'medium': (ic_scores >= 2.0) & (ic_scores < 4.0),
        'common': ic_scores >= 4.0
    }
    
    for bin_name, mask in ic_bins.items():
        y_true_bin = y_true[:, mask]
        y_pred_bin = y_pred[:, mask]
        
        fmax, _ = compute_fmax(y_true_bin, y_pred_bin)
        print(f"{bin_name} terms (IC {ic_bins[bin_name]}): Fmax = {fmax:.3f}")

Per-ontology breakdown:

def analyze_by_ontology(y_true, y_pred, go_terms):
    """
    Analyze performance separately for BP, MF, CC.
    """
    ontologies = {
        'BP': [t for t in go_terms if t.startswith('GO:') and go_dag[t].namespace == 'biological_process'],
        'MF': [t for t in go_terms if t.startswith('GO:') and go_dag[t].namespace == 'molecular_function'],
        'CC': [t for t in go_terms if t.startswith('GO:') and go_dag[t].namespace == 'cellular_component']
    }
    
    for ont, terms in ontologies.items():
        term_indices = [go_terms.index(t) for t in terms if t in go_terms]
        if term_indices:
            y_true_ont = y_true[:, term_indices]
            y_pred_ont = y_pred[:, term_indices]
            fmax, _ = compute_fmax(y_true_ont, y_pred_ont)
            print(f"{ont}: Fmax = {fmax:.3f}")

Best Practices Summary

  1. Time-based splits: Always use temporal splits to prevent leakage
  2. Ancestor closure: Propagate labels up the DAG for consistency
  3. Cache embeddings: Save PLM embeddings to avoid recomputation
  4. Version control: Tag data, code, and model versions
  5. Documentation: Include README with setup, usage, and results
  6. Reproducibility: Pin all dependencies, set random seeds
  7. Error analysis: Analyze failures by IC, ontology, hierarchy violations

Deliverables checklist:

  • results/ directory with JSON metrics and plots
  • cfgs/ directory with all configuration files
  • README.md with setup instructions and results summary
  • environment.yml or requirements.txt with pinned versions
  • Makefile for reproducible execution
  • Seed control in all scripts
  • Data version tags (cutoff dates, GO version)

Deliverables: Zipped results/ + cfgs/ + short README = review-ready.

This is a personal blog. The views, thoughts, and opinions expressed here are my own and do not represent, reflect, or constitute the views, policies, or positions of any employer, university, client, or organization I am associated with or have been associated with.

© Copyright 2017-2025