DR. ATABAK KH

Cloud Platform Modernization Architect specializing in transforming legacy systems into reliable, observable, and cost-efficient Cloud platforms.

Certified: Google Professional Cloud Architect, AWS Solutions Architect, MapR Cluster Administrator

Takeaway: For robust GO prediction, start with homology + PLM baselines, add label smoothing on PPI, and only then graduate to GNNs/multimodal fusion.

Problem framing (multilabel, hierarchical)

  • Labels: Gene Ontology (BP/MF/CC), DAG with is_a, part_of.
  • Task: multilabel classification with long-tail classes and hierarchy constraints.
  • Implication: predictions must be ancestor-closed.

Features to combine

1) Homology: BLAST/DIAMOND kNN; domain HMMs (Pfam). 2) Sequence embeddings: frozen PLMs (e.g., ESM-like, 1-3k dims). 3) Structure: DSSP features, secondary structure, solvent accessibility; optional 3D graph embeddings. 4) Networks: PPI/co-expression; A adjacency for smoothing or GNN. 5) Text: weak supervision from abstracts/full-text.

Minimal, high-value recipe

(a) Logistic baseline on PLM embeddings (balanced, calibrated)
(b) One-step label smoothing on normalized PPI
(c) Hierarchy closure + per-class threshold tuning (Fmax)

# X: PLM embeddings [N, D], Y: multi-hot GO [N, C], A: normalized PPI
clf = OneVsRestClassifier(LogisticRegression(max_iter=4000, class_weight="balanced"))
clf.fit(X_tr, Y_tr)
P0 = clf.predict_proba(X_val)

alpha = 0.2
P1 = (1 - alpha) * P0 + alpha * (A @ P0)           # one-step smoothing
P1 = np.clip(P1, 0, 1)
P1 = close_under_ancestors(P1, go_dag)              # hierarchy consistency

th = tune_thresholds(P1, Y_val, metric="Fmax")      # per-class thresholds
Y_hat = (P1 >= th).astype(int)

Hyper-parameters (good starting points)

  • PLM embedding: L2-normalize; optionally PCA to 256-512.
  • Logistic: C=1.0, class_weight="balanced", solver="lbfgs", max_iter=4000.
  • Smoothing: alpha = 0.1-0.3 (avoid oversmoothing hubs).
  • Threshold search: per-class sweep over [0.05..0.95].

Evaluation & reporting

  • Fmax per ontology (BP/MF/CC), micro/macro-auPRC.
  • Coverage (proteins with ≥1 label), ECE (calibration).
  • Hierarchy violations (should be ~0 after closure).

Ablations to include (table)

  • PLM only vs +kNN homology vs +PPI smoothing.
  • Thresholding: global vs per-class.
  • With/without hierarchy closure.

Bottom line: this pipeline is simple, strong, and extensible; add GNNs (GCN/GAT) or multimodal fusion when the baseline plateaus.


Detailed Implementation

Step 1: Feature Extraction

Homology-based features (BLAST/DIAMOND):

from Bio.Blast import NCBIWWW, NCBIXML
import subprocess

def get_blast_homologs(sequence: str, database: str = "nr", evalue: float = 1e-5):
    """
    Get homologous proteins via BLAST.
    
    Returns:
        homolog_ids: List of UniProt IDs with significant hits
        scores: List of E-values
    """
    result = NCBIWWW.qblast("blastp", database, sequence, expect=evalue)
    blast_record = NCBIXML.read(result)
    
    homolog_ids = []
    scores = []
    for alignment in blast_record.alignments:
        for hsp in alignment.hsps:
            if hsp.expect < evalue:
                homolog_ids.append(alignment.title.split('|')[1])  # UniProt ID
                scores.append(hsp.expect)
    
    return homolog_ids, scores

def transfer_labels_from_homologs(protein_id: str, homolog_ids: list, go_annotations: dict):
    """
    Transfer GO labels from homologous proteins.
    
    Returns:
        transferred_labels: Set of GO terms from homologs
    """
    transferred_labels = set()
    for homolog_id in homolog_ids:
        if homolog_id in go_annotations:
            transferred_labels.update(go_annotations[homolog_id])
    return transferred_labels

PLM embeddings (ESM2):

from transformers import EsmModel, EsmTokenizer
import torch

def get_esm2_embedding(sequence: str, model_name: str = "facebook/esm2_t33_650M_UR50D"):
    """
    Get ESM2 embedding for a protein sequence.
    
    Returns:
        embedding: numpy array [D] of L2-normalized embedding
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = EsmTokenizer.from_pretrained(model_name)
    model = EsmModel.from_pretrained(model_name).to(device)
    model.eval()
    
    # Tokenize
    encoded = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(device)
    
    # Get embedding (mean pooling)
    with torch.no_grad():
        outputs = model(**encoded)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
        embedding = embedding / embedding.norm()  # L2 normalize
    
    return embedding.cpu().numpy()

Structure features (DSSP):

from Bio.PDB import PDBParser
from Bio.PDB.DSSP import DSSP

def extract_structure_features(pdb_file: str):
    """
    Extract secondary structure and solvent accessibility from PDB.
    
    Returns:
        features: Dictionary with SS, SA, etc.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    model = structure[0]
    
    dssp = DSSP(model, pdb_file)
    
    features = {
        'secondary_structure': [],
        'solvent_accessibility': [],
        'phi': [],
        'psi': []
    }
    
    for residue in dssp:
        features['secondary_structure'].append(residue[2])  # H, E, C
        features['solvent_accessibility'].append(residue[3])  # RSA
        features['phi'].append(residue[4])
        features['psi'].append(residue[5])
    
    return features

Step 2: Label Smoothing on PPI

One-step label smoothing implementation:

import numpy as np
import scipy.sparse as sp

def label_smoothing_on_ppi(
    P0: np.ndarray,
    A: sp.csr_matrix,
    alpha: float = 0.2
):
    """
    Apply one-step label smoothing on PPI network.
    
    Args:
        P0: [N, C] initial predictions (probabilities)
        A: [N, N] normalized PPI adjacency matrix
        alpha: Smoothing coefficient (0.2 = 20% from neighbors)
    
    Returns:
        P1: [N, C] smoothed predictions
    """
    # Normalize adjacency (row-stochastic)
    A_normalized = A.copy()
    row_sums = np.array(A.sum(axis=1)).flatten()
    row_sums[row_sums == 0] = 1.0  # Avoid division by zero
    A_normalized = A_normalized.multiply(1.0 / row_sums[:, np.newaxis])
    
    # One-step smoothing: (1-alpha) * P0 + alpha * A @ P0
    P_smoothed = A_normalized @ P0
    P1 = (1 - alpha) * P0 + alpha * P_smoothed
    
    # Clip to [0, 1]
    P1 = np.clip(P1, 0, 1)
    
    return P1

Why it works:

  • Proteins with similar functions tend to interact (guilt-by-association)
  • Smoothing propagates labels from annotated to unannotated proteins
  • One step avoids oversmoothing (multiple steps can blur signal)

Tuning alpha:

  • alpha = 0.1: Conservative, mostly original predictions
  • alpha = 0.2: Balanced (good default)
  • alpha = 0.3: Aggressive, more neighbor influence
  • alpha > 0.3: Risk of oversmoothing, especially for hub proteins

Step 3: Hierarchy Closure

Ancestor closure (from previous article):

def close_under_ancestors(P: np.ndarray, go_dag, threshold: float = 0.5):
    """
    Ensure predictions respect GO hierarchy.
    
    Args:
        P: [N, C] class probabilities
        go_dag: GO DAG object
        threshold: Threshold for binary predictions
    
    Returns:
        Y_closed: [N, C] ancestor-closed binary labels
    """
    Y = (P >= threshold).astype(np.uint8)
    
    # Topological order (parents before children)
    order = go_dag.topo_order()
    
    # Visit children before parents (reverse order)
    for term_idx in reversed(order):
        for ancestor_idx in go_dag.ancestors(term_idx):
            # If child is on, force ancestor on
            Y[:, ancestor_idx] = np.maximum(Y[:, ancestor_idx], Y[:, term_idx])
    
    return Y

Step 4: Threshold Tuning

Per-class threshold optimization:

from sklearn.metrics import f1_score

def tune_per_class_thresholds(
    P: np.ndarray,
    Y_true: np.ndarray,
    metric: str = "f1",
    thresholds: np.ndarray = None
):
    """
    Find optimal per-class thresholds for Fmax.
    
    Returns:
        best_thresholds: [C] array of optimal thresholds per class
        best_fmax: Maximum F1 score achieved
    """
    if thresholds is None:
        thresholds = np.arange(0.05, 0.95, 0.01)
    
    num_classes = Y_true.shape[1]
    best_thresholds = np.zeros(num_classes)
    best_fmax = 0.0
    
    for class_idx in range(num_classes):
        if Y_true[:, class_idx].sum() == 0:
            continue  # Skip classes with no positives
        
        best_f1 = 0.0
        best_th = 0.5
        
        for th in thresholds:
            y_pred = (P[:, class_idx] >= th).astype(int)
            f1 = f1_score(Y_true[:, class_idx], y_pred, zero_division=0)
            
            if f1 > best_f1:
                best_f1 = f1
                best_th = th
        
        best_thresholds[class_idx] = best_th
    
    # Compute Fmax with best thresholds
    Y_pred = (P >= best_thresholds).astype(int)
    fmax = compute_fmax(Y_true, Y_pred)
    
    return best_thresholds, fmax

