mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
update inference config to take model and not model_dir
This commit is contained in:
parent
08c3802f45
commit
039861f1c7
9 changed files with 400 additions and 101 deletions
|
@ -75,11 +75,13 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||||
|
|
||||||
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
|
|
||||||
repo_id = model.huggingface_repo
|
repo_id = model.huggingface_repo
|
||||||
if repo_id is None:
|
if repo_id is None:
|
||||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||||
|
|
||||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
output_dir = model_local_dir(model)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
true_output_dir = snapshot_download(
|
true_output_dir = snapshot_download(
|
||||||
|
@ -107,8 +109,9 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
|
|
||||||
def _meta_download(self, model: "Model", meta_url: str):
|
def _meta_download(self, model: "Model", meta_url: str):
|
||||||
from llama_models.sku_list import llama_meta_net_info
|
from llama_models.sku_list import llama_meta_net_info
|
||||||
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
|
|
||||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
output_dir = model_local_dir(model)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
info = llama_meta_net_info(model)
|
info = llama_meta_net_info(model)
|
||||||
|
|
8
llama_toolchain/common/model_utils.py
Normal file
8
llama_toolchain/common/model_utils.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
import os
|
||||||
|
from llama_models.datatypes import Model
|
||||||
|
|
||||||
|
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def model_local_dir(model: Model) -> str:
|
||||||
|
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor())
|
|
@ -4,61 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from typing import Optional
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from strong_typing.schema import json_schema_type
|
from strong_typing.schema import json_schema_type
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api import QuantizationConfig
|
from llama_toolchain.inference.api import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CheckpointType(Enum):
|
|
||||||
pytorch = "pytorch"
|
|
||||||
huggingface = "huggingface"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PytorchCheckpoint(BaseModel):
|
|
||||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
|
||||||
CheckpointType.pytorch.value
|
|
||||||
)
|
|
||||||
checkpoint_dir: str
|
|
||||||
tokenizer_path: str
|
|
||||||
model_parallel_size: int
|
|
||||||
quantization_format: CheckpointQuantizationFormat = (
|
|
||||||
CheckpointQuantizationFormat.bf16
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class HuggingFaceCheckpoint(BaseModel):
|
|
||||||
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
|
|
||||||
CheckpointType.huggingface.value
|
|
||||||
)
|
|
||||||
repo_id: str # or model_name ?
|
|
||||||
model_parallel_size: int
|
|
||||||
quantization_format: CheckpointQuantizationFormat = (
|
|
||||||
CheckpointQuantizationFormat.bf16
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModelCheckpointConfig(BaseModel):
|
|
||||||
checkpoint: Annotated[
|
|
||||||
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
|
|
||||||
Field(discriminator="checkpoint_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MetaReferenceImplConfig(BaseModel):
|
class MetaReferenceImplConfig(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
checkpoint_config: ModelCheckpointConfig
|
|
||||||
quantization: Optional[QuantizationConfig] = None
|
quantization: Optional[QuantizationConfig] = None
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
|
|
|
@ -27,11 +27,22 @@ from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3_1.api.datatypes import Message
|
||||||
from llama_models.llama3_1.api.model import Transformer
|
from llama_models.llama3_1.api.model import Transformer
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
from llama_toolchain.inference.api import QuantizationType
|
from llama_toolchain.inference.api import QuantizationType
|
||||||
|
|
||||||
from .config import CheckpointType, MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
def model_checkpoint_dir(model) -> str:
|
||||||
|
checkpoint_dir = Path(model_local_dir(model))
|
||||||
|
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||||
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
|
assert checkpoint_dir.exists(), f"Could not find checkpoint dir: {checkpoint_dir}"
|
||||||
|
return str(checkpoint_dir)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -51,9 +62,7 @@ class Llama:
|
||||||
This method initializes the distributed process group, sets the device to CUDA,
|
This method initializes the distributed process group, sets the device to CUDA,
|
||||||
and loads the pre-trained model and tokenizer.
|
and loads the pre-trained model and tokenizer.
|
||||||
"""
|
"""
|
||||||
checkpoint = config.checkpoint_config.checkpoint
|
model = resolve_model(config.model)
|
||||||
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
|
||||||
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
config.quantization
|
config.quantization
|
||||||
|
@ -67,7 +76,7 @@ class Llama:
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
|
|
||||||
model_parallel_size = checkpoint.model_parallel_size
|
model_parallel_size = model.hardware_requirements.gpu_count
|
||||||
if not model_parallel_is_initialized():
|
if not model_parallel_is_initialized():
|
||||||
initialize_model_parallel(model_parallel_size)
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
|
||||||
|
@ -82,7 +91,8 @@ class Llama:
|
||||||
sys.stdout = open(os.devnull, "w")
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
ckpt_dir = checkpoint.checkpoint_dir
|
ckpt_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
assert model_parallel_size == len(
|
assert model_parallel_size == len(
|
||||||
|
@ -103,7 +113,9 @@ class Llama:
|
||||||
max_batch_size=config.max_batch_size,
|
max_batch_size=config.max_batch_size,
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
|
|
||||||
|
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
||||||
|
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
model_args.vocab_size == tokenizer.n_words
|
model_args.vocab_size == tokenizer.n_words
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -12,9 +13,10 @@ from typing import Generator, List, Optional
|
||||||
from llama_models.llama3_1.api.chat_format import ChatFormat
|
from llama_models.llama3_1.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3_1.api.datatypes import Message
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
from .generation import Llama
|
from .generation import Llama, model_checkpoint_dir
|
||||||
from .parallel_utils import ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,11 +62,12 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def __init__(self, config: MetaReferenceImplConfig):
|
def __init__(self, config: MetaReferenceImplConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.model = resolve_model(self.config.model)
|
||||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||||
# while the tool-use loop is going
|
# while the tool-use loop is going
|
||||||
checkpoint = self.config.checkpoint_config.checkpoint
|
checkpoint_dir = model_checkpoint_dir(self.model)
|
||||||
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
|
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||||
|
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self.__enter__()
|
self.__enter__()
|
||||||
|
@ -73,9 +76,8 @@ class LlamaModelParallelGenerator:
|
||||||
self.__exit__(None, None, None)
|
self.__exit__(None, None, None)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
checkpoint = self.config.checkpoint_config.checkpoint
|
|
||||||
self.group = ModelParallelProcessGroup(
|
self.group = ModelParallelProcessGroup(
|
||||||
checkpoint.model_parallel_size,
|
self.model.hardware_requirements.gpu_count,
|
||||||
init_model_cb=partial(init_model_cb, self.config),
|
init_model_cb=partial(init_model_cb, self.config),
|
||||||
)
|
)
|
||||||
self.group.start()
|
self.group.start()
|
||||||
|
|
|
@ -44,11 +44,13 @@ OLLAMA_SUPPORTED_SKUS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_provider_impl(config: OllamaImplConfig) -> Inference:
|
async def get_provider_impl(config: OllamaImplConfig) -> Inference:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, OllamaImplConfig
|
config, OllamaImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
return OllamaInference(config)
|
impl = OllamaInference(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
class OllamaInference(Inference):
|
class OllamaInference(Inference):
|
||||||
|
|
340
ollama_install.sh
Normal file
340
ollama_install.sh
Normal file
|
@ -0,0 +1,340 @@
|
||||||
|
#!/bin/sh
|
||||||
|
# This script installs Ollama on Linux.
|
||||||
|
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
status() { echo ">>> $*" >&2; }
|
||||||
|
error() { echo "ERROR $*"; exit 1; }
|
||||||
|
warning() { echo "WARNING: $*"; }
|
||||||
|
|
||||||
|
TEMP_DIR=$(mktemp -d)
|
||||||
|
cleanup() { rm -rf $TEMP_DIR; }
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
available() { command -v $1 >/dev/null; }
|
||||||
|
require() {
|
||||||
|
local MISSING=''
|
||||||
|
for TOOL in $*; do
|
||||||
|
if ! available $TOOL; then
|
||||||
|
MISSING="$MISSING $TOOL"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo $MISSING
|
||||||
|
}
|
||||||
|
|
||||||
|
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||||
|
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) ARCH="amd64" ;;
|
||||||
|
aarch64|arm64) ARCH="arm64" ;;
|
||||||
|
*) error "Unsupported architecture: $ARCH" ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
IS_WSL2=false
|
||||||
|
|
||||||
|
KERN=$(uname -r)
|
||||||
|
case "$KERN" in
|
||||||
|
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
|
||||||
|
*icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||||
|
*) ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
|
||||||
|
|
||||||
|
SUDO=
|
||||||
|
if [ "$(id -u)" -ne 0 ]; then
|
||||||
|
# Running as root, no need for sudo
|
||||||
|
if ! available sudo; then
|
||||||
|
error "This script requires superuser permissions. Please re-run as root."
|
||||||
|
fi
|
||||||
|
|
||||||
|
SUDO="sudo"
|
||||||
|
fi
|
||||||
|
|
||||||
|
NEEDS=$(require curl awk grep sed tee xargs)
|
||||||
|
if [ -n "$NEEDS" ]; then
|
||||||
|
status "ERROR: The following tools are required but missing:"
|
||||||
|
for NEED in $NEEDS; do
|
||||||
|
echo " - $NEED"
|
||||||
|
done
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Downloading ollama..."
|
||||||
|
curl --fail --show-error --location --progress-bar -o $TEMP_DIR/ollama "https://ollama.com/download/ollama-linux-${ARCH}${VER_PARAM}"
|
||||||
|
|
||||||
|
for BINDIR in /usr/local/bin /usr/bin /bin; do
|
||||||
|
echo $PATH | grep -q $BINDIR && break || continue
|
||||||
|
done
|
||||||
|
|
||||||
|
status "Installing ollama to $BINDIR..."
|
||||||
|
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||||
|
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
|
||||||
|
|
||||||
|
install_success() {
|
||||||
|
status 'The Ollama API is now available at 127.0.0.1:11434.'
|
||||||
|
status 'Install complete. Run "ollama" from the command line.'
|
||||||
|
}
|
||||||
|
trap install_success EXIT
|
||||||
|
|
||||||
|
# Everything from this point onwards is optional.
|
||||||
|
|
||||||
|
configure_systemd() {
|
||||||
|
if ! id ollama >/dev/null 2>&1; then
|
||||||
|
status "Creating ollama user..."
|
||||||
|
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/ollama ollama
|
||||||
|
fi
|
||||||
|
if getent group render >/dev/null 2>&1; then
|
||||||
|
status "Adding ollama user to render group..."
|
||||||
|
$SUDO usermod -a -G render ollama
|
||||||
|
fi
|
||||||
|
if getent group video >/dev/null 2>&1; then
|
||||||
|
status "Adding ollama user to video group..."
|
||||||
|
$SUDO usermod -a -G video ollama
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Adding current user to ollama group..."
|
||||||
|
$SUDO usermod -a -G ollama $(whoami)
|
||||||
|
|
||||||
|
status "Creating ollama systemd service..."
|
||||||
|
cat <<EOF | $SUDO tee /etc/systemd/system/ollama.service >/dev/null
|
||||||
|
[Unit]
|
||||||
|
Description=Ollama Service
|
||||||
|
After=network-online.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
ExecStart=$BINDIR/ollama serve
|
||||||
|
User=ollama
|
||||||
|
Group=ollama
|
||||||
|
Restart=always
|
||||||
|
RestartSec=3
|
||||||
|
Environment="PATH=$PATH"
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=default.target
|
||||||
|
EOF
|
||||||
|
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
|
||||||
|
case $SYSTEMCTL_RUNNING in
|
||||||
|
running|degraded)
|
||||||
|
status "Enabling and starting ollama service..."
|
||||||
|
$SUDO systemctl daemon-reload
|
||||||
|
$SUDO systemctl enable ollama
|
||||||
|
|
||||||
|
start_service() { $SUDO systemctl restart ollama; }
|
||||||
|
trap start_service EXIT
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
if available systemctl; then
|
||||||
|
configure_systemd
|
||||||
|
fi
|
||||||
|
|
||||||
|
# WSL2 only supports GPUs via nvidia passthrough
|
||||||
|
# so check for nvidia-smi to determine if GPU is available
|
||||||
|
if [ "$IS_WSL2" = true ]; then
|
||||||
|
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||||
|
status "Nvidia GPU detected."
|
||||||
|
fi
|
||||||
|
install_success
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install GPU dependencies on Linux
|
||||||
|
if ! available lspci && ! available lshw; then
|
||||||
|
warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
check_gpu() {
|
||||||
|
# Look for devices based on vendor ID for NVIDIA and AMD
|
||||||
|
case $1 in
|
||||||
|
lspci)
|
||||||
|
case $2 in
|
||||||
|
nvidia) available lspci && lspci -d '10de:' | grep -q 'NVIDIA' || return 1 ;;
|
||||||
|
amdgpu) available lspci && lspci -d '1002:' | grep -q 'AMD' || return 1 ;;
|
||||||
|
esac ;;
|
||||||
|
lshw)
|
||||||
|
case $2 in
|
||||||
|
nvidia) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[10DE\]' || return 1 ;;
|
||||||
|
amdgpu) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[1002\]' || return 1 ;;
|
||||||
|
esac ;;
|
||||||
|
nvidia-smi) available nvidia-smi || return 1 ;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
if check_gpu nvidia-smi; then
|
||||||
|
status "NVIDIA GPU installed."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdgpu && ! check_gpu lshw amdgpu; then
|
||||||
|
install_success
|
||||||
|
warning "No NVIDIA/AMD GPU detected. Ollama will run in CPU-only mode."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
|
||||||
|
# Look for pre-existing ROCm v6 before downloading the dependencies
|
||||||
|
for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm" "/usr/lib64"; do
|
||||||
|
if [ -n "${search}" ] && [ -e "${search}/libhipblas.so.2" -o -e "${search}/lib/libhipblas.so.2" ]; then
|
||||||
|
status "Compatible AMD GPU ROCm library detected at ${search}"
|
||||||
|
install_success
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
status "Downloading AMD GPU dependencies..."
|
||||||
|
$SUDO rm -rf /usr/share/ollama/lib
|
||||||
|
$SUDO chmod o+x /usr/share/ollama
|
||||||
|
$SUDO install -o ollama -g ollama -m 755 -d /usr/share/ollama/lib/rocm
|
||||||
|
curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \
|
||||||
|
| $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm .
|
||||||
|
install_success
|
||||||
|
status "AMD GPU ready."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
CUDA_REPO_ERR_MSG="NVIDIA GPU detected, but your OS and Architecture are not supported by NVIDIA. Please install the CUDA driver manually https://docs.nvidia.com/cuda/cuda-installation-guide-linux/"
|
||||||
|
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-7-centos-7
|
||||||
|
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-8-rocky-8
|
||||||
|
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-9-rocky-9
|
||||||
|
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#fedora
|
||||||
|
install_cuda_driver_yum() {
|
||||||
|
status 'Installing NVIDIA repository...'
|
||||||
|
|
||||||
|
case $PACKAGE_MANAGER in
|
||||||
|
yum)
|
||||||
|
$SUDO $PACKAGE_MANAGER -y install yum-utils
|
||||||
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
|
||||||
|
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
|
||||||
|
else
|
||||||
|
error $CUDA_REPO_ERR_MSG
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
dnf)
|
||||||
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
|
||||||
|
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
|
||||||
|
else
|
||||||
|
error $CUDA_REPO_ERR_MSG
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
case $1 in
|
||||||
|
rhel)
|
||||||
|
status 'Installing EPEL repository...'
|
||||||
|
# EPEL is required for third-party dependencies such as dkms and libvdpau
|
||||||
|
$SUDO $PACKAGE_MANAGER -y install https://dl.fedoraproject.org/pub/epel/epel-release-latest-$2.noarch.rpm || true
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
status 'Installing CUDA driver...'
|
||||||
|
|
||||||
|
if [ "$1" = 'centos' ] || [ "$1$2" = 'rhel7' ]; then
|
||||||
|
$SUDO $PACKAGE_MANAGER -y install nvidia-driver-latest-dkms
|
||||||
|
fi
|
||||||
|
|
||||||
|
$SUDO $PACKAGE_MANAGER -y install cuda-drivers
|
||||||
|
}
|
||||||
|
|
||||||
|
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu
|
||||||
|
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
|
||||||
|
install_cuda_driver_apt() {
|
||||||
|
status 'Installing NVIDIA repository...'
|
||||||
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
|
||||||
|
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
|
||||||
|
else
|
||||||
|
error $CUDA_REPO_ERR_MSG
|
||||||
|
fi
|
||||||
|
|
||||||
|
case $1 in
|
||||||
|
debian)
|
||||||
|
status 'Enabling contrib sources...'
|
||||||
|
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list | $SUDO tee /etc/apt/sources.list.d/contrib.list > /dev/null
|
||||||
|
if [ -f "/etc/apt/sources.list.d/debian.sources" ]; then
|
||||||
|
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list.d/debian.sources | $SUDO tee /etc/apt/sources.list.d/contrib.sources > /dev/null
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
status 'Installing CUDA driver...'
|
||||||
|
$SUDO dpkg -i $TEMP_DIR/cuda-keyring.deb
|
||||||
|
$SUDO apt-get update
|
||||||
|
|
||||||
|
[ -n "$SUDO" ] && SUDO_E="$SUDO -E" || SUDO_E=
|
||||||
|
DEBIAN_FRONTEND=noninteractive $SUDO_E apt-get -y install cuda-drivers -q
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ ! -f "/etc/os-release" ]; then
|
||||||
|
error "Unknown distribution. Skipping CUDA installation."
|
||||||
|
fi
|
||||||
|
|
||||||
|
. /etc/os-release
|
||||||
|
|
||||||
|
OS_NAME=$ID
|
||||||
|
OS_VERSION=$VERSION_ID
|
||||||
|
|
||||||
|
PACKAGE_MANAGER=
|
||||||
|
for PACKAGE_MANAGER in dnf yum apt-get; do
|
||||||
|
if available $PACKAGE_MANAGER; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -z "$PACKAGE_MANAGER" ]; then
|
||||||
|
error "Unknown package manager. Skipping CUDA installation."
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||||
|
case $OS_NAME in
|
||||||
|
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
|
||||||
|
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
|
||||||
|
fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';;
|
||||||
|
amzn) install_cuda_driver_yum 'fedora' '37' ;;
|
||||||
|
debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
|
||||||
|
ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;
|
||||||
|
*) exit ;;
|
||||||
|
esac
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! lsmod | grep -q nvidia || ! lsmod | grep -q nvidia_uvm; then
|
||||||
|
KERNEL_RELEASE="$(uname -r)"
|
||||||
|
case $OS_NAME in
|
||||||
|
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
|
||||||
|
centos|rhel|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
|
||||||
|
fedora) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE ;;
|
||||||
|
debian|ubuntu) $SUDO apt-get -y install linux-headers-$KERNEL_RELEASE ;;
|
||||||
|
*) exit ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
NVIDIA_CUDA_VERSION=$($SUDO dkms status | awk -F: '/added/ { print $1 }')
|
||||||
|
if [ -n "$NVIDIA_CUDA_VERSION" ]; then
|
||||||
|
$SUDO dkms install $NVIDIA_CUDA_VERSION
|
||||||
|
fi
|
||||||
|
|
||||||
|
if lsmod | grep -q nouveau; then
|
||||||
|
status 'Reboot to complete NVIDIA CUDA driver install.'
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
$SUDO modprobe nvidia
|
||||||
|
$SUDO modprobe nvidia_uvm
|
||||||
|
fi
|
||||||
|
|
||||||
|
# make sure the NVIDIA modules are loaded on boot with nvidia-persistenced
|
||||||
|
if command -v nvidia-persistenced > /dev/null 2>&1; then
|
||||||
|
$SUDO touch /etc/modules-load.d/nvidia.conf
|
||||||
|
MODULES="nvidia nvidia-uvm"
|
||||||
|
for MODULE in $MODULES; do
|
||||||
|
if ! grep -qxF "$MODULE" /etc/modules-load.d/nvidia.conf; then
|
||||||
|
echo "$MODULE" | sudo tee -a /etc/modules-load.d/nvidia.conf > /dev/null
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "NVIDIA GPU ready."
|
||||||
|
install_success
|
|
@ -14,24 +14,18 @@ from llama_models.llama3_1.api.datatypes import (
|
||||||
StopReason,
|
StopReason,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_toolchain.inference.api.config import (
|
|
||||||
InferenceConfig,
|
|
||||||
InlineImplConfig,
|
|
||||||
RemoteImplConfig,
|
|
||||||
ModelCheckpointConfig,
|
|
||||||
PytorchCheckpoint,
|
|
||||||
CheckpointQuantizationFormat,
|
|
||||||
)
|
|
||||||
from llama_toolchain.inference.api_instance import (
|
|
||||||
get_inference_api_instance,
|
|
||||||
)
|
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
from llama_toolchain.inference.api.datatypes import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
)
|
)
|
||||||
|
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
|
||||||
|
from llama_toolchain.inference.meta_reference.config import (
|
||||||
|
MetaReferenceImplConfig,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
||||||
|
|
||||||
|
|
||||||
|
MODEL = "Meta-Llama3.1-8B-Instruct"
|
||||||
HELPER_MSG = """
|
HELPER_MSG = """
|
||||||
This test needs llama-3.1-8b-instruct models.
|
This test needs llama-3.1-8b-instruct models.
|
||||||
Please donwload using the llama cli
|
Please donwload using the llama cli
|
||||||
|
@ -50,32 +44,18 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def asyncSetUpClass(cls):
|
async def asyncSetUpClass(cls):
|
||||||
# assert model exists on local
|
# assert model exists on local
|
||||||
model_dir = os.path.expanduser(
|
model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/")
|
||||||
"~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/"
|
|
||||||
)
|
|
||||||
assert os.path.isdir(model_dir), HELPER_MSG
|
assert os.path.isdir(model_dir), HELPER_MSG
|
||||||
|
|
||||||
tokenizer_path = os.path.join(model_dir, "tokenizer.model")
|
tokenizer_path = os.path.join(model_dir, "tokenizer.model")
|
||||||
assert os.path.exists(tokenizer_path), HELPER_MSG
|
assert os.path.exists(tokenizer_path), HELPER_MSG
|
||||||
|
|
||||||
inline_config = InlineImplConfig(
|
config = MetaReferenceImplConfig(
|
||||||
checkpoint_config=ModelCheckpointConfig(
|
model=MODEL,
|
||||||
checkpoint=PytorchCheckpoint(
|
|
||||||
checkpoint_dir=model_dir,
|
|
||||||
tokenizer_path=tokenizer_path,
|
|
||||||
model_parallel_size=1,
|
|
||||||
quantization_format=CheckpointQuantizationFormat.bf16,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
max_seq_len=2048,
|
max_seq_len=2048,
|
||||||
)
|
)
|
||||||
inference_config = InferenceConfig(impl_config=inline_config)
|
|
||||||
|
|
||||||
# -- For faster testing iteration --
|
cls.api = await get_provider_impl(config, {})
|
||||||
# remote_config = RemoteImplConfig(url="http://localhost:5000")
|
|
||||||
# inference_config = InferenceConfig(impl_config=remote_config)
|
|
||||||
|
|
||||||
cls.api = await get_inference_api_instance(inference_config)
|
|
||||||
await cls.api.initialize()
|
await cls.api.initialize()
|
||||||
|
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
|
@ -134,7 +114,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
await cls.api.shutdown()
|
await cls.api.shutdown()
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
|
self.valid_supported_model = MODEL
|
||||||
|
|
||||||
async def test_text(self):
|
async def test_text(self):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
|
@ -10,14 +10,12 @@ from llama_models.llama3_1.api.datatypes import (
|
||||||
SamplingStrategy,
|
SamplingStrategy,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api_instance import (
|
|
||||||
get_inference_api_instance,
|
|
||||||
)
|
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
from llama_toolchain.inference.api.datatypes import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
||||||
from llama_toolchain.inference.api.config import InferenceConfig, OllamaImplConfig
|
from llama_toolchain.inference.ollama.config import OllamaImplConfig
|
||||||
|
from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
@ -30,9 +28,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup ollama
|
# setup ollama
|
||||||
self.api = await get_inference_api_instance(
|
self.api = await get_provider_impl(ollama_config)
|
||||||
InferenceConfig(impl_config=ollama_config)
|
|
||||||
)
|
|
||||||
await self.api.initialize()
|
await self.api.initialize()
|
||||||
|
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue