From 039861f1c72ad68f9ba0c8a8149dffda54bc697d Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 6 Aug 2024 15:02:41 -0700 Subject: [PATCH] update inference config to take model and not model_dir --- llama_toolchain/cli/download.py | 7 +- llama_toolchain/common/model_utils.py | 8 + .../inference/meta_reference/config.py | 48 +-- .../inference/meta_reference/generation.py | 26 +- .../meta_reference/model_parallel.py | 14 +- llama_toolchain/inference/ollama/ollama.py | 6 +- ollama_install.sh | 340 ++++++++++++++++++ tests/test_inference.py | 42 +-- tests/test_ollama_inference.py | 10 +- 9 files changed, 400 insertions(+), 101 deletions(-) create mode 100644 llama_toolchain/common/model_utils.py create mode 100644 ollama_install.sh diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index 1fa420f4b..b268d3b8d 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -75,11 +75,13 @@ safetensors files to avoid downloading duplicate weights. from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + from llama_toolchain.common.model_utils import model_local_dir + repo_id = model.huggingface_repo if repo_id is None: raise ValueError(f"No repo id found for model {model.descriptor()}") - output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor() + output_dir = model_local_dir(model) os.makedirs(output_dir, exist_ok=True) try: true_output_dir = snapshot_download( @@ -107,8 +109,9 @@ safetensors files to avoid downloading duplicate weights. def _meta_download(self, model: "Model", meta_url: str): from llama_models.sku_list import llama_meta_net_info + from llama_toolchain.common.model_utils import model_local_dir - output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor() + output_dir = model_local_dir(model) os.makedirs(output_dir, exist_ok=True) info = llama_meta_net_info(model) diff --git a/llama_toolchain/common/model_utils.py b/llama_toolchain/common/model_utils.py new file mode 100644 index 000000000..af3929cb7 --- /dev/null +++ b/llama_toolchain/common/model_utils.py @@ -0,0 +1,8 @@ +import os +from llama_models.datatypes import Model + +from .config_dirs import DEFAULT_CHECKPOINT_DIR + + +def model_local_dir(model: Model) -> str: + return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor()) diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index 0f5bf8eb4..45e0247b7 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -4,61 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import Literal, Optional, Union +from typing import Optional -from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat - -from pydantic import BaseModel, Field +from pydantic import BaseModel from strong_typing.schema import json_schema_type -from typing_extensions import Annotated from llama_toolchain.inference.api import QuantizationConfig -@json_schema_type -class CheckpointType(Enum): - pytorch = "pytorch" - huggingface = "huggingface" - - -@json_schema_type -class PytorchCheckpoint(BaseModel): - checkpoint_type: Literal[CheckpointType.pytorch.value] = ( - CheckpointType.pytorch.value - ) - checkpoint_dir: str - tokenizer_path: str - model_parallel_size: int - quantization_format: CheckpointQuantizationFormat = ( - CheckpointQuantizationFormat.bf16 - ) - - -@json_schema_type -class HuggingFaceCheckpoint(BaseModel): - checkpoint_type: Literal[CheckpointType.huggingface.value] = ( - CheckpointType.huggingface.value - ) - repo_id: str # or model_name ? - model_parallel_size: int - quantization_format: CheckpointQuantizationFormat = ( - CheckpointQuantizationFormat.bf16 - ) - - -@json_schema_type -class ModelCheckpointConfig(BaseModel): - checkpoint: Annotated[ - Union[PytorchCheckpoint, HuggingFaceCheckpoint], - Field(discriminator="checkpoint_type"), - ] - - @json_schema_type class MetaReferenceImplConfig(BaseModel): model: str - checkpoint_config: ModelCheckpointConfig quantization: Optional[QuantizationConfig] = None torch_seed: Optional[int] = None max_seq_len: int diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index 70580995d..0f8df84ac 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -27,11 +27,22 @@ from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.model import Transformer from llama_models.llama3_1.api.tokenizer import Tokenizer +from llama_models.sku_list import resolve_model from termcolor import cprint +from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.inference.api import QuantizationType -from .config import CheckpointType, MetaReferenceImplConfig +from .config import MetaReferenceImplConfig + + +def model_checkpoint_dir(model) -> str: + checkpoint_dir = Path(model_local_dir(model)) + if not Path(checkpoint_dir / "consolidated.00.pth").exists(): + checkpoint_dir = checkpoint_dir / "original" + + assert checkpoint_dir.exists(), f"Could not find checkpoint dir: {checkpoint_dir}" + return str(checkpoint_dir) @dataclass @@ -51,9 +62,7 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ - checkpoint = config.checkpoint_config.checkpoint - if checkpoint.checkpoint_type != CheckpointType.pytorch.value: - raise NotImplementedError("HuggingFace checkpoints not supported yet") + model = resolve_model(config.model) if ( config.quantization @@ -67,7 +76,7 @@ class Llama: if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") - model_parallel_size = checkpoint.model_parallel_size + model_parallel_size = model.hardware_requirements.gpu_count if not model_parallel_is_initialized(): initialize_model_parallel(model_parallel_size) @@ -82,7 +91,8 @@ class Llama: sys.stdout = open(os.devnull, "w") start_time = time.time() - ckpt_dir = checkpoint.checkpoint_dir + ckpt_dir = model_checkpoint_dir(model) + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert model_parallel_size == len( @@ -103,7 +113,9 @@ class Llama: max_batch_size=config.max_batch_size, **params, ) - tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path) + + tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model") + tokenizer = Tokenizer(model_path=tokenizer_path) assert ( model_args.vocab_size == tokenizer.n_words diff --git a/llama_toolchain/inference/meta_reference/model_parallel.py b/llama_toolchain/inference/meta_reference/model_parallel.py index 58fbd2177..dee05d8d5 100644 --- a/llama_toolchain/inference/meta_reference/model_parallel.py +++ b/llama_toolchain/inference/meta_reference/model_parallel.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os from copy import deepcopy from dataclasses import dataclass from functools import partial @@ -12,9 +13,10 @@ from typing import Generator, List, Optional from llama_models.llama3_1.api.chat_format import ChatFormat from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.tokenizer import Tokenizer +from llama_models.sku_list import resolve_model from .config import MetaReferenceImplConfig -from .generation import Llama +from .generation import Llama, model_checkpoint_dir from .parallel_utils import ModelParallelProcessGroup @@ -60,11 +62,12 @@ class LlamaModelParallelGenerator: def __init__(self, config: MetaReferenceImplConfig): self.config = config - + self.model = resolve_model(self.config.model) # this is a hack because Agent's loop uses this to tokenize and check if input is too long # while the tool-use loop is going - checkpoint = self.config.checkpoint_config.checkpoint - self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path)) + checkpoint_dir = model_checkpoint_dir(self.model) + tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model") + self.formatter = ChatFormat(Tokenizer(tokenizer_path)) def start(self): self.__enter__() @@ -73,9 +76,8 @@ class LlamaModelParallelGenerator: self.__exit__(None, None, None) def __enter__(self): - checkpoint = self.config.checkpoint_config.checkpoint self.group = ModelParallelProcessGroup( - checkpoint.model_parallel_size, + self.model.hardware_requirements.gpu_count, init_model_cb=partial(init_model_cb, self.config), ) self.group.start() diff --git a/llama_toolchain/inference/ollama/ollama.py b/llama_toolchain/inference/ollama/ollama.py index 3afd1326b..f28d3c637 100644 --- a/llama_toolchain/inference/ollama/ollama.py +++ b/llama_toolchain/inference/ollama/ollama.py @@ -44,11 +44,13 @@ OLLAMA_SUPPORTED_SKUS = { } -def get_provider_impl(config: OllamaImplConfig) -> Inference: +async def get_provider_impl(config: OllamaImplConfig) -> Inference: assert isinstance( config, OllamaImplConfig ), f"Unexpected config type: {type(config)}" - return OllamaInference(config) + impl = OllamaInference(config) + await impl.initialize() + return impl class OllamaInference(Inference): diff --git a/ollama_install.sh b/ollama_install.sh new file mode 100644 index 000000000..aa8b3e5e3 --- /dev/null +++ b/ollama_install.sh @@ -0,0 +1,340 @@ +#!/bin/sh +# This script installs Ollama on Linux. +# It detects the current operating system architecture and installs the appropriate version of Ollama. + +set -eu + +status() { echo ">>> $*" >&2; } +error() { echo "ERROR $*"; exit 1; } +warning() { echo "WARNING: $*"; } + +TEMP_DIR=$(mktemp -d) +cleanup() { rm -rf $TEMP_DIR; } +trap cleanup EXIT + +available() { command -v $1 >/dev/null; } +require() { + local MISSING='' + for TOOL in $*; do + if ! available $TOOL; then + MISSING="$MISSING $TOOL" + fi + done + + echo $MISSING +} + +[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.' + +ARCH=$(uname -m) +case "$ARCH" in + x86_64) ARCH="amd64" ;; + aarch64|arm64) ARCH="arm64" ;; + *) error "Unsupported architecture: $ARCH" ;; +esac + +IS_WSL2=false + +KERN=$(uname -r) +case "$KERN" in + *icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;; + *icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version 2'" ;; + *) ;; +esac + +VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}" + +SUDO= +if [ "$(id -u)" -ne 0 ]; then + # Running as root, no need for sudo + if ! available sudo; then + error "This script requires superuser permissions. Please re-run as root." + fi + + SUDO="sudo" +fi + +NEEDS=$(require curl awk grep sed tee xargs) +if [ -n "$NEEDS" ]; then + status "ERROR: The following tools are required but missing:" + for NEED in $NEEDS; do + echo " - $NEED" + done + exit 1 +fi + +status "Downloading ollama..." +curl --fail --show-error --location --progress-bar -o $TEMP_DIR/ollama "https://ollama.com/download/ollama-linux-${ARCH}${VER_PARAM}" + +for BINDIR in /usr/local/bin /usr/bin /bin; do + echo $PATH | grep -q $BINDIR && break || continue +done + +status "Installing ollama to $BINDIR..." +$SUDO install -o0 -g0 -m755 -d $BINDIR +$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama + +install_success() { + status 'The Ollama API is now available at 127.0.0.1:11434.' + status 'Install complete. Run "ollama" from the command line.' +} +trap install_success EXIT + +# Everything from this point onwards is optional. + +configure_systemd() { + if ! id ollama >/dev/null 2>&1; then + status "Creating ollama user..." + $SUDO useradd -r -s /bin/false -U -m -d /usr/share/ollama ollama + fi + if getent group render >/dev/null 2>&1; then + status "Adding ollama user to render group..." + $SUDO usermod -a -G render ollama + fi + if getent group video >/dev/null 2>&1; then + status "Adding ollama user to video group..." + $SUDO usermod -a -G video ollama + fi + + status "Adding current user to ollama group..." + $SUDO usermod -a -G ollama $(whoami) + + status "Creating ollama systemd service..." + cat </dev/null +[Unit] +Description=Ollama Service +After=network-online.target + +[Service] +ExecStart=$BINDIR/ollama serve +User=ollama +Group=ollama +Restart=always +RestartSec=3 +Environment="PATH=$PATH" + +[Install] +WantedBy=default.target +EOF + SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)" + case $SYSTEMCTL_RUNNING in + running|degraded) + status "Enabling and starting ollama service..." + $SUDO systemctl daemon-reload + $SUDO systemctl enable ollama + + start_service() { $SUDO systemctl restart ollama; } + trap start_service EXIT + ;; + esac +} + +if available systemctl; then + configure_systemd +fi + +# WSL2 only supports GPUs via nvidia passthrough +# so check for nvidia-smi to determine if GPU is available +if [ "$IS_WSL2" = true ]; then + if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then + status "Nvidia GPU detected." + fi + install_success + exit 0 +fi + +# Install GPU dependencies on Linux +if ! available lspci && ! available lshw; then + warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies." + exit 0 +fi + +check_gpu() { + # Look for devices based on vendor ID for NVIDIA and AMD + case $1 in + lspci) + case $2 in + nvidia) available lspci && lspci -d '10de:' | grep -q 'NVIDIA' || return 1 ;; + amdgpu) available lspci && lspci -d '1002:' | grep -q 'AMD' || return 1 ;; + esac ;; + lshw) + case $2 in + nvidia) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[10DE\]' || return 1 ;; + amdgpu) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[1002\]' || return 1 ;; + esac ;; + nvidia-smi) available nvidia-smi || return 1 ;; + esac +} + +if check_gpu nvidia-smi; then + status "NVIDIA GPU installed." + exit 0 +fi + +if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdgpu && ! check_gpu lshw amdgpu; then + install_success + warning "No NVIDIA/AMD GPU detected. Ollama will run in CPU-only mode." + exit 0 +fi + +if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then + # Look for pre-existing ROCm v6 before downloading the dependencies + for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm" "/usr/lib64"; do + if [ -n "${search}" ] && [ -e "${search}/libhipblas.so.2" -o -e "${search}/lib/libhipblas.so.2" ]; then + status "Compatible AMD GPU ROCm library detected at ${search}" + install_success + exit 0 + fi + done + + status "Downloading AMD GPU dependencies..." + $SUDO rm -rf /usr/share/ollama/lib + $SUDO chmod o+x /usr/share/ollama + $SUDO install -o ollama -g ollama -m 755 -d /usr/share/ollama/lib/rocm + curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \ + | $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm . + install_success + status "AMD GPU ready." + exit 0 +fi + +CUDA_REPO_ERR_MSG="NVIDIA GPU detected, but your OS and Architecture are not supported by NVIDIA. Please install the CUDA driver manually https://docs.nvidia.com/cuda/cuda-installation-guide-linux/" +# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-7-centos-7 +# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-8-rocky-8 +# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-9-rocky-9 +# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#fedora +install_cuda_driver_yum() { + status 'Installing NVIDIA repository...' + + case $PACKAGE_MANAGER in + yum) + $SUDO $PACKAGE_MANAGER -y install yum-utils + if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then + $SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo + else + error $CUDA_REPO_ERR_MSG + fi + ;; + dnf) + if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then + $SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo + else + error $CUDA_REPO_ERR_MSG + fi + ;; + esac + + case $1 in + rhel) + status 'Installing EPEL repository...' + # EPEL is required for third-party dependencies such as dkms and libvdpau + $SUDO $PACKAGE_MANAGER -y install https://dl.fedoraproject.org/pub/epel/epel-release-latest-$2.noarch.rpm || true + ;; + esac + + status 'Installing CUDA driver...' + + if [ "$1" = 'centos' ] || [ "$1$2" = 'rhel7' ]; then + $SUDO $PACKAGE_MANAGER -y install nvidia-driver-latest-dkms + fi + + $SUDO $PACKAGE_MANAGER -y install cuda-drivers +} + +# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu +# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian +install_cuda_driver_apt() { + status 'Installing NVIDIA repository...' + if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then + curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb + else + error $CUDA_REPO_ERR_MSG + fi + + case $1 in + debian) + status 'Enabling contrib sources...' + $SUDO sed 's/main/contrib/' < /etc/apt/sources.list | $SUDO tee /etc/apt/sources.list.d/contrib.list > /dev/null + if [ -f "/etc/apt/sources.list.d/debian.sources" ]; then + $SUDO sed 's/main/contrib/' < /etc/apt/sources.list.d/debian.sources | $SUDO tee /etc/apt/sources.list.d/contrib.sources > /dev/null + fi + ;; + esac + + status 'Installing CUDA driver...' + $SUDO dpkg -i $TEMP_DIR/cuda-keyring.deb + $SUDO apt-get update + + [ -n "$SUDO" ] && SUDO_E="$SUDO -E" || SUDO_E= + DEBIAN_FRONTEND=noninteractive $SUDO_E apt-get -y install cuda-drivers -q +} + +if [ ! -f "/etc/os-release" ]; then + error "Unknown distribution. Skipping CUDA installation." +fi + +. /etc/os-release + +OS_NAME=$ID +OS_VERSION=$VERSION_ID + +PACKAGE_MANAGER= +for PACKAGE_MANAGER in dnf yum apt-get; do + if available $PACKAGE_MANAGER; then + break + fi +done + +if [ -z "$PACKAGE_MANAGER" ]; then + error "Unknown package manager. Skipping CUDA installation." +fi + +if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then + case $OS_NAME in + centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;; + rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;; + fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';; + amzn) install_cuda_driver_yum 'fedora' '37' ;; + debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;; + ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;; + *) exit ;; + esac +fi + +if ! lsmod | grep -q nvidia || ! lsmod | grep -q nvidia_uvm; then + KERNEL_RELEASE="$(uname -r)" + case $OS_NAME in + rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;; + centos|rhel|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;; + fedora) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE ;; + debian|ubuntu) $SUDO apt-get -y install linux-headers-$KERNEL_RELEASE ;; + *) exit ;; + esac + + NVIDIA_CUDA_VERSION=$($SUDO dkms status | awk -F: '/added/ { print $1 }') + if [ -n "$NVIDIA_CUDA_VERSION" ]; then + $SUDO dkms install $NVIDIA_CUDA_VERSION + fi + + if lsmod | grep -q nouveau; then + status 'Reboot to complete NVIDIA CUDA driver install.' + exit 0 + fi + + $SUDO modprobe nvidia + $SUDO modprobe nvidia_uvm +fi + +# make sure the NVIDIA modules are loaded on boot with nvidia-persistenced +if command -v nvidia-persistenced > /dev/null 2>&1; then + $SUDO touch /etc/modules-load.d/nvidia.conf + MODULES="nvidia nvidia-uvm" + for MODULE in $MODULES; do + if ! grep -qxF "$MODULE" /etc/modules-load.d/nvidia.conf; then + echo "$MODULE" | sudo tee -a /etc/modules-load.d/nvidia.conf > /dev/null + fi + done +fi + +status "NVIDIA GPU ready." +install_success diff --git a/tests/test_inference.py b/tests/test_inference.py index ad7bf6d19..2714482e8 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -14,24 +14,18 @@ from llama_models.llama3_1.api.datatypes import ( StopReason, SystemMessage, ) - -from llama_toolchain.inference.api.config import ( - InferenceConfig, - InlineImplConfig, - RemoteImplConfig, - ModelCheckpointConfig, - PytorchCheckpoint, - CheckpointQuantizationFormat, -) -from llama_toolchain.inference.api_instance import ( - get_inference_api_instance, -) from llama_toolchain.inference.api.datatypes import ( ChatCompletionResponseEventType, ) +from llama_toolchain.inference.meta_reference.inference import get_provider_impl +from llama_toolchain.inference.meta_reference.config import ( + MetaReferenceImplConfig, +) + from llama_toolchain.inference.api.endpoints import ChatCompletionRequest +MODEL = "Meta-Llama3.1-8B-Instruct" HELPER_MSG = """ This test needs llama-3.1-8b-instruct models. Please donwload using the llama cli @@ -50,32 +44,18 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): @classmethod async def asyncSetUpClass(cls): # assert model exists on local - model_dir = os.path.expanduser( - "~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/" - ) + model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/") assert os.path.isdir(model_dir), HELPER_MSG tokenizer_path = os.path.join(model_dir, "tokenizer.model") assert os.path.exists(tokenizer_path), HELPER_MSG - inline_config = InlineImplConfig( - checkpoint_config=ModelCheckpointConfig( - checkpoint=PytorchCheckpoint( - checkpoint_dir=model_dir, - tokenizer_path=tokenizer_path, - model_parallel_size=1, - quantization_format=CheckpointQuantizationFormat.bf16, - ) - ), + config = MetaReferenceImplConfig( + model=MODEL, max_seq_len=2048, ) - inference_config = InferenceConfig(impl_config=inline_config) - # -- For faster testing iteration -- - # remote_config = RemoteImplConfig(url="http://localhost:5000") - # inference_config = InferenceConfig(impl_config=remote_config) - - cls.api = await get_inference_api_instance(inference_config) + cls.api = await get_provider_impl(config, {}) await cls.api.initialize() current_date = datetime.now() @@ -134,7 +114,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): await cls.api.shutdown() async def asyncSetUp(self): - self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" + self.valid_supported_model = MODEL async def test_text(self): request = ChatCompletionRequest( diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index 67493db25..714521084 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -10,14 +10,12 @@ from llama_models.llama3_1.api.datatypes import ( SamplingStrategy, SystemMessage, ) -from llama_toolchain.inference.api_instance import ( - get_inference_api_instance, -) from llama_toolchain.inference.api.datatypes import ( ChatCompletionResponseEventType, ) from llama_toolchain.inference.api.endpoints import ChatCompletionRequest -from llama_toolchain.inference.api.config import InferenceConfig, OllamaImplConfig +from llama_toolchain.inference.ollama.config import OllamaImplConfig +from llama_toolchain.inference.ollama.ollama import get_provider_impl class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): @@ -30,9 +28,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ) # setup ollama - self.api = await get_inference_api_instance( - InferenceConfig(impl_config=ollama_config) - ) + self.api = await get_provider_impl(ollama_config) await self.api.initialize() current_date = datetime.now()