Step 5: Advanced: GNN Implementation

GCN for GO prediction:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCNForGO(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers=2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(input_dim, hidden_dim))
        
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        self.convs.append(GCNConv(hidden_dim, num_classes))
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x, edge_index):
        """
        Args:
            x: [N, D] node features (PLM embeddings)
            edge_index: [2, E] edge indices (PPI network)
        """
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
        
        x = self.convs[-1](x, edge_index)
        return torch.sigmoid(x)  # Multi-label classification

Training loop:

def train_gnn(model, train_loader, val_loader, num_epochs=100):
    """
    Train GCN for GO prediction.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCELoss()
    
    best_fmax = 0.0
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_predictions = []
        val_labels = []
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch.x, batch.edge_index)
                val_predictions.append(out.cpu().numpy())
                val_labels.append(batch.y.cpu().numpy())
        
        val_predictions = np.vstack(val_predictions)
        val_labels = np.vstack(val_labels)
        
        fmax, _ = compute_fmax(val_labels, val_predictions)
        
        if fmax > best_fmax:
            best_fmax = fmax
            torch.save(model.state_dict(), 'best_model.pth')
        
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Fmax = {fmax:.4f}")

Step 6: Multimodal Fusion

Late fusion of multiple features:

def multimodal_fusion(
    plm_embeddings: np.ndarray,
    structure_features: np.ndarray,
    homology_scores: np.ndarray,
    ppi_embeddings: np.ndarray
):
    """
    Combine multiple feature modalities.
    
    Returns:
        fused_features: [N, D] combined feature vector
    """
    # Normalize each modality
    plm_norm = plm_embeddings / np.linalg.norm(plm_embeddings, axis=1, keepdim=True)
    struct_norm = structure_features / (np.linalg.norm(structure_features, axis=1, keepdim=True) + 1e-10)
    homology_norm = homology_scores / (np.linalg.norm(homology_scores, axis=1, keepdim=True) + 1e-10)
    ppi_norm = ppi_embeddings / (np.linalg.norm(ppi_embeddings, axis=1, keepdim=True) + 1e-10)
    
    # Concatenate or weighted average
    # Option 1: Concatenate
    fused = np.concatenate([plm_norm, struct_norm, homology_norm, ppi_norm], axis=1)
    
    # Option 2: Weighted average (if same dimension)
    # weights = [0.5, 0.2, 0.2, 0.1]  # PLM gets most weight
    # fused = weights[0] * plm_norm + weights[1] * struct_norm + ...
    
    return fused

Early fusion (learned):

class MultimodalFusion(nn.Module):
    def __init__(self, plm_dim, struct_dim, homology_dim, ppi_dim, hidden_dim):
        super().__init__()
        self.plm_proj = nn.Linear(plm_dim, hidden_dim)
        self.struct_proj = nn.Linear(struct_dim, hidden_dim)
        self.homology_proj = nn.Linear(homology_dim, hidden_dim)
        self.ppi_proj = nn.Linear(ppi_dim, hidden_dim)
        
        # Learnable fusion weights
        self.fusion_weights = nn.Parameter(torch.ones(4) / 4)
    
    def forward(self, plm, struct, homology, ppi):
        # Project each modality to same dimension
        plm_proj = F.relu(self.plm_proj(plm))
        struct_proj = F.relu(self.struct_proj(struct))
        homology_proj = F.relu(self.homology_proj(homology))
        ppi_proj = F.relu(self.ppi_proj(ppi))
        
        # Weighted combination
        weights = F.softmax(self.fusion_weights, dim=0)
        fused = (weights[0] * plm_proj + 
                 weights[1] * struct_proj + 
                 weights[2] * homology_proj + 
                 weights[3] * ppi_proj)
        
        return fused

Complete Pipeline Example

End-to-end pipeline:

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

# 1. Load data
sequences = load_sequences("data/uniprot.fasta")
go_annotations = load_go_annotations("data/goa.tsv")
ppi_network = load_ppi_network("data/string.tsv")

# 2. Generate features
print("Generating PLM embeddings...")
X_plm = np.array([get_esm2_embedding(seq) for seq in sequences])

print("Computing homology features...")
X_homology = np.array([get_blast_homologs(seq) for seq in sequences])

print("Building PPI adjacency...")
A = build_ppi_adjacency(ppi_network, sequences)

# 3. Train baseline
print("Training logistic regression...")
clf = OneVsRestClassifier(
    LogisticRegression(max_iter=4000, class_weight="balanced", solver="lbfgs")
)
clf.fit(X_plm, Y_train)
P0 = clf.predict_proba(X_val)

# 4. Label smoothing
print("Applying label smoothing...")
alpha = 0.2
P1 = label_smoothing_on_ppi(P0, A, alpha=alpha)
P1 = np.clip(P1, 0, 1)

# 5. Hierarchy closure
print("Enforcing hierarchy consistency...")
P1 = close_under_ancestors(P1, go_dag)

# 6. Threshold tuning
print("Tuning thresholds...")
thresholds, fmax = tune_per_class_thresholds(P1, Y_val)
Y_pred = (P1 >= thresholds).astype(int)

# 7. Evaluate
print("Evaluating...")
fmax_bp, fmax_mf, fmax_cc = compute_fmax_by_ontology(Y_val, Y_pred, go_dag)
micro_auprc, macro_auprc = compute_auprc(Y_val, P1)

print(f"Fmax (BP): {fmax_bp:.3f}")
print(f"Fmax (MF): {fmax_mf:.3f}")
print(f"Fmax (CC): {fmax_cc:.3f}")
print(f"Micro AUPRC: {micro_auprc:.3f}")
print(f"Macro AUPRC: {macro_auprc:.3f}")

When to Graduate to GNNs

Stick with baseline if:

  • Fmax > 0.5 on validation set
  • Training time < 1 hour
  • Simple features (PLM + PPI smoothing) sufficient

Upgrade to GNNs when:

  • Baseline plateaus (Fmax < 0.4)
  • PPI network is dense and informative
  • You have structure data (3D coordinates)
  • You need to model complex protein interactions

GNN architecture selection:

  • GCN: Simple, fast, good for dense graphs
  • GAT: Attention mechanism, better for heterogeneous graphs
  • GraphSAGE: Inductive learning, handles new proteins
  • GIN: More expressive, better for complex patterns

Ablation Study Template

Compare different approaches:

results = {}

# Baseline: PLM only
clf_plm = train_logistic(X_plm, Y_train)
P_plm = clf_plm.predict_proba(X_val)
results['PLM only'] = evaluate(Y_val, P_plm)

# + Homology
X_combined = np.concatenate([X_plm, X_homology], axis=1)
clf_hom = train_logistic(X_combined, Y_train)
P_hom = clf_hom.predict_proba(X_val)
results['PLM + Homology'] = evaluate(Y_val, P_hom)

# + PPI smoothing
P_smooth = label_smoothing_on_ppi(P_plm, A, alpha=0.2)
results['PLM + PPI smoothing'] = evaluate(Y_val, P_smooth)

# + Hierarchy closure
P_closed = close_under_ancestors(P_smooth, go_dag)
results['PLM + PPI + Hierarchy'] = evaluate(Y_val, P_closed)

# + Per-class thresholds
thresholds, _ = tune_per_class_thresholds(P_closed, Y_val)
Y_pred = (P_closed >= thresholds).astype(int)
results['Full pipeline'] = evaluate(Y_val, Y_pred)

# Print results table
print("\nAblation Results:")
print("=" * 60)
for method, metrics in results.items():
    print(f"{method:30s} Fmax={metrics['fmax']:.3f} AUPRC={metrics['auprc']:.3f}")

Common Pitfalls

Pitfall 1: Oversmoothing

Problem: Too many smoothing steps or high alpha blurs signal.

Solution:

  • Use one-step smoothing (alpha = 0.1-0.2)
  • Monitor performance: if smoothing hurts, reduce alpha
  • Filter hub proteins (high degree) from smoothing

Pitfall 2: Ignoring Hierarchy

Problem: Predictions violate ancestor-child relationships.

Solution:

  • Always apply hierarchy closure
  • Use per-class thresholds (higher for general terms)
  • Report both pre- and post-closure metrics

Pitfall 3: Data Leakage

Problem: Test set contains proteins from training time period.

Solution:

  • Strict time-based split (test after training cutoff)
  • Remove IEA evidence from both train and test
  • Document split dates in paper/report

Best Practices Summary

  1. Start simple: PLM + logistic regression baseline
  2. Add smoothing: One-step PPI label smoothing (alpha = 0.2)
  3. Enforce hierarchy: Ancestor closure after thresholding
  4. Tune thresholds: Per-class thresholds for Fmax
  5. Graduate carefully: Only add GNNs when baseline plateaus
  6. Ablate everything: Compare each component’s contribution
  7. Document splits: Time-based splits, evidence filtering

Bottom line: This pipeline is simple, strong, and extensible; add GNNs (GCN/GAT) or multimodal fusion when the baseline plateaus.

This is a personal blog. The views, thoughts, and opinions expressed here are my own and do not represent, reflect, or constitute the views, policies, or positions of any employer, university, client, or organization I am associated with or have been associated with.

© Copyright 2017-2025