Files
flowdock/tests/helpers/run_if.py
Olamide Isreal a3ffec6a07
Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled
Initial commit: FlowDock pipeline configured for WES execution
2026-03-16 15:23:29 +01:00

143 lines
4.4 KiB
Python

"""Adapted from:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py
"""
import sys
import pytest
import torch
from beartype.typing import Any, Dict, Optional
from packaging.version import Version
from pkg_resources import get_distribution
from pytest import MarkDecorator
from tests.helpers.package_available import (
_COMET_AVAILABLE,
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_IS_WINDOWS,
_MLFLOW_AVAILABLE,
_NEPTUNE_AVAILABLE,
_SH_AVAILABLE,
_TPU_AVAILABLE,
_WANDB_AVAILABLE,
)
class RunIf:
"""RunIf wrapper for conditional skipping of tests.
Fully compatible with `@pytest.mark`.
Example:
```python
@RunIf(min_torch="1.8")
@pytest.mark.parametrize("arg1", [1.0, 2.0])
def test_wrapper(arg1):
assert arg1 > 0
```
"""
def __new__(
cls,
min_gpus: int = 0,
min_torch: Optional[str] = None,
max_torch: Optional[str] = None,
min_python: Optional[str] = None,
skip_windows: bool = False,
sh: bool = False,
tpu: bool = False,
fairscale: bool = False,
deepspeed: bool = False,
wandb: bool = False,
neptune: bool = False,
comet: bool = False,
mlflow: bool = False,
**kwargs: Dict[Any, Any],
) -> MarkDecorator:
"""Creates a new `@RunIf` `MarkDecorator` decorator.
:param min_gpus: Min number of GPUs required to run test.
:param min_torch: Minimum pytorch version to run test.
:param max_torch: Maximum pytorch version to run test.
:param min_python: Minimum python version required to run test.
:param skip_windows: Skip test for Windows platform.
:param tpu: If TPU is available.
:param sh: If `sh` module is required to run the test.
:param fairscale: If `fairscale` module is required to run the test.
:param deepspeed: If `deepspeed` module is required to run the test.
:param wandb: If `wandb` module is required to run the test.
:param neptune: If `neptune` module is required to run the test.
:param comet: If `comet` module is required to run the test.
:param mlflow: If `mlflow` module is required to run the test.
:param kwargs: Native `pytest.mark.skipif` keyword arguments.
"""
conditions = []
reasons = []
if min_gpus:
conditions.append(torch.cuda.device_count() < min_gpus)
reasons.append(f"GPUs>={min_gpus}")
if min_torch:
torch_version = get_distribution("torch").version
conditions.append(Version(torch_version) < Version(min_torch))
reasons.append(f"torch>={min_torch}")
if max_torch:
torch_version = get_distribution("torch").version
conditions.append(Version(torch_version) >= Version(max_torch))
reasons.append(f"torch<{max_torch}")
if min_python:
py_version = (
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
)
conditions.append(Version(py_version) < Version(min_python))
reasons.append(f"python>={min_python}")
if skip_windows:
conditions.append(_IS_WINDOWS)
reasons.append("does not run on Windows")
if tpu:
conditions.append(not _TPU_AVAILABLE)
reasons.append("TPU")
if sh:
conditions.append(not _SH_AVAILABLE)
reasons.append("sh")
if fairscale:
conditions.append(not _FAIRSCALE_AVAILABLE)
reasons.append("fairscale")
if deepspeed:
conditions.append(not _DEEPSPEED_AVAILABLE)
reasons.append("deepspeed")
if wandb:
conditions.append(not _WANDB_AVAILABLE)
reasons.append("wandb")
if neptune:
conditions.append(not _NEPTUNE_AVAILABLE)
reasons.append("neptune")
if comet:
conditions.append(not _COMET_AVAILABLE)
reasons.append("comet")
if mlflow:
conditions.append(not _MLFLOW_AVAILABLE)
reasons.append("mlflow")
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
condition=any(conditions),
reason=f"Requires: [{' + '.join(reasons)}]",
**kwargs,
)