DR. ATABAK KH

Cloud Platform Modernization Architect specializing in transforming legacy systems into reliable, observable, and cost-efficient Cloud platforms.

Certified: Google Professional Cloud Architect, AWS Solutions Architect, MapR Cluster Administrator

Takeaway: For robust GO prediction, start with homology + PLM baselines, add label smoothing on PPI, and only then graduate to GNNs/multimodal fusion.

Problem framing (multilabel, hierarchical)

  • Labels: Gene Ontology (BP/MF/CC), DAG with is_a, part_of.
  • Task: multilabel classification with long-tail classes and hierarchy constraints.
  • Implication: predictions must be ancestor-closed.

Features to combine

1) Homology: BLAST/DIAMOND kNN; domain HMMs (Pfam). 2) Sequence embeddings: frozen PLMs (e.g., ESM-like, 1-3k dims). 3) Structure: DSSP features, secondary structure, solvent accessibility; optional 3D graph embeddings. 4) Networks: PPI/co-expression; A adjacency for smoothing or GNN. 5) Text: weak supervision from abstracts/full-text.

Minimal, high-value recipe

(a) Logistic baseline on PLM embeddings (balanced, calibrated)
(b) One-step label smoothing on normalized PPI
(c) Hierarchy closure + per-class threshold tuning (Fmax)

# X: PLM embeddings [N, D], Y: multi-hot GO [N, C], A: normalized PPI
clf = OneVsRestClassifier(LogisticRegression(max_iter=4000, class_weight="balanced"))
clf.fit(X_tr, Y_tr)
P0 = clf.predict_proba(X_val)

alpha = 0.2
P1 = (1 - alpha) * P0 + alpha * (A @ P0)           # one-step smoothing
P1 = np.clip(P1, 0, 1)
P1 = close_under_ancestors(P1, go_dag)              # hierarchy consistency

th = tune_thresholds(P1, Y_val, metric="Fmax")      # per-class thresholds
Y_hat = (P1 >= th).astype(int)

Hyper-parameters (good starting points)

  • PLM embedding: L2-normalize; optionally PCA to 256-512.
  • Logistic: C=1.0, class_weight="balanced", solver="lbfgs", max_iter=4000.
  • Smoothing: alpha = 0.1-0.3 (avoid oversmoothing hubs).
  • Threshold search: per-class sweep over [0.05..0.95].

Evaluation & reporting

  • Fmax per ontology (BP/MF/CC), micro/macro-auPRC.
  • Coverage (proteins with ≥1 label), ECE (calibration).
  • Hierarchy violations (should be ~0 after closure).

Ablations to include (table)

  • PLM only vs +kNN homology vs +PPI smoothing.
  • Thresholding: global vs per-class.
  • With/without hierarchy closure.

Bottom line: this pipeline is simple, strong, and extensible; add GNNs (GCN/GAT) or multimodal fusion when the baseline plateaus.


Detailed Implementation

Step 1: Feature Extraction

Homology-based features (BLAST/DIAMOND):

from Bio.Blast import NCBIWWW, NCBIXML
import subprocess

def get_blast_homologs(sequence: str, database: str = "nr", evalue: float = 1e-5):
    """
    Get homologous proteins via BLAST.
    
    Returns:
        homolog_ids: List of UniProt IDs with significant hits
        scores: List of E-values
    """
    result = NCBIWWW.qblast("blastp", database, sequence, expect=evalue)
    blast_record = NCBIXML.read(result)
    
    homolog_ids = []
    scores = []
    for alignment in blast_record.alignments:
        for hsp in alignment.hsps:
            if hsp.expect < evalue:
                homolog_ids.append(alignment.title.split('|')[1])  # UniProt ID
                scores.append(hsp.expect)
    
    return homolog_ids, scores

def transfer_labels_from_homologs(protein_id: str, homolog_ids: list, go_annotations: dict):
    """
    Transfer GO labels from homologous proteins.
    
    Returns:
        transferred_labels: Set of GO terms from homologs
    """
    transferred_labels = set()
    for homolog_id in homolog_ids:
        if homolog_id in go_annotations:
            transferred_labels.update(go_annotations[homolog_id])
    return transferred_labels

PLM embeddings (ESM2):

from transformers import EsmModel, EsmTokenizer
import torch

