Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled
154 lines
5.3 KiB
Python
154 lines
5.3 KiB
Python
import warnings
|
|
from importlib.util import find_spec
|
|
|
|
import rootutils
|
|
from beartype.typing import Any, Callable, Dict, List, Optional, Tuple
|
|
from omegaconf import DictConfig
|
|
|
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
|
from flowdock.utils import pylogger, rich_utils
|
|
|
|
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
def extras(cfg: DictConfig) -> None:
|
|
"""Applies optional utilities before the task is started.
|
|
|
|
Utilities:
|
|
- Ignoring python warnings
|
|
- Setting tags from command line
|
|
- Rich config printing
|
|
|
|
:param cfg: A DictConfig object containing the config tree.
|
|
"""
|
|
# return if no `extras` config
|
|
if not cfg.get("extras"):
|
|
log.warning("Extras config not found! <cfg.extras=null>")
|
|
return
|
|
|
|
# disable python warnings
|
|
if cfg.extras.get("ignore_warnings"):
|
|
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
|
warnings.filterwarnings("ignore")
|
|
|
|
# prompt user to input tags from command line if none are provided in the config
|
|
if cfg.extras.get("enforce_tags"):
|
|
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
|
rich_utils.enforce_tags(cfg, save_to_file=True)
|
|
|
|
# pretty print config tree using Rich library
|
|
if cfg.extras.get("print_config"):
|
|
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
|
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
|
|
|
|
|
def task_wrapper(task_func: Callable) -> Callable:
|
|
"""Optional decorator that controls the failure behavior when executing the task function.
|
|
|
|
This wrapper can be used to:
|
|
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
|
- save the exception to a `.log` file
|
|
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
|
- etc. (adjust depending on your needs)
|
|
|
|
Example:
|
|
```
|
|
@utils.task_wrapper
|
|
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
...
|
|
return metric_dict, object_dict
|
|
```
|
|
|
|
:param task_func: The task function to be wrapped.
|
|
|
|
:return: The wrapped task function.
|
|
"""
|
|
|
|
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
# execute the task
|
|
try:
|
|
metric_dict, object_dict = task_func(cfg=cfg)
|
|
|
|
# things to do if exception occurs
|
|
except Exception as ex:
|
|
# save exception to `.log` file
|
|
log.exception("")
|
|
|
|
# some hyperparameter combinations might be invalid or cause out-of-memory errors
|
|
# so when using hparam search plugins like Optuna, you might want to disable
|
|
# raising the below exception to avoid multirun failure
|
|
raise ex
|
|
|
|
# things to always do after either success or exception
|
|
finally:
|
|
# display output dir path in terminal
|
|
log.info(f"Output dir: {cfg.paths.output_dir}")
|
|
|
|
# always close wandb run (even if exception occurs so multirun won't fail)
|
|
if find_spec("wandb"): # check if wandb is installed
|
|
import wandb
|
|
|
|
if wandb.run:
|
|
log.info("Closing wandb!")
|
|
wandb.finish()
|
|
|
|
return metric_dict, object_dict
|
|
|
|
return wrap
|
|
|
|
|
|
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
|
|
"""Safely retrieves value of the metric logged in LightningModule.
|
|
|
|
:param metric_dict: A dict containing metric values.
|
|
:param metric_name: If provided, the name of the metric to retrieve.
|
|
:return: If a metric name was provided, the value of the metric.
|
|
"""
|
|
if not metric_name:
|
|
log.info("Metric name is None! Skipping metric value retrieval...")
|
|
return None
|
|
|
|
if metric_name not in metric_dict:
|
|
raise Exception(
|
|
f"Metric value not found! <metric_name={metric_name}>\n"
|
|
"Make sure metric name logged in LightningModule is correct!\n"
|
|
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
|
)
|
|
|
|
metric_value = metric_dict[metric_name].item()
|
|
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
|
|
|
return metric_value
|
|
|
|
|
|
def read_strings_from_txt(path: str) -> List[str]:
|
|
"""Reads strings from a text file and returns them as a list.
|
|
|
|
:param path: Path to the text file.
|
|
:return: List of strings.
|
|
"""
|
|
with open(path) as file:
|
|
# NOTE: every line will be one element of the returned list
|
|
lines = file.readlines()
|
|
return [line.rstrip() for line in lines]
|
|
|
|
|
|
def fasta_to_dict(filename: str) -> Dict[str, str]:
|
|
"""Converts a FASTA file to a dictionary where the keys are the sequence IDs and the values are
|
|
the sequences.
|
|
|
|
:param filename: Path to the FASTA file.
|
|
:return: Dictionary with sequence IDs as keys and sequences as values.
|
|
"""
|
|
fasta_dict = {}
|
|
with open(filename) as file:
|
|
for line in file:
|
|
line = line.rstrip() # remove trailing whitespace
|
|
if line.startswith(">"): # identifier line
|
|
seq_id = line[1:] # remove the '>' character
|
|
fasta_dict[seq_id] = ""
|
|
else: # sequence line
|
|
fasta_dict[seq_id] += line
|
|
return fasta_dict
|