update inference config to take model and not model_dir

This commit is contained in:
Hardik Shah 2024-08-06 15:02:41 -07:00
parent 08c3802f45
commit 039861f1c7
9 changed files with 400 additions and 101 deletions

View file

@ -75,11 +75,13 @@ safetensors files to avoid downloading duplicate weights.
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_toolchain.common.model_utils import model_local_dir
repo_id = model.huggingface_repo repo_id = model.huggingface_repo
if repo_id is None: if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}") raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor() output_dir = model_local_dir(model)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
try: try:
true_output_dir = snapshot_download( true_output_dir = snapshot_download(
@ -107,8 +109,9 @@ safetensors files to avoid downloading duplicate weights.
def _meta_download(self, model: "Model", meta_url: str): def _meta_download(self, model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info from llama_models.sku_list import llama_meta_net_info
from llama_toolchain.common.model_utils import model_local_dir
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor() output_dir = model_local_dir(model)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model) info = llama_meta_net_info(model)

View file

@ -0,0 +1,8 @@
import os
from llama_models.datatypes import Model
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(model: Model) -> str:
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor())

View file

@ -4,61 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from typing import Optional
from typing import Literal, Optional, Union
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat from pydantic import BaseModel
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from llama_toolchain.inference.api import QuantizationConfig from llama_toolchain.inference.api import QuantizationConfig
@json_schema_type
class CheckpointType(Enum):
pytorch = "pytorch"
huggingface = "huggingface"
@json_schema_type
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value
)
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class HuggingFaceCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
CheckpointType.huggingface.value
)
repo_id: str # or model_name ?
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class ModelCheckpointConfig(BaseModel):
checkpoint: Annotated[
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
Field(discriminator="checkpoint_type"),
]
@json_schema_type @json_schema_type
class MetaReferenceImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):
model: str model: str
checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
max_seq_len: int max_seq_len: int

View file

@ -27,11 +27,22 @@ from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.model import Transformer from llama_models.llama3_1.api.model import Transformer
from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from termcolor import cprint from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType from llama_toolchain.inference.api import QuantizationType
from .config import CheckpointType, MetaReferenceImplConfig from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model))
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), f"Could not find checkpoint dir: {checkpoint_dir}"
return str(checkpoint_dir)
@dataclass @dataclass
@ -51,9 +62,7 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA, This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer. and loads the pre-trained model and tokenizer.
""" """
checkpoint = config.checkpoint_config.checkpoint model = resolve_model(config.model)
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet")
if ( if (
config.quantization config.quantization
@ -67,7 +76,7 @@ class Llama:
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
model_parallel_size = checkpoint.model_parallel_size model_parallel_size = model.hardware_requirements.gpu_count
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size) initialize_model_parallel(model_parallel_size)
@ -82,7 +91,8 @@ class Llama:
sys.stdout = open(os.devnull, "w") sys.stdout = open(os.devnull, "w")
start_time = time.time() start_time = time.time()
ckpt_dir = checkpoint.checkpoint_dir ckpt_dir = model_checkpoint_dir(model)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len( assert model_parallel_size == len(
@ -103,7 +113,9 @@ class Llama:
max_batch_size=config.max_batch_size, max_batch_size=config.max_batch_size,
**params, **params,
) )
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
tokenizer = Tokenizer(model_path=tokenizer_path)
assert ( assert (
model_args.vocab_size == tokenizer.n_words model_args.vocab_size == tokenizer.n_words

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
@ -12,9 +13,10 @@ from typing import Generator, List, Optional
from llama_models.llama3_1.api.chat_format import ChatFormat from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
from .generation import Llama from .generation import Llama, model_checkpoint_dir
from .parallel_utils import ModelParallelProcessGroup from .parallel_utils import ModelParallelProcessGroup
@ -60,11 +62,12 @@ class LlamaModelParallelGenerator:
def __init__(self, config: MetaReferenceImplConfig): def __init__(self, config: MetaReferenceImplConfig):
self.config = config self.config = config
self.model = resolve_model(self.config.model)
# this is a hack because Agent's loop uses this to tokenize and check if input is too long # this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going # while the tool-use loop is going
checkpoint = self.config.checkpoint_config.checkpoint checkpoint_dir = model_checkpoint_dir(self.model)
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path)) tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
def start(self): def start(self):
self.__enter__() self.__enter__()
@ -73,9 +76,8 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None) self.__exit__(None, None, None)
def __enter__(self): def __enter__(self):
checkpoint = self.config.checkpoint_config.checkpoint
self.group = ModelParallelProcessGroup( self.group = ModelParallelProcessGroup(
checkpoint.model_parallel_size, self.model.hardware_requirements.gpu_count,
init_model_cb=partial(init_model_cb, self.config), init_model_cb=partial(init_model_cb, self.config),
) )
self.group.start() self.group.start()