def get_esm2_embedding(sequence: str, model_name: str = "facebook/esm2_t33_650M_UR50D"):
    """
    Get ESM2 embedding for a protein sequence.
    
    Returns:
        embedding: numpy array [D] of L2-normalized embedding
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = EsmTokenizer.from_pretrained(model_name)
    model = EsmModel.from_pretrained(model_name).to(device)
    model.eval()
    
    # Tokenize
    encoded = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(device)
    
    # Get embedding (mean pooling)
    with torch.no_grad():
        outputs = model(**encoded)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
        embedding = embedding / embedding.norm()  # L2 normalize
    
    return embedding.cpu().numpy()

Structure features (DSSP):

from Bio.PDB import PDBParser
from Bio.PDB.DSSP import DSSP

def extract_structure_features(pdb_file: str):
    """
    Extract secondary structure and solvent accessibility from PDB.
    
    Returns:
        features: Dictionary with SS, SA, etc.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    model = structure[0]
    
    dssp = DSSP(model, pdb_file)
    
    features = {
        'secondary_structure': [],
        'solvent_accessibility': [],
        'phi': [],
        'psi': []
    }
    
    for residue in dssp:
        features['secondary_structure'].append(residue[2])  # H, E, C
        features['solvent_accessibility'].append(residue[3])  # RSA
        features['phi'].append(residue[4])
        features['psi'].append(residue[5])
    
    return features

Step 2: Label Smoothing on PPI

One-step label smoothing implementation:

import numpy as np
import scipy.sparse as sp

def label_smoothing_on_ppi(
    P0: np.ndarray,
    A: sp.csr_matrix,
    alpha: float = 0.2
):
    """
    Apply one-step label smoothing on PPI network.
    
    Args:
        P0: [N, C] initial predictions (probabilities)
        A: [N, N] normalized PPI adjacency matrix
        alpha: Smoothing coefficient (0.2 = 20% from neighbors)
    
    Returns:
        P1: [N, C] smoothed predictions
    """
    # Normalize adjacency (row-stochastic)
    A_normalized = A.copy()
    row_sums = np.array(A.sum(axis=1)).flatten()
    row_sums[row_sums == 0] = 1.0  # Avoid division by zero
    A_normalized = A_normalized.multiply(1.0 / row_sums[:, np.newaxis])
    
    # One-step smoothing: (1-alpha) * P0 + alpha * A @ P0
    P_smoothed = A_normalized @ P0
    P1 = (1 - alpha) * P0 + alpha * P_smoothed
    
    # Clip to [0, 1]
    P1 = np.clip(P1, 0, 1)
    
    return P1

Why it works:

  • Proteins with similar functions tend to interact (guilt-by-association)
  • Smoothing propagates labels from annotated to unannotated proteins
  • One step avoids oversmoothing (multiple steps can blur signal)

Tuning alpha:

  • alpha = 0.1: Conservative, mostly original predictions
  • alpha = 0.2: Balanced (good default)
  • alpha = 0.3: Aggressive, more neighbor influence
  • alpha > 0.3: Risk of oversmoothing, especially for hub proteins

Step 3: Hierarchy Closure

Ancestor closure (from previous article):

def close_under_ancestors(P: np.ndarray, go_dag, threshold: float = 0.5):
    """
    Ensure predictions respect GO hierarchy.
    
    Args:
        P: [N, C] class probabilities
        go_dag: GO DAG object
        threshold: Threshold for binary predictions
    
    Returns:
        Y_closed: [N, C] ancestor-closed binary labels
    """
    Y = (P >= threshold).astype(np.uint8)
    
    # Topological order (parents before children)
    order = go_dag.topo_order()
    
    # Visit children before parents (reverse order)
    for term_idx in reversed(order):
        for ancestor_idx in go_dag.ancestors(term_idx):
            # If child is on, force ancestor on
            Y[:, ancestor_idx] = np.maximum(Y[:, ancestor_idx], Y[:, term_idx])
    
    return Y

Step 4: Threshold Tuning

Per-class threshold optimization:

from sklearn.metrics import f1_score

def tune_per_class_thresholds(
    P: np.ndarray,
    Y_true: np.ndarray,
    metric: str = "f1",
    thresholds: np.ndarray = None
):
    """
    Find optimal per-class thresholds for Fmax.
    
    Returns:
        best_thresholds: [C] array of optimal thresholds per class
        best_fmax: Maximum F1 score achieved
    """
    if thresholds is None:
        thresholds = np.arange(0.05, 0.95, 0.01)
    
    num_classes = Y_true.shape[1]
    best_thresholds = np.zeros(num_classes)
    best_fmax = 0.0
    
    for class_idx in range(num_classes):
        if Y_true[:, class_idx].sum() == 0:
            continue  # Skip classes with no positives
        
        best_f1 = 0.0
        best_th = 0.5
        
        for th in thresholds:
            y_pred = (P[:, class_idx] >= th).astype(int)
            f1 = f1_score(Y_true[:, class_idx], y_pred, zero_division=0)
            
            if f1 > best_f1:
                best_f1 = f1
                best_th = th
        
        best_thresholds[class_idx] = best_th
    
    # Compute Fmax with best thresholds
    Y_pred = (P >= best_thresholds).astype(int)
    fmax = compute_fmax(Y_true, Y_pred)
    
    return best_thresholds, fmax

Step 5: Advanced: GNN Implementation

GCN for GO prediction:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCNForGO(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers=2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(input_dim, hidden_dim))
        
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        self.convs.append(GCNConv(hidden_dim, num_classes))
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x, edge_index):
        """
        Args:
            x: [N, D] node features (PLM embeddings)
            edge_index: [2, E] edge indices (PPI network)
        """
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
        
        x = self.convs[-1](x, edge_index)
        return torch.sigmoid(x)  # Multi-label classification

Training loop:

def train_gnn(model, train_loader, val_loader, num_epochs=100):
    """
    Train GCN for GO prediction.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCELoss()
    
    best_fmax = 0.0
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_predictions = []
        val_labels = []
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch.x, batch.edge_index)
                val_predictions.append(out.cpu().numpy())
                val_labels.append(batch.y.cpu().numpy())
        
        val_predictions = np.vstack(val_predictions)
        val_labels = np.vstack(val_labels)
        
        fmax, _ = compute_fmax(val_labels, val_predictions)
        
        if fmax > best_fmax:
            best_fmax = fmax
            torch.save(model.state_dict(), 'best_model.pth')
        
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Fmax = {fmax:.4f}")

Step 6: Multimodal Fusion

Late fusion of multiple features:

def multimodal_fusion(
    plm_embeddings: np.ndarray,
    structure_features: np.ndarray,
    homology_scores: np.ndarray,
    ppi_embeddings: np.ndarray
):
    """
    Combine multiple feature modalities.
    
    Returns:
        fused_features: [N, D] combined feature vector
    """
    # Normalize each modality
    plm_norm = plm_embeddings / np.linalg.norm(plm_embeddings, axis=1, keepdim=True)
    struct_norm = structure_features / (np.linalg.norm(structure_features, axis=1, keepdim=True) + 1e-10)
    homology_norm = homology_scores / (np.linalg.norm(homology_scores, axis=1, keepdim=True) + 1e-10)
    ppi_norm = ppi_embeddings / (np.linalg.norm(ppi_embeddings, axis=1, keepdim=True) + 1e-10)
    
    # Concatenate or weighted average
    # Option 1: Concatenate
    fused = np.concatenate([plm_norm, struct_norm, homology_norm, ppi_norm], axis=1)
    
    # Option 2: Weighted average (if same dimension)
    # weights = [0.5, 0.2, 0.2, 0.1]  # PLM gets most weight
    # fused = weights[0] * plm_norm + weights[1] * struct_norm + ...
    
    return fused

Early fusion (learned):

class MultimodalFusion(nn.Module):
    def __init__(self, plm_dim, struct_dim, homology_dim, ppi_dim, hidden_dim):
        super().__init__()
        self.plm_proj = nn.Linear(plm_dim, hidden_dim)
        self.struct_proj = nn.Linear(struct_dim, hidden_dim)
        self.homology_proj = nn.Linear(homology_dim, hidden_dim)
        self.ppi_proj = nn.Linear(ppi_dim, hidden_dim)
        
        # Learnable fusion weights
        self.fusion_weights = nn.Parameter(torch.ones(4) / 4)
    
    def forward(self, plm, struct, homology, ppi):
        # Project each modality to same dimension
        plm_proj = F.relu(self.plm_proj(plm))
        struct_proj = F.relu(self.struct_proj(struct))
        homology_proj = F.relu(self.homology_proj(homology))
        ppi_proj = F.relu(self.ppi_proj(ppi))
        
        # Weighted combination
        weights = F.softmax(self.fusion_weights, dim=0)
        fused = (weights[0] * plm_proj + 
                 weights[1] * struct_proj + 
                 weights[2] * homology_proj + 
                 weights[3] * ppi_proj)
        
        return fused

Complete Pipeline Example

End-to-end pipeline:

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

# 1. Load data
sequences = load_sequences("data/uniprot.fasta")
go_annotations = load_go_annotations("data/goa.tsv")
ppi_network = load_ppi_network("data/string.tsv")

# 2. Generate features
print("Generating PLM embeddings...")
X_plm = np.array([get_esm2_embedding(seq) for seq in sequences])

print("Computing homology features...")
X_homology = np.array([get_blast_homologs(seq) for seq in sequences])

print("Building PPI adjacency...")
A = build_ppi_adjacency(ppi_network, sequences)

# 3. Train baseline
print("Training logistic regression...")
clf = OneVsRestClassifier(
    LogisticRegression(max_iter=4000, class_weight="balanced", solver="lbfgs")
)
clf.fit(X_plm, Y_train)
P0 = clf.predict_proba(X_val)

# 4. Label smoothing
print("Applying label smoothing...")
alpha = 0.2
P1 = label_smoothing_on_ppi(P0, A, alpha=alpha)
P1 = np.clip(P1, 0, 1)

# 5. Hierarchy closure
print("Enforcing hierarchy consistency...")
P1 = close_under_ancestors(P1, go_dag)

# 6. Threshold tuning
print("Tuning thresholds...")
thresholds, fmax = tune_per_class_thresholds(P1, Y_val)
Y_pred = (P1 >= thresholds).astype(int)

# 7. Evaluate
print("Evaluating...")
fmax_bp, fmax_mf, fmax_cc = compute_fmax_by_ontology(Y_val, Y_pred, go_dag)
micro_auprc, macro_auprc = compute_auprc(Y_val, P1)

print(f"Fmax (BP): {fmax_bp:.3f}")
print(f"Fmax (MF): {fmax_mf:.3f}")
print(f"Fmax (CC): {fmax_cc:.3f}")
print(f"Micro AUPRC: {micro_auprc:.3f}")
print(f"Macro AUPRC: {macro_auprc:.3f}")

When to Graduate to GNNs

Stick with baseline if:

  • Fmax > 0.5 on validation set
  • Training time < 1 hour
  • Simple features (PLM + PPI smoothing) sufficient

Upgrade to GNNs when:

  • Baseline plateaus (Fmax < 0.4)
  • PPI network is dense and informative
  • You have structure data (3D coordinates)
  • You need to model complex protein interactions

GNN architecture selection:

  • GCN: Simple, fast, good for dense graphs
  • GAT: Attention mechanism, better for heterogeneous graphs
  • GraphSAGE: Inductive learning, handles new proteins
  • GIN: More expressive, better for complex patterns

Ablation Study Template

Compare different approaches:

results = {}

# Baseline: PLM only
clf_plm = train_logistic(X_plm, Y_train)
P_plm = clf_plm.predict_proba(X_val)
results['PLM only'] = evaluate(Y_val, P_plm)

# + Homology
X_combined = np.concatenate([X_plm, X_homology], axis=1)
clf_hom = train_logistic(X_combined, Y_train)
P_hom = clf_hom.predict_proba(X_val)
results['PLM + Homology'] = evaluate(Y_val, P_hom)

# + PPI smoothing
P_smooth = label_smoothing_on_ppi(P_plm, A, alpha=0.2)
results['PLM + PPI smoothing'] = evaluate(Y_val, P_smooth)

# + Hierarchy closure
P_closed = close_under_ancestors(P_smooth, go_dag)
results['PLM + PPI + Hierarchy'] = evaluate(Y_val, P_closed)

# + Per-class thresholds
thresholds, _ = tune_per_class_thresholds(P_closed, Y_val)
Y_pred = (P_closed >= thresholds).astype(int)
results['Full pipeline'] = evaluate(Y_val, Y_pred)

# Print results table
print("\nAblation Results:")
print("=" * 60)
for method, metrics in results.items():
    print(f"{method:30s} Fmax={metrics['fmax']:.3f} AUPRC={metrics['auprc']:.3f}")

Common Pitfalls

Pitfall 1: Oversmoothing

Problem: Too many smoothing steps or high alpha blurs signal.

Solution:

  • Use one-step smoothing (alpha = 0.1-0.2)
  • Monitor performance: if smoothing hurts, reduce alpha
  • Filter hub proteins (high degree) from smoothing

Pitfall 2: Ignoring Hierarchy

Problem: Predictions violate ancestor-child relationships.

Solution:

  • Always apply hierarchy closure
  • Use per-class thresholds (higher for general terms)
  • Report both pre- and post-closure metrics

Pitfall 3: Data Leakage

Problem: Test set contains proteins from training time period.

Solution:

  • Strict time-based split (test after training cutoff)
  • Remove IEA evidence from both train and test
  • Document split dates in paper/report

Best Practices Summary

  1. Start simple: PLM + logistic regression baseline
  2. Add smoothing: One-step PPI label smoothing (alpha = 0.2)
  3. Enforce hierarchy: Ancestor closure after thresholding
  4. Tune thresholds: Per-class thresholds for Fmax
  5. Graduate carefully: Only add GNNs when baseline plateaus
  6. Ablate everything: Compare each component’s contribution
  7. Document splits: Time-based splits, evidence filtering

Bottom line: This pipeline is simple, strong, and extensible; add GNNs (GCN/GAT) or multimodal fusion when the baseline plateaus.

© Copyright 2017-2025