# Use NVIDIA CUDA base image with Ubuntu 22.04 FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04 # Set environment variables ENV DEBIAN_FRONTEND=noninteractive \ LANG=C.UTF-8 \ LC_ALL=C.UTF-8 \ PYTHONUNBUFFERED=TRUE \ PYTHONFAULTHANDLER=1 \ PYTHONPYCACHEPREFIX='/tmp/.chai_pycache' \ MYPY_CACHE_DIR='/tmp/.chai_mypy_cache' # Install system dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ git \ wget \ curl \ ca-certificates \ python3.10 \ python3.10-dev \ python3-pip \ build-essential \ && rm -rf /var/lib/apt/lists/* # Set working directory WORKDIR /workspace # Upgrade pip RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel # Install chai_lab first (this will install older PyTorch) RUN pip3 install --no-cache-dir chai_lab==0.5.2 # Force uninstall old PyTorch and related packages RUN pip3 uninstall -y torch torchvision torchaudio # Install PyTorch 2.6+ from main PyPI (has CUDA support built-in) RUN pip3 install --no-cache-dir torch torchvision torchaudio # Upgrade transformers to ensure compatibility RUN pip3 install --no-cache-dir --upgrade "transformers>=4.30.0" # Verify all installations RUN python3 -c "import torch; v=torch.__version__.split('+')[0]; print(f'PyTorch: {v}'); major,minor=map(int,v.split('.')[:2]); assert (major==2 and minor>=6) or major>2, f'PyTorch {v} is too old, need 2.6+'" && \ python3 -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" && \ python3 -c "from transformers import EsmModel; print('transformers: OK')" && \ python3 -c "import typer; print('typer: OK')" && \ python3 -c "import chai_lab; print('chai_lab: OK')" && \ chai --help # Add entry point script COPY entrypoint.sh /workspace/ RUN chmod +x /workspace/entrypoint.sh # Set entry point ENTRYPOINT ["/workspace/entrypoint.sh"]