diff --git a/Dockerfile b/Dockerfile index f2b842b..39a7d80 100755 --- a/Dockerfile +++ b/Dockerfile @@ -1,51 +1,23 @@ -# Use NVIDIA CUDA base image with Ubuntu 22.04 -FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04 +# Use PyTorch base image with CUDA support (much smaller than building from scratch) +FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime # Set environment variables ENV DEBIAN_FRONTEND=noninteractive \ - LANG=C.UTF-8 \ - LC_ALL=C.UTF-8 \ PYTHONUNBUFFERED=TRUE \ - PYTHONFAULTHANDLER=1 \ - PYTHONPYCACHEPREFIX='/tmp/.chai_pycache' \ - MYPY_CACHE_DIR='/tmp/.chai_mypy_cache' - -# Install system dependencies -RUN apt-get update && apt-get install -y --no-install-recommends \ - git \ - wget \ - curl \ - ca-certificates \ - python3.10 \ - python3.10-dev \ - python3-pip \ - build-essential \ - && rm -rf /var/lib/apt/lists/* + PYTHONFAULTHANDLER=1 # Set working directory WORKDIR /workspace -# Upgrade pip -RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel +# Install chai_lab and transformers in a single layer +RUN pip install --no-cache-dir \ + chai_lab==0.5.2 \ + "transformers>=4.30.0" -# Install chai_lab first (this will install older PyTorch) -RUN pip3 install --no-cache-dir chai_lab==0.5.2 - -# Force uninstall old PyTorch and related packages -RUN pip3 uninstall -y torch torchvision torchaudio - -# Install PyTorch 2.6+ from main PyPI (has CUDA support built-in) -RUN pip3 install --no-cache-dir torch torchvision torchaudio - -# Upgrade transformers to ensure compatibility -RUN pip3 install --no-cache-dir --upgrade "transformers>=4.30.0" - -# Verify all installations -RUN python3 -c "import torch; v=torch.__version__.split('+')[0]; print(f'PyTorch: {v}'); major,minor=map(int,v.split('.')[:2]); assert (major==2 and minor>=6) or major>2, f'PyTorch {v} is too old, need 2.6+'" && \ - python3 -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" && \ - python3 -c "from transformers import EsmModel; print('transformers: OK')" && \ - python3 -c "import typer; print('typer: OK')" && \ - python3 -c "import chai_lab; print('chai_lab: OK')" && \ +# Verify installations +RUN python -c "import torch; print(f'PyTorch: {torch.__version__}')" && \ + python -c "from transformers import EsmModel; print('transformers: OK')" && \ + python -c "import chai_lab; print('chai_lab: OK')" && \ chai --help # Add entry point script