import pandas as pd import numpy as np from scipy import stats from sklearn.linear_model import LinearRegression from typing import Dict, List, Optional, Tuple, Literal import logging from concurrent.futures import ProcessPoolExecutor import argparse import warnings warnings.filterwarnings('ignore') def setup_logger(verbose: bool = False) -> logging.Logger: """Setup logging configuration""" logger = logging.getLogger('CortoNetwork') logger.setLevel(logging.INFO if verbose else logging.WARNING) # Create console handler with formatting handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) return logger def load_data(expression_file: str, metabolite_file: str, logger: logging.Logger) -> Tuple[pd.DataFrame, pd.DataFrame]: """Load and preprocess data files""" logger.info("Loading expression data...") exp_df = pd.read_csv(expression_file) # Set multi-index and convert to numeric matrix logger.info("Processing expression data...") exp_df.set_index(['gene_id', 'transcript_id'], inplace=True) exp_df = exp_df.apply(pd.to_numeric, errors='coerce') exp_df.index = [f"{idx[0]}_{idx[1]}" for idx in exp_df.index] logger.info(f"Expression matrix shape: {exp_df.shape}") # Load metabolite data logger.info("Loading metabolomics data...") met_df = pd.read_csv(metabolite_file) logger.info("Processing metabolomics data...") met_df.set_index('CCLE_ID', inplace=True) met_df = met_df.select_dtypes(include=[np.number]) met_df = met_df.T logger.info(f"Metabolomics matrix shape: {met_df.shape}") # Align samples common_samples = list(set(exp_df.columns) & set(met_df.columns)) if not common_samples: raise ValueError("No common samples between matrices") logger.info(f"Found {len(common_samples)} common samples") exp_df = exp_df[common_samples] met_df = met_df[common_samples] return exp_df, met_df def remove_zero_variance(df: pd.DataFrame, logger: logging.Logger) -> pd.DataFrame: """Remove features with zero variance""" logger.info(f"Checking variance in matrix of shape {df.shape}") vars = df.var(axis=1) keep = vars[vars > 0].index logger.info(f"Keeping {len(keep)} features with non-zero variance") return df.loc[keep] def p2r(p: float, n: int) -> float: """Convert p-value to correlation coefficient threshold""" t = stats.t.ppf(p/2, df=n-2, loc=0, scale=1) r = np.sqrt((t**2)/(n-2 + t**2)) return r def calculate_correlations_corto(expression_df: pd.DataFrame, metabolite_df: pd.DataFrame, r_threshold: float, logger: logging.Logger) -> pd.DataFrame: """Calculate correlations keeping matrices separate (corto approach)""" logger.info("Calculating correlations...") # Calculate correlations in chunks to save memory chunk_size = 1000 # Adjust based on available memory n_chunks = int(np.ceil(len(expression_df) / chunk_size)) edges = [] for i in range(n_chunks): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, len(expression_df)) logger.info(f"Processing chunk {i+1}/{n_chunks}") exp_chunk = expression_df.iloc[start_idx:end_idx] # Calculate correlations for this chunk chunk_corr = pd.DataFrame( np.corrcoef(exp_chunk, metabolite_df)[ :exp_chunk.shape[0], exp_chunk.shape[0]: ], index=exp_chunk.index, columns=metabolite_df.index ) # Find significant correlations for gene in chunk_corr.index: for metabolite in chunk_corr.columns: corr = chunk_corr.loc[gene, metabolite] if abs(corr) >= r_threshold: edges.append({ 'source': gene, 'target': metabolite, 'correlation': corr, 'type': 'gene_metabolite' }) # Clear memory del chunk_corr logger.info(f"Found {len(edges)} significant correlations") return pd.DataFrame(edges) def calculate_correlations_combined(expression_df: pd.DataFrame, metabolite_df: pd.DataFrame, r_threshold: float, logger: logging.Logger) -> pd.DataFrame: """Calculate correlations using combined matrix approach""" logger.info("Combining matrices...") # Add prefixes and combine exp_prefixed = expression_df.copy() exp_prefixed.index = 'GENE_' + exp_prefixed.index met_prefixed = metabolite_df.copy() met_prefixed.index = 'MET_' + met_prefixed.index combined_df = pd.concat([exp_prefixed, met_prefixed]) logger.info("Calculating correlations...") edges = [] chunk_size = 1000 n_chunks = int(np.ceil(len(combined_df) / chunk_size)) for i in range(n_chunks): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, len(combined_df)) logger.info(f"Processing chunk {i+1}/{n_chunks}") chunk = combined_df.iloc[start_idx:end_idx] chunk_corr = pd.DataFrame( np.corrcoef(chunk, combined_df)[ :chunk.shape[0], chunk.shape[0]: ], index=chunk.index, columns=combined_df.index ) for source in chunk_corr.index: for target in chunk_corr.columns: if source < target: # Only take upper triangle corr = chunk_corr.loc[source, target] if abs(corr) >= r_threshold: type = 'gene_gene' if 'GENE_' in source and 'GENE_' in target else \ 'metabolite_metabolite' if 'MET_' in source and 'MET_' in target else \ 'gene_metabolite' edges.append({ 'source': source, 'target': target, 'correlation': corr, 'type': type }) del chunk_corr logger.info(f"Found {len(edges)} significant correlations") return pd.DataFrame(edges) def bootstrap_network(expression_df: pd.DataFrame, metabolite_df: pd.DataFrame, r_threshold: float, seed: int, logger: logging.Logger) -> List[str]: """Bootstrap for a single iteration""" np.random.seed(seed) # Sample with replacement sample_idx = np.random.choice( expression_df.shape[1], size=expression_df.shape[1], replace=True ) # Sample matrices boot_expression = expression_df.iloc[:, sample_idx] boot_metabolite = metabolite_df.iloc[:, sample_idx] # Calculate correlations for bootstrap sample edges = [] chunk_size = 1000 # Process in chunks to save memory n_chunks = int(np.ceil(len(boot_expression) / chunk_size)) for i in range(n_chunks): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, len(boot_expression)) exp_chunk = boot_expression.iloc[start_idx:end_idx] # Calculate correlations for this chunk chunk_corr = pd.DataFrame( np.corrcoef(exp_chunk, boot_metabolite)[ :exp_chunk.shape[0], exp_chunk.shape[0]: ], index=exp_chunk.index, columns=boot_metabolite.index ) # Find significant correlations for gene in chunk_corr.index: for metabolite in chunk_corr.columns: corr = chunk_corr.loc[gene, metabolite] if abs(corr) >= r_threshold: edges.append({ 'source': gene, 'target': metabolite, 'correlation': corr }) # Find strongest connections for each target winners = [] edge_df = pd.DataFrame(edges) if not edge_df.empty: for target in edge_df['target'].unique(): target_edges = edge_df[edge_df['target'] == target] if not target_edges.empty: winner = target_edges.loc[target_edges['correlation'].abs().idxmax()] winners.append(f"{winner['source']}_{winner['target']}") return winners def main(args): # Setup logging logger = setup_logger(args.verbose) logger.info(f"Starting corto network analysis in {args.mode} mode...") try: # Load data expression_df, metabolite_df = load_data(args.expression_file, args.metabolomics_file, logger) # Remove zero variance features expression_df = remove_zero_variance(expression_df, logger) metabolite_df = remove_zero_variance(metabolite_df, logger) # Calculate correlation threshold r_threshold = p2r(args.p_threshold, len(metabolite_df.columns)) logger.info(f"Using correlation threshold: {r_threshold}") # Calculate initial correlations based on mode if args.mode == 'corto': edge_df = calculate_correlations_corto( expression_df, metabolite_df, r_threshold, logger ) else: edge_df = calculate_correlations_combined( expression_df, metabolite_df, r_threshold, logger ) # Store valid pairs for bootstrapping valid_pairs = set([f"{row['source']}_{row['target']}" for _, row in edge_df.iterrows()]) # Initialize occurrence tracking using valid pairs occurrences = pd.DataFrame({ 'source': edge_df['source'], 'target': edge_df['target'], 'correlation': edge_df['correlation'], 'type': edge_df['type'], # Now using type from edge_df 'occurrences': 0 }) occurrences.index = occurrences['source'] + '_' + occurrences['target'] # Run bootstraps logger.info(f"Running {args.nbootstraps} bootstraps...") with ProcessPoolExecutor(max_workers=args.nthreads) as executor: futures = [ executor.submit( bootstrap_network if args.mode == 'corto' else bootstrap_network_combined, expression_df, metabolite_df, r_threshold, i, logger ) for i in range(args.nbootstraps) ] bootstrap_winners = [] for future in futures: # Only keep winners that were in original valid pairs winners = future.result() valid_winners = [w for w in winners if w in valid_pairs] bootstrap_winners.extend(valid_winners) # Update occurrences winner_counts = pd.Series(bootstrap_winners).value_counts() occurrences.loc[winner_counts.index, 'occurrences'] += winner_counts # Calculate final likelihoods occurrences['likelihood'] = occurrences['occurrences'] / args.nbootstraps # Create regulon object regulon = {} for source in occurrences['source'].unique(): source_edges = occurrences[occurrences['source'] == source] if args.mode == 'corto': regulon[source] = { 'tfmode': dict(zip(source_edges['target'], source_edges['correlation'])), 'likelihood': dict(zip(source_edges['target'], source_edges['likelihood'])) } else: # For combined mode, include edge types regulon[source] = { 'tfmode': dict(zip(source_edges['target'], source_edges['correlation'])), 'likelihood': dict(zip(source_edges['target'], source_edges['likelihood'])), 'edge_types': dict(zip(source_edges['target'], source_edges['type'])) } # Save results logger.info("Saving results...") # Save network with additional stats network_file = f'corto_network_{args.mode}.csv' regulon_file = f'corto_regulon_{args.mode}.txt' occurrences['support'] = occurrences['occurrences'] / args.nbootstraps occurrences['abs_correlation'] = abs(occurrences['correlation']) # Remove prefixes if in combined mode if args.mode == 'combined': occurrences['source'] = occurrences['source'].str.replace('GENE_', '').str.replace('MET_', '') occurrences['target'] = occurrences['target'].str.replace('GENE_', '').str.replace('MET_', '') occurrences.sort_values('abs_correlation', ascending=False).to_csv(network_file) # Save regulon with pretty formatting with open(regulon_file, 'w') as f: f.write(f"# Corto Regulon Analysis\n") f.write(f"# Mode: {args.mode}\n") f.write(f"# Parameters:\n") f.write(f"# p-threshold: {args.p_threshold}\n") f.write(f"# bootstraps: {args.nbootstraps}\n") f.write(f"# edges found: {len(occurrences)}\n\n") for source, data in regulon.items(): source_name = source.replace('GENE_', '').replace('MET_', '') if args.mode == 'combined' else source f.write(f"\n{source_name}:\n") for key, values in data.items(): f.write(f" {key}:\n") if key == 'edge_types': for target, value in values.items(): target_name = target.replace('GENE_', '').replace('MET_', '') f.write(f" {target_name}: {value}\n") else: sorted_items = sorted(values.items(), key=lambda x: abs(x[1]), reverse=True) for target, value in sorted_items: target_name = target.replace('GENE_', '').replace('MET_', '') if args.mode == 'combined' else target f.write(f" {target_name}: {value:.4f}\n") logger.info("Analysis complete!") if args.mode == 'corto': logger.info(f"Found {len(occurrences)} significant gene-metabolite relationships") else: relationship_counts = occurrences['type'].value_counts() for rel_type, count in relationship_counts.items(): logger.info(f"Found {count} significant {rel_type} relationships") logger.info(f"Results saved to {network_file} and {regulon_file}") except Exception as e: logger.error(f"Error during analysis: {str(e)}") raise if __name__ == "__main__": parser = argparse.ArgumentParser(description='Run corto network analysis') parser.add_argument('--mode', choices=['corto', 'combined'], default='corto', help='Analysis mode - either corto or combined (default: corto)') parser.add_argument('--expression_file', required=True, help='Path to expression data file') parser.add_argument('--metabolomics_file', required=True, help='Path to metabolomics data file') parser.add_argument('--p_threshold', type=float, default=1e-30, help='P-value threshold') parser.add_argument('--nbootstraps', type=int, default=100, help='Number of bootstrap iterations') parser.add_argument('--nthreads', type=int, default=4, help='Number of parallel threads') parser.add_argument('--verbose', action='store_true', help='Print verbose output') args = parser.parse_args() main(args)