Files
digital-patients/rna2protexpression.py
Olamide Isreal 9e6a16c19b Initial commit: digital-patients pipeline (clean, no large files)
Large reference/model files excluded from repo - to be staged to S3 or baked into Docker images.
2026-03-26 15:15:23 +01:00

134 lines
6.3 KiB
Python

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from axial_attention import AxialAttention
from thefuzz import fuzz
import argparse
from joblib import load
#load data from docker container
#tissue and ensg preprocessing
tissue2number = load('/home/omic/rna2protexpresson/tissue2number.joblib')
ensg2number = load('/home/omic/rna2protexpresson/ensg2number.joblib')
#model weights
model_path = '/home/omic/rna2protexpresson/go_term_protein_expression_model.pth'
#get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#get go terms
go_df_work = pd.read_csv('/home/omic/rna2protexpresson/go_df_work.csv', index_col=0)
#model
class TissueSpecificProteinExpressionModel(nn.Module):
def __init__(self, emb_dim = 64, dim_heads = 512, n_tissues = 29, n_ensg = 14601, go_size=go_df_work.shape[1]):
super(TissueSpecificProteinExpressionModel, self).__init__()
self.emb_dim = emb_dim
self.emb_ensg = nn.Embedding(n_ensg, int(emb_dim/2))
self.emb_tiss = nn.Embedding(n_tissues, emb_dim)
self.attn1 = AxialAttention(
dim = 1, # embedding dimension
dim_index = 1, # where is the embedding dimension
dim_heads = dim_heads, # dimension of each head. defaults to dim // heads if not supplied
heads = 1, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
self.attn2 = AxialAttention(dim = emb_dim*2+1, dim_index = 1, dim_heads = dim_heads, heads = 1, num_dimensions = 2, sum_axial_out = True)
self.attn3 = AxialAttention(dim = emb_dim*2+1, dim_index = 1, dim_heads = dim_heads, heads = 1, num_dimensions = 2, sum_axial_out = True)
self.con = nn.Conv2d(emb_dim*2+1, 1, 1, stride=1)
self.batch_norm1 = nn.BatchNorm2d(emb_dim*2+1)
self.batch_norm2 = nn.BatchNorm2d(emb_dim*2+1)
self.emb_go = nn.Linear(go_size, int(emb_dim/2))
def forward(self, x, emb_pos_prot, emb_pos_tissu, go_term):
#embeding go terms
emb_in_go = self.emb_go(go_term)
emb_in_go = torch.unsqueeze(torch.permute(emb_in_go, (0,2,1)),-1)
#embeding proteins
emb_in_p = self.emb_ensg(emb_pos_prot)
emb_in_p = torch.unsqueeze(torch.permute(emb_in_p, (0,2,1)),-1)
emb_in_p = torch.cat([emb_in_go,emb_in_p], dim = 1)
emb_in_p = emb_in_p.expand(x.shape[0],self.emb_dim,emb_pos_prot.shape[1],emb_pos_tissu.shape[1])
#embeding tissues
emb_in_t = self.emb_tiss(emb_pos_tissu)
emb_in_t = torch.unsqueeze(torch.permute(emb_in_t, (0,2,1)),-2)
emb_in_t = emb_in_t.expand(x.shape[0],self.emb_dim,emb_pos_prot.shape[1],emb_pos_tissu.shape[1])
#RNA expresson
x = torch.unsqueeze(x, 1)
x = self.attn1(x)
#concat protein embedding, tissue embedding and RNA expresson
x = torch.cat([x,emb_in_p, emb_in_t], dim = 1)
x = self.batch_norm1(x)
x = self.attn2(x)
x = self.batch_norm2(x)
x = self.attn3(x)
x = self.con(x)
x = torch.squeeze(x, 1)
return x
#run model from pandas dataframe
def run_model(model, X_test, goterm, scaler):
model.eval()
X_test_scaled = [scaler.fit_transform(np.array(i).reshape(i.shape[0], -1)).reshape(i.shape) for i in X_test]# np.array(X_test)
X_test_tensor = torch.FloatTensor(X_test_scaled).to(device)
X_test_tissues = [tissue2number.transform(i.columns) for i in X_test]
X_test_ensg = [ensg2number.fit_transform(i.index) for i in X_test]
X_test_tissues = torch.IntTensor(X_test_tissues).to(device)
X_test_ensg = torch.IntTensor(X_test_ensg).to(device)
X_go = torch.FloatTensor(np.array(goterm)).to(device)
test_dataset = TensorDataset(X_test_tensor, X_test_ensg, X_test_tissues, X_go)
test_loader = DataLoader(test_dataset, batch_size=1)
with torch.no_grad():
for batch_X, ensg_numb, tissue_numb, X_go in test_loader:
y_pred = model(batch_X, ensg_numb, tissue_numb, X_go).cpu().numpy()
return y_pred
def run_save_protrin_esxpression_predicition(TPM, goterm, model, scaler, tissue2number, ensg2number, split_len, name):
#set column names
TPM.columns = [i.split(':')[1] for i in TPM.columns]
#drop tissue/columns not used in model
TPM = TPM[TPM.columns.intersection(tissue2number.classes_)]
## this is done with go_df_work merge
#drop ensg/row not used in model
#TPM = TPM.loc[TPM.index.intersection(ensg2number.classes_),:]
#transform to log2 / used to train model
TPM = np.log2(TPM)
#split TPM in 5 000 chunks
pred = [run_model(model, [i], [j], scaler) for i, j in zip(np.array_split(TPM, int(np.ceil(TPM.shape[0] / split_len))), np.array_split(goterm, int(np.ceil(TPM.shape[0] / split_len))))]
#create Dataframe with tissue and ensg names
pred = pd.DataFrame(np.squeeze(np.concatenate(pred, 1)))
pred = pred.set_index(TPM.index)
pred.columns = TPM.columns
#save predicitions log2
pred.to_csv(f'{name}_Protein_Expression_log2.csv')
checkpoint = torch.load(model_path, map_location=device)
model = TissueSpecificProteinExpressionModel().to(device)
model.load_state_dict(checkpoint['model_state_dict'])
scaler = checkpoint['scaler']
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict protein expression from RNA expression')
parser.add_argument('--borzoi_ouput', help='Output from borzoi step', required=True)
args = parser.parse_args()
#borzoi output TPM
TPM = pd.read_csv(args.borzoi_ouput,index_col=0)
#get output name
out_name = args.borzoi_ouput.split('_TPM.csv')[0]
#get go term
X_go = pd.merge(TPM, go_df_work, left_index=True, right_index=True).iloc[:,TPM.shape[1]:]
#drop ensg/row not used in model
TPM = pd.merge(TPM, go_df_work, left_index=True, right_index=True).iloc[:,:TPM.shape[1]]
#run and save output to Protein_Expression_log2.csv
run_save_protrin_esxpression_predicition(TPM, X_go, model, scaler, tissue2number, ensg2number, 2500, out_name)