Files
corto/corto-matrix-combination.py

410 lines
16 KiB
Python

import pandas as pd
import numpy as np
from scipy import stats
from sklearn.linear_model import LinearRegression
from typing import Dict, List, Optional, Tuple, Literal
import logging
from concurrent.futures import ProcessPoolExecutor
import argparse
import warnings
warnings.filterwarnings('ignore')
def setup_logger(verbose: bool = False) -> logging.Logger:
"""Setup logging configuration"""
logger = logging.getLogger('CortoNetwork')
logger.setLevel(logging.INFO if verbose else logging.WARNING)
# Create console handler with formatting
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def load_data(expression_file: str, metabolite_file: str, logger: logging.Logger) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Load and preprocess data files"""
logger.info("Loading expression data...")
exp_df = pd.read_csv(expression_file)
# Set multi-index and convert to numeric matrix
logger.info("Processing expression data...")
exp_df.set_index(['gene_id', 'transcript_id'], inplace=True)
exp_df = exp_df.apply(pd.to_numeric, errors='coerce')
exp_df.index = [f"{idx[0]}_{idx[1]}" for idx in exp_df.index]
logger.info(f"Expression matrix shape: {exp_df.shape}")
# Load metabolite data
logger.info("Loading metabolomics data...")
met_df = pd.read_csv(metabolite_file)
logger.info("Processing metabolomics data...")
met_df.set_index('CCLE_ID', inplace=True)
met_df = met_df.select_dtypes(include=[np.number])
met_df = met_df.T
logger.info(f"Metabolomics matrix shape: {met_df.shape}")
# Align samples
common_samples = list(set(exp_df.columns) & set(met_df.columns))
if not common_samples:
raise ValueError("No common samples between matrices")
logger.info(f"Found {len(common_samples)} common samples")
exp_df = exp_df[common_samples]
met_df = met_df[common_samples]
return exp_df, met_df
def remove_zero_variance(df: pd.DataFrame, logger: logging.Logger) -> pd.DataFrame:
"""Remove features with zero variance"""
logger.info(f"Checking variance in matrix of shape {df.shape}")
vars = df.var(axis=1)
keep = vars[vars > 0].index
logger.info(f"Keeping {len(keep)} features with non-zero variance")
return df.loc[keep]
def p2r(p: float, n: int) -> float:
"""Convert p-value to correlation coefficient threshold"""
t = stats.t.ppf(p/2, df=n-2, loc=0, scale=1)
r = np.sqrt((t**2)/(n-2 + t**2))
return r
def calculate_correlations_corto(expression_df: pd.DataFrame,
metabolite_df: pd.DataFrame,
r_threshold: float,
logger: logging.Logger) -> pd.DataFrame:
"""Calculate correlations keeping matrices separate (corto approach)"""
logger.info("Calculating correlations...")
# Calculate correlations in chunks to save memory
chunk_size = 1000 # Adjust based on available memory
n_chunks = int(np.ceil(len(expression_df) / chunk_size))
edges = []
for i in range(n_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(expression_df))
logger.info(f"Processing chunk {i+1}/{n_chunks}")
exp_chunk = expression_df.iloc[start_idx:end_idx]
# Calculate correlations for this chunk
chunk_corr = pd.DataFrame(
np.corrcoef(exp_chunk, metabolite_df)[
:exp_chunk.shape[0],
exp_chunk.shape[0]:
],
index=exp_chunk.index,
columns=metabolite_df.index
)
# Find significant correlations
for gene in chunk_corr.index:
for metabolite in chunk_corr.columns:
corr = chunk_corr.loc[gene, metabolite]
if abs(corr) >= r_threshold:
edges.append({
'source': gene,
'target': metabolite,
'correlation': corr,
'type': 'gene_metabolite'
})
# Clear memory
del chunk_corr
logger.info(f"Found {len(edges)} significant correlations")
return pd.DataFrame(edges)
def calculate_correlations_combined(expression_df: pd.DataFrame,
metabolite_df: pd.DataFrame,
r_threshold: float,
logger: logging.Logger) -> pd.DataFrame:
"""Calculate correlations using combined matrix approach"""
logger.info("Combining matrices...")
# Add prefixes and combine
exp_prefixed = expression_df.copy()
exp_prefixed.index = 'GENE_' + exp_prefixed.index
met_prefixed = metabolite_df.copy()
met_prefixed.index = 'MET_' + met_prefixed.index
combined_df = pd.concat([exp_prefixed, met_prefixed])
logger.info("Calculating correlations...")
edges = []
chunk_size = 1000
n_chunks = int(np.ceil(len(combined_df) / chunk_size))
for i in range(n_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(combined_df))
logger.info(f"Processing chunk {i+1}/{n_chunks}")
chunk = combined_df.iloc[start_idx:end_idx]
chunk_corr = pd.DataFrame(
np.corrcoef(chunk, combined_df)[
:chunk.shape[0],
chunk.shape[0]:
],
index=chunk.index,
columns=combined_df.index
)
for source in chunk_corr.index:
for target in chunk_corr.columns:
if source < target: # Only take upper triangle
corr = chunk_corr.loc[source, target]
if abs(corr) >= r_threshold:
type = 'gene_gene' if 'GENE_' in source and 'GENE_' in target else \
'metabolite_metabolite' if 'MET_' in source and 'MET_' in target else \
'gene_metabolite'
edges.append({
'source': source,
'target': target,
'correlation': corr,
'type': type
})
del chunk_corr
logger.info(f"Found {len(edges)} significant correlations")
return pd.DataFrame(edges)
def bootstrap_network(expression_df: pd.DataFrame,
metabolite_df: pd.DataFrame,
r_threshold: float,
seed: int,
logger: logging.Logger) -> List[str]:
"""Bootstrap for a single iteration"""
np.random.seed(seed)
# Sample with replacement
sample_idx = np.random.choice(
expression_df.shape[1],
size=expression_df.shape[1],
replace=True
)
# Sample matrices
boot_expression = expression_df.iloc[:, sample_idx]
boot_metabolite = metabolite_df.iloc[:, sample_idx]
# Calculate correlations for bootstrap sample
edges = []
chunk_size = 1000 # Process in chunks to save memory
n_chunks = int(np.ceil(len(boot_expression) / chunk_size))
for i in range(n_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(boot_expression))
exp_chunk = boot_expression.iloc[start_idx:end_idx]
# Calculate correlations for this chunk
chunk_corr = pd.DataFrame(
np.corrcoef(exp_chunk, boot_metabolite)[
:exp_chunk.shape[0],
exp_chunk.shape[0]:
],
index=exp_chunk.index,
columns=boot_metabolite.index
)
# Find significant correlations
for gene in chunk_corr.index:
for metabolite in chunk_corr.columns:
corr = chunk_corr.loc[gene, metabolite]
if abs(corr) >= r_threshold:
edges.append({
'source': gene,
'target': metabolite,
'correlation': corr
})
# Find strongest connections for each target
winners = []
edge_df = pd.DataFrame(edges)
if not edge_df.empty:
for target in edge_df['target'].unique():
target_edges = edge_df[edge_df['target'] == target]
if not target_edges.empty:
winner = target_edges.loc[target_edges['correlation'].abs().idxmax()]
winners.append(f"{winner['source']}_{winner['target']}")
return winners
def main(args):
# Setup logging
logger = setup_logger(args.verbose)
logger.info(f"Starting corto network analysis in {args.mode} mode...")
try:
# Load data
expression_df, metabolite_df = load_data(args.expression_file, args.metabolomics_file, logger)
# Remove zero variance features
expression_df = remove_zero_variance(expression_df, logger)
metabolite_df = remove_zero_variance(metabolite_df, logger)
# Calculate correlation threshold
r_threshold = p2r(args.p_threshold, len(metabolite_df.columns))
logger.info(f"Using correlation threshold: {r_threshold}")
# Calculate initial correlations based on mode
if args.mode == 'corto':
edge_df = calculate_correlations_corto(
expression_df,
metabolite_df,
r_threshold,
logger
)
else:
edge_df = calculate_correlations_combined(
expression_df,
metabolite_df,
r_threshold,
logger
)
# Store valid pairs for bootstrapping
valid_pairs = set([f"{row['source']}_{row['target']}" for _, row in edge_df.iterrows()])
# Initialize occurrence tracking using valid pairs
occurrences = pd.DataFrame({
'source': edge_df['source'],
'target': edge_df['target'],
'correlation': edge_df['correlation'],
'type': edge_df['type'], # Now using type from edge_df
'occurrences': 0
})
occurrences.index = occurrences['source'] + '_' + occurrences['target']
# Run bootstraps
logger.info(f"Running {args.nbootstraps} bootstraps...")
with ProcessPoolExecutor(max_workers=args.nthreads) as executor:
futures = [
executor.submit(
bootstrap_network if args.mode == 'corto' else bootstrap_network_combined,
expression_df,
metabolite_df,
r_threshold,
i,
logger
)
for i in range(args.nbootstraps)
]
bootstrap_winners = []
for future in futures:
# Only keep winners that were in original valid pairs
winners = future.result()
valid_winners = [w for w in winners if w in valid_pairs]
bootstrap_winners.extend(valid_winners)
# Update occurrences
winner_counts = pd.Series(bootstrap_winners).value_counts()
occurrences.loc[winner_counts.index, 'occurrences'] += winner_counts
# Calculate final likelihoods
occurrences['likelihood'] = occurrences['occurrences'] / args.nbootstraps
# Create regulon object
regulon = {}
for source in occurrences['source'].unique():
source_edges = occurrences[occurrences['source'] == source]
if args.mode == 'corto':
regulon[source] = {
'tfmode': dict(zip(source_edges['target'], source_edges['correlation'])),
'likelihood': dict(zip(source_edges['target'], source_edges['likelihood']))
}
else:
# For combined mode, include edge types
regulon[source] = {
'tfmode': dict(zip(source_edges['target'], source_edges['correlation'])),
'likelihood': dict(zip(source_edges['target'], source_edges['likelihood'])),
'edge_types': dict(zip(source_edges['target'], source_edges['type']))
}
# Save results
logger.info("Saving results...")
# Save network with additional stats
network_file = f'corto_network_{args.mode}.csv'
regulon_file = f'corto_regulon_{args.mode}.txt'
occurrences['support'] = occurrences['occurrences'] / args.nbootstraps
occurrences['abs_correlation'] = abs(occurrences['correlation'])
# Remove prefixes if in combined mode
if args.mode == 'combined':
occurrences['source'] = occurrences['source'].str.replace('GENE_', '').str.replace('MET_', '')
occurrences['target'] = occurrences['target'].str.replace('GENE_', '').str.replace('MET_', '')
occurrences.sort_values('abs_correlation', ascending=False).to_csv(network_file)
# Save regulon with pretty formatting
with open(regulon_file, 'w') as f:
f.write(f"# Corto Regulon Analysis\n")
f.write(f"# Mode: {args.mode}\n")
f.write(f"# Parameters:\n")
f.write(f"# p-threshold: {args.p_threshold}\n")
f.write(f"# bootstraps: {args.nbootstraps}\n")
f.write(f"# edges found: {len(occurrences)}\n\n")
for source, data in regulon.items():
source_name = source.replace('GENE_', '').replace('MET_', '') if args.mode == 'combined' else source
f.write(f"\n{source_name}:\n")
for key, values in data.items():
f.write(f" {key}:\n")
if key == 'edge_types':
for target, value in values.items():
target_name = target.replace('GENE_', '').replace('MET_', '')
f.write(f" {target_name}: {value}\n")
else:
sorted_items = sorted(values.items(), key=lambda x: abs(x[1]), reverse=True)
for target, value in sorted_items:
target_name = target.replace('GENE_', '').replace('MET_', '') if args.mode == 'combined' else target
f.write(f" {target_name}: {value:.4f}\n")
logger.info("Analysis complete!")
if args.mode == 'corto':
logger.info(f"Found {len(occurrences)} significant gene-metabolite relationships")
else:
relationship_counts = occurrences['type'].value_counts()
for rel_type, count in relationship_counts.items():
logger.info(f"Found {count} significant {rel_type} relationships")
logger.info(f"Results saved to {network_file} and {regulon_file}")
except Exception as e:
logger.error(f"Error during analysis: {str(e)}")
raise
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run corto network analysis')
parser.add_argument('--mode', choices=['corto', 'combined'], default='corto',
help='Analysis mode - either corto or combined (default: corto)')
parser.add_argument('--expression_file', required=True,
help='Path to expression data file')
parser.add_argument('--metabolomics_file', required=True,
help='Path to metabolomics data file')
parser.add_argument('--p_threshold', type=float, default=1e-30,
help='P-value threshold')
parser.add_argument('--nbootstraps', type=int, default=100,
help='Number of bootstrap iterations')
parser.add_argument('--nthreads', type=int, default=4,
help='Number of parallel threads')
parser.add_argument('--verbose', action='store_true',
help='Print verbose output')
args = parser.parse_args()
main(args)