View file

@ -44,11 +44,13 @@ OLLAMA_SUPPORTED_SKUS = {
} }
def get_provider_impl(config: OllamaImplConfig) -> Inference: async def get_provider_impl(config: OllamaImplConfig) -> Inference:
assert isinstance( assert isinstance(
config, OllamaImplConfig config, OllamaImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
return OllamaInference(config) impl = OllamaInference(config)
await impl.initialize()
return impl
class OllamaInference(Inference): class OllamaInference(Inference):

340
ollama_install.sh Normal file
View file

@ -0,0 +1,340 @@
#!/bin/sh
# This script installs Ollama on Linux.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
set -eu
status() { echo ">>> $*" >&2; }
error() { echo "ERROR $*"; exit 1; }
warning() { echo "WARNING: $*"; }
TEMP_DIR=$(mktemp -d)
cleanup() { rm -rf $TEMP_DIR; }
trap cleanup EXIT
available() { command -v $1 >/dev/null; }
require() {
local MISSING=''
for TOOL in $*; do
if ! available $TOOL; then
MISSING="$MISSING $TOOL"
fi
done
echo $MISSING
}
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;;
*) error "Unsupported architecture: $ARCH" ;;
esac
IS_WSL2=false
KERN=$(uname -r)
case "$KERN" in
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
*icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;;
*) ;;
esac
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
SUDO=
if [ "$(id -u)" -ne 0 ]; then
# Running as root, no need for sudo
if ! available sudo; then
error "This script requires superuser permissions. Please re-run as root."
fi
SUDO="sudo"
fi
NEEDS=$(require curl awk grep sed tee xargs)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
status "Downloading ollama..."
curl --fail --show-error --location --progress-bar -o $TEMP_DIR/ollama "https://ollama.com/download/ollama-linux-${ARCH}${VER_PARAM}"
for BINDIR in /usr/local/bin /usr/bin /bin; do
echo $PATH | grep -q $BINDIR && break || continue
done
status "Installing ollama to $BINDIR..."
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
install_success() {
status 'The Ollama API is now available at 127.0.0.1:11434.'
status 'Install complete. Run "ollama" from the command line.'
}
trap install_success EXIT
# Everything from this point onwards is optional.
configure_systemd() {
if ! id ollama >/dev/null 2>&1; then
status "Creating ollama user..."
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/ollama ollama
fi
if getent group render >/dev/null 2>&1; then
status "Adding ollama user to render group..."
$SUDO usermod -a -G render ollama
fi
if getent group video >/dev/null 2>&1; then
status "Adding ollama user to video group..."
$SUDO usermod -a -G video ollama
fi
status "Adding current user to ollama group..."
$SUDO usermod -a -G ollama $(whoami)
status "Creating ollama systemd service..."
cat <<EOF | $SUDO tee /etc/systemd/system/ollama.service >/dev/null
[Unit]
Description=Ollama Service
After=network-online.target
[Service]
ExecStart=$BINDIR/ollama serve
User=ollama
Group=ollama
Restart=always
RestartSec=3
Environment="PATH=$PATH"
[Install]
WantedBy=default.target
EOF
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
case $SYSTEMCTL_RUNNING in
running|degraded)
status "Enabling and starting ollama service..."
$SUDO systemctl daemon-reload
$SUDO systemctl enable ollama
start_service() { $SUDO systemctl restart ollama; }
trap start_service EXIT
;;
esac
}
if available systemctl; then
configure_systemd
fi
# WSL2 only supports GPUs via nvidia passthrough
# so check for nvidia-smi to determine if GPU is available
if [ "$IS_WSL2" = true ]; then
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
status "Nvidia GPU detected."
fi
install_success
exit 0
fi
# Install GPU dependencies on Linux
if ! available lspci && ! available lshw; then
warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies."
exit 0
fi
check_gpu() {
# Look for devices based on vendor ID for NVIDIA and AMD
case $1 in
lspci)
case $2 in
nvidia) available lspci && lspci -d '10de:' | grep -q 'NVIDIA' || return 1 ;;
amdgpu) available lspci && lspci -d '1002:' | grep -q 'AMD' || return 1 ;;
esac ;;
lshw)
case $2 in
nvidia) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[10DE\]' || return 1 ;;
amdgpu) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[1002\]' || return 1 ;;
esac ;;
nvidia-smi) available nvidia-smi || return 1 ;;
esac
}
if check_gpu nvidia-smi; then
status "NVIDIA GPU installed."
exit 0
fi
if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdgpu && ! check_gpu lshw amdgpu; then
install_success
warning "No NVIDIA/AMD GPU detected. Ollama will run in CPU-only mode."
exit 0
fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
# Look for pre-existing ROCm v6 before downloading the dependencies
for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm" "/usr/lib64"; do
if [ -n "${search}" ] && [ -e "${search}/libhipblas.so.2" -o -e "${search}/lib/libhipblas.so.2" ]; then
status "Compatible AMD GPU ROCm library detected at ${search}"
install_success
exit 0
fi
done
status "Downloading AMD GPU dependencies..."
$SUDO rm -rf /usr/share/ollama/lib
$SUDO chmod o+x /usr/share/ollama
$SUDO install -o ollama -g ollama -m 755 -d /usr/share/ollama/lib/rocm
curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \
| $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm .
install_success
status "AMD GPU ready."
exit 0
fi
CUDA_REPO_ERR_MSG="NVIDIA GPU detected, but your OS and Architecture are not supported by NVIDIA. Please install the CUDA driver manually https://docs.nvidia.com/cuda/cuda-installation-guide-linux/"
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-7-centos-7
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-8-rocky-8
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-9-rocky-9
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#fedora
install_cuda_driver_yum() {
status 'Installing NVIDIA repository...'
case $PACKAGE_MANAGER in
yum)
$SUDO $PACKAGE_MANAGER -y install yum-utils
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
else
error $CUDA_REPO_ERR_MSG
fi
;;
dnf)
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
else
error $CUDA_REPO_ERR_MSG
fi
;;
esac
case $1 in
rhel)
status 'Installing EPEL repository...'
# EPEL is required for third-party dependencies such as dkms and libvdpau
$SUDO $PACKAGE_MANAGER -y install https://dl.fedoraproject.org/pub/epel/epel-release-latest-$2.noarch.rpm || true
;;
esac
status 'Installing CUDA driver...'
if [ "$1" = 'centos' ] || [ "$1$2" = 'rhel7' ]; then
$SUDO $PACKAGE_MANAGER -y install nvidia-driver-latest-dkms
fi
$SUDO $PACKAGE_MANAGER -y install cuda-drivers
}
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
install_cuda_driver_apt() {
status 'Installing NVIDIA repository...'
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
else
error $CUDA_REPO_ERR_MSG
fi
case $1 in
debian)
status 'Enabling contrib sources...'
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list | $SUDO tee /etc/apt/sources.list.d/contrib.list > /dev/null
if [ -f "/etc/apt/sources.list.d/debian.sources" ]; then
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list.d/debian.sources | $SUDO tee /etc/apt/sources.list.d/contrib.sources > /dev/null
fi
;;
esac
status 'Installing CUDA driver...'
$SUDO dpkg -i $TEMP_DIR/cuda-keyring.deb
$SUDO apt-get update
[ -n "$SUDO" ] && SUDO_E="$SUDO -E" || SUDO_E=
DEBIAN_FRONTEND=noninteractive $SUDO_E apt-get -y install cuda-drivers -q
}
if [ ! -f "/etc/os-release" ]; then
error "Unknown distribution. Skipping CUDA installation."
fi
. /etc/os-release
OS_NAME=$ID
OS_VERSION=$VERSION_ID
PACKAGE_MANAGER=
for PACKAGE_MANAGER in dnf yum apt-get; do
if available $PACKAGE_MANAGER; then
break
fi
done
if [ -z "$PACKAGE_MANAGER" ]; then
error "Unknown package manager. Skipping CUDA installation."
fi
if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
case $OS_NAME in
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';;
amzn) install_cuda_driver_yum 'fedora' '37' ;;
debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;
*) exit ;;
esac
fi
if ! lsmod | grep -q nvidia || ! lsmod | grep -q nvidia_uvm; then
KERNEL_RELEASE="$(uname -r)"
case $OS_NAME in
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
centos|rhel|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
fedora) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE ;;
debian|ubuntu) $SUDO apt-get -y install linux-headers-$KERNEL_RELEASE ;;
*) exit ;;
esac
NVIDIA_CUDA_VERSION=$($SUDO dkms status | awk -F: '/added/ { print $1 }')
if [ -n "$NVIDIA_CUDA_VERSION" ]; then
$SUDO dkms install $NVIDIA_CUDA_VERSION
fi
if lsmod | grep -q nouveau; then
status 'Reboot to complete NVIDIA CUDA driver install.'
exit 0
fi
$SUDO modprobe nvidia
$SUDO modprobe nvidia_uvm
fi
# make sure the NVIDIA modules are loaded on boot with nvidia-persistenced
if command -v nvidia-persistenced > /dev/null 2>&1; then
$SUDO touch /etc/modules-load.d/nvidia.conf
MODULES="nvidia nvidia-uvm"
for MODULE in $MODULES; do
if ! grep -qxF "$MODULE" /etc/modules-load.d/nvidia.conf; then
echo "$MODULE" | sudo tee -a /etc/modules-load.d/nvidia.conf > /dev/null
fi
done
fi
status "NVIDIA GPU ready."
install_success

View file

@ -14,24 +14,18 @@ from llama_models.llama3_1.api.datatypes import (
StopReason, StopReason,
SystemMessage, SystemMessage,
) )
from llama_toolchain.inference.api.config import (
InferenceConfig,
InlineImplConfig,
RemoteImplConfig,
ModelCheckpointConfig,
PytorchCheckpoint,
CheckpointQuantizationFormat,
)
from llama_toolchain.inference.api_instance import (
get_inference_api_instance,
)
from llama_toolchain.inference.api.datatypes import ( from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
) )
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
from llama_toolchain.inference.meta_reference.config import (
MetaReferenceImplConfig,
)
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
MODEL = "Meta-Llama3.1-8B-Instruct"
HELPER_MSG = """ HELPER_MSG = """
This test needs llama-3.1-8b-instruct models. This test needs llama-3.1-8b-instruct models.
Please donwload using the llama cli Please donwload using the llama cli
@ -50,32 +44,18 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod @classmethod
async def asyncSetUpClass(cls): async def asyncSetUpClass(cls):
# assert model exists on local # assert model exists on local
model_dir = os.path.expanduser( model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/")
"~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/"
)
assert os.path.isdir(model_dir), HELPER_MSG assert os.path.isdir(model_dir), HELPER_MSG
tokenizer_path = os.path.join(model_dir, "tokenizer.model") tokenizer_path = os.path.join(model_dir, "tokenizer.model")
assert os.path.exists(tokenizer_path), HELPER_MSG assert os.path.exists(tokenizer_path), HELPER_MSG
inline_config = InlineImplConfig( config = MetaReferenceImplConfig(
checkpoint_config=ModelCheckpointConfig( model=MODEL,
checkpoint=PytorchCheckpoint(
checkpoint_dir=model_dir,
tokenizer_path=tokenizer_path,
model_parallel_size=1,
quantization_format=CheckpointQuantizationFormat.bf16,
)
),
max_seq_len=2048, max_seq_len=2048,
) )
inference_config = InferenceConfig(impl_config=inline_config)
# -- For faster testing iteration -- cls.api = await get_provider_impl(config, {})
# remote_config = RemoteImplConfig(url="http://localhost:5000")
# inference_config = InferenceConfig(impl_config=remote_config)
cls.api = await get_inference_api_instance(inference_config)
await cls.api.initialize() await cls.api.initialize()
current_date = datetime.now() current_date = datetime.now()
@ -134,7 +114,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
await cls.api.shutdown() await cls.api.shutdown()
async def asyncSetUp(self): async def asyncSetUp(self):
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" self.valid_supported_model = MODEL
async def test_text(self): async def test_text(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(

View file

@ -10,14 +10,12 @@ from llama_models.llama3_1.api.datatypes import (
SamplingStrategy, SamplingStrategy,
SystemMessage, SystemMessage,
) )
from llama_toolchain.inference.api_instance import (
get_inference_api_instance,
)
from llama_toolchain.inference.api.datatypes import ( from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
) )
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.api.config import InferenceConfig, OllamaImplConfig from llama_toolchain.inference.ollama.config import OllamaImplConfig
from llama_toolchain.inference.ollama.ollama import get_provider_impl
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
@ -30,9 +28,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
) )
# setup ollama # setup ollama
self.api = await get_inference_api_instance( self.api = await get_provider_impl(ollama_config)
InferenceConfig(impl_config=ollama_config)
)
await self.api.initialize() await self.api.initialize()
current_date = datetime.now() current_date = datetime.now()