get ollama working

This commit is contained in:
Hardik Shah 2024-08-07 17:52:49 -07:00
parent ea50086190
commit 171a178783
9 changed files with 151 additions and 375 deletions

View file

@ -10,6 +10,7 @@ set -euo pipefail
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
error_handler() { error_handler() {
@ -78,6 +79,8 @@ pip_dependencies="$3"
ensure_conda_env_python310 "$env_name" "$pip_dependencies" ensure_conda_env_python310 "$env_name" "$pip_dependencies"
echo -e "${GREEN}Successfully setup distribution environment. Starting to configure ....${NC}"
eval "$(conda shell.bash hook)" eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name" conda deactivate && conda activate "$env_name"

View file

@ -41,7 +41,10 @@ def model_checkpoint_dir(model) -> str:
if not Path(checkpoint_dir / "consolidated.00.pth").exists(): if not Path(checkpoint_dir / "consolidated.00.pth").exists():
checkpoint_dir = checkpoint_dir / "original" checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), f"Could not find checkpoint dir: {checkpoint_dir}" assert checkpoint_dir.exists(), (
f"Could not find checkpoint dir: {checkpoint_dir}."
f"Please download model using `llama download {model.descriptor()}`"
)
return str(checkpoint_dir) return str(checkpoint_dir)

View file

@ -10,5 +10,7 @@ from strong_typing.schema import json_schema_type
@json_schema_type @json_schema_type
class OllamaImplConfig(BaseModel): class OllamaImplConfig(BaseModel):
model: str = Field(..., description="The name of the model in ollama catalog") url: str = Field(
url: str = Field(..., description="The URL for the ollama server") default="http://localhost:11434",
description="The URL for the ollama server",
)

View file

@ -5,10 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
import uuid import uuid
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator
import httpx import httpx
from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3_1.api.datatypes import (
BuiltinTool, BuiltinTool,
CompletionMessage, CompletionMessage,
@ -17,11 +17,8 @@ from llama_models.llama3_1.api.datatypes import (
ToolCall, ToolCall,
) )
from llama_models.llama3_1.api.tool_utils import ToolUtils from llama_models.llama3_1.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from ollama import AsyncClient
from llama_toolchain.inference.api import ( from llama_toolchain.inference.api import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -33,18 +30,21 @@ from llama_toolchain.inference.api import (
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
) )
from ollama import AsyncClient
from .config import OllamaImplConfig from .config import OllamaImplConfig
# TODO: Eventually this will move to the llama cli model list command # TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = { OLLAMA_SUPPORTED_SKUS = {
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16" "Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
# TODO: Add other variants for llama3.1 "Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
} }
async def get_provider_impl(config: OllamaImplConfig) -> Inference: async def get_provider_impl(
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance( assert isinstance(
config, OllamaImplConfig config, OllamaImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
@ -57,15 +57,14 @@ class OllamaInference(Inference):
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
self.config = config self.config = config
self.model = config.model
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.config.url)
async def initialize(self) -> None: async def initialize(self) -> None:
self.client = AsyncClient(host=self.config.url)
try: try:
status = await self.client.pull(self.model) await self.client.ps()
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
except httpx.ConnectError: except httpx.ConnectError:
print( print(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal" "Ollama Server is not running, start it using `ollama serve` in a separate terminal"
@ -81,7 +80,11 @@ class OllamaInference(Inference):
def _messages_to_ollama_messages(self, messages: list[Message]) -> list: def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = [] ollama_messages = []
for message in messages: for message in messages:
ollama_messages.append({"role": message.role, "content": message.content}) if message.role == "ipython":
role = "tool"
else:
role = message.role
ollama_messages.append({"role": role, "content": message.content})
return ollama_messages return ollama_messages
@ -112,6 +115,21 @@ class OllamaInference(Inference):
# accumulate sampling params and other options to pass to ollama # accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request) options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model) ollama_model = self.resolve_ollama_model(request.model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
if not request.stream: if not request.stream:
r = await self.client.chat( r = await self.client.chat(
model=ollama_model, model=ollama_model,
@ -141,7 +159,6 @@ class OllamaInference(Inference):
delta="", delta="",
) )
) )
stream = await self.client.chat( stream = await self.client.chat(
model=ollama_model, model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages), messages=self._messages_to_ollama_messages(request.messages),
@ -154,11 +171,10 @@ class OllamaInference(Inference):
stop_reason = None stop_reason = None
async for chunk in stream: async for chunk in stream:
# check if ollama is done
if chunk["done"]: if chunk["done"]:
if chunk["done_reason"] == "stop": if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif chunk["done_reason"] == "length": elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
break break
@ -176,7 +192,7 @@ class OllamaInference(Inference):
), ),
) )
) )
buffer = buffer[len("<|python_tag|>") :] buffer += text
continue continue
if ipython: if ipython:
@ -214,7 +230,6 @@ class OllamaInference(Inference):
# parse tool calls and report errors # parse tool calls and report errors
message = decode_assistant_message_from_content(buffer, stop_reason) message = decode_assistant_message_from_content(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(

View file

@ -10,14 +10,14 @@ from pydantic import BaseModel
class LlamaGuardShieldConfig(BaseModel): class LlamaGuardShieldConfig(BaseModel):
model: str model: str = "Llama-Guard-3-8B"
excluded_categories: List[str] excluded_categories: List[str] = []
disable_input_check: bool = False disable_input_check: bool = False
disable_output_check: bool = False disable_output_check: bool = False
class PromptGuardShieldConfig(BaseModel): class PromptGuardShieldConfig(BaseModel):
model: str model: str = "Prompt-Guard-86M"
class SafetyConfig(BaseModel): class SafetyConfig(BaseModel):

View file

@ -1,340 +0,0 @@
#!/bin/sh
# This script installs Ollama on Linux.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
set -eu
status() { echo ">>> $*" >&2; }
error() { echo "ERROR $*"; exit 1; }
warning() { echo "WARNING: $*"; }
TEMP_DIR=$(mktemp -d)
cleanup() { rm -rf $TEMP_DIR; }
trap cleanup EXIT
available() { command -v $1 >/dev/null; }
require() {
local MISSING=''
for TOOL in $*; do
if ! available $TOOL; then
MISSING="$MISSING $TOOL"
fi
done
echo $MISSING
}
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;;
*) error "Unsupported architecture: $ARCH" ;;
esac
IS_WSL2=false
KERN=$(uname -r)
case "$KERN" in
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
*icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;;
*) ;;
esac
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
SUDO=
if [ "$(id -u)" -ne 0 ]; then
# Running as root, no need for sudo
if ! available sudo; then
error "This script requires superuser permissions. Please re-run as root."
fi
SUDO="sudo"
fi
NEEDS=$(require curl awk grep sed tee xargs)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
status "Downloading ollama..."
curl --fail --show-error --location --progress-bar -o $TEMP_DIR/ollama "https://ollama.com/download/ollama-linux-${ARCH}${VER_PARAM}"
for BINDIR in /usr/local/bin /usr/bin /bin; do
echo $PATH | grep -q $BINDIR && break || continue
done
status "Installing ollama to $BINDIR..."
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
install_success() {
status 'The Ollama API is now available at 127.0.0.1:11434.'
status 'Install complete. Run "ollama" from the command line.'
}
trap install_success EXIT
# Everything from this point onwards is optional.
configure_systemd() {
if ! id ollama >/dev/null 2>&1; then
status "Creating ollama user..."
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/ollama ollama
fi
if getent group render >/dev/null 2>&1; then
status "Adding ollama user to render group..."
$SUDO usermod -a -G render ollama
fi
if getent group video >/dev/null 2>&1; then
status "Adding ollama user to video group..."
$SUDO usermod -a -G video ollama
fi
status "Adding current user to ollama group..."
$SUDO usermod -a -G ollama $(whoami)
status "Creating ollama systemd service..."
cat <<EOF | $SUDO tee /etc/systemd/system/ollama.service >/dev/null
[Unit]
Description=Ollama Service
After=network-online.target
[Service]
ExecStart=$BINDIR/ollama serve
User=ollama
Group=ollama
Restart=always
RestartSec=3
Environment="PATH=$PATH"
[Install]
WantedBy=default.target
EOF
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
case $SYSTEMCTL_RUNNING in
running|degraded)
status "Enabling and starting ollama service..."
$SUDO systemctl daemon-reload
$SUDO systemctl enable ollama
start_service() { $SUDO systemctl restart ollama; }
trap start_service EXIT
;;
esac
}
if available systemctl; then
configure_systemd
fi
# WSL2 only supports GPUs via nvidia passthrough
# so check for nvidia-smi to determine if GPU is available
if [ "$IS_WSL2" = true ]; then
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
status "Nvidia GPU detected."
fi
install_success
exit 0
fi
# Install GPU dependencies on Linux
if ! available lspci && ! available lshw; then
warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies."
exit 0
fi
check_gpu() {
# Look for devices based on vendor ID for NVIDIA and AMD
case $1 in
lspci)
case $2 in
nvidia) available lspci && lspci -d '10de:' | grep -q 'NVIDIA' || return 1 ;;
amdgpu) available lspci && lspci -d '1002:' | grep -q 'AMD' || return 1 ;;
esac ;;
lshw)
case $2 in
nvidia) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[10DE\]' || return 1 ;;
amdgpu) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[1002\]' || return 1 ;;
esac ;;
nvidia-smi) available nvidia-smi || return 1 ;;
esac
}
if check_gpu nvidia-smi; then
status "NVIDIA GPU installed."
exit 0
fi
if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdgpu && ! check_gpu lshw amdgpu; then
install_success
warning "No NVIDIA/AMD GPU detected. Ollama will run in CPU-only mode."
exit 0
fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
# Look for pre-existing ROCm v6 before downloading the dependencies
for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm" "/usr/lib64"; do
if [ -n "${search}" ] && [ -e "${search}/libhipblas.so.2" -o -e "${search}/lib/libhipblas.so.2" ]; then
status "Compatible AMD GPU ROCm library detected at ${search}"
install_success
exit 0
fi
done
status "Downloading AMD GPU dependencies..."
$SUDO rm -rf /usr/share/ollama/lib
$SUDO chmod o+x /usr/share/ollama
$SUDO install -o ollama -g ollama -m 755 -d /usr/share/ollama/lib/rocm
curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \
| $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm .
install_success
status "AMD GPU ready."
exit 0
fi
CUDA_REPO_ERR_MSG="NVIDIA GPU detected, but your OS and Architecture are not supported by NVIDIA. Please install the CUDA driver manually https://docs.nvidia.com/cuda/cuda-installation-guide-linux/"
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-7-centos-7
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-8-rocky-8
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-9-rocky-9
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#fedora
install_cuda_driver_yum() {
status 'Installing NVIDIA repository...'
case $PACKAGE_MANAGER in
yum)
$SUDO $PACKAGE_MANAGER -y install yum-utils
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
else
error $CUDA_REPO_ERR_MSG
fi
;;
dnf)
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
else
error $CUDA_REPO_ERR_MSG
fi
;;
esac
case $1 in
rhel)
status 'Installing EPEL repository...'
# EPEL is required for third-party dependencies such as dkms and libvdpau
$SUDO $PACKAGE_MANAGER -y install https://dl.fedoraproject.org/pub/epel/epel-release-latest-$2.noarch.rpm || true
;;
esac
status 'Installing CUDA driver...'
if [ "$1" = 'centos' ] || [ "$1$2" = 'rhel7' ]; then
$SUDO $PACKAGE_MANAGER -y install nvidia-driver-latest-dkms
fi
$SUDO $PACKAGE_MANAGER -y install cuda-drivers
}
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
install_cuda_driver_apt() {
status 'Installing NVIDIA repository...'
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
else
error $CUDA_REPO_ERR_MSG
fi
case $1 in
debian)
status 'Enabling contrib sources...'
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list | $SUDO tee /etc/apt/sources.list.d/contrib.list > /dev/null
if [ -f "/etc/apt/sources.list.d/debian.sources" ]; then
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list.d/debian.sources | $SUDO tee /etc/apt/sources.list.d/contrib.sources > /dev/null
fi
;;
esac
status 'Installing CUDA driver...'
$SUDO dpkg -i $TEMP_DIR/cuda-keyring.deb
$SUDO apt-get update
[ -n "$SUDO" ] && SUDO_E="$SUDO -E" || SUDO_E=
DEBIAN_FRONTEND=noninteractive $SUDO_E apt-get -y install cuda-drivers -q
}
if [ ! -f "/etc/os-release" ]; then
error "Unknown distribution. Skipping CUDA installation."
fi
. /etc/os-release
OS_NAME=$ID
OS_VERSION=$VERSION_ID
PACKAGE_MANAGER=
for PACKAGE_MANAGER in dnf yum apt-get; do
if available $PACKAGE_MANAGER; then
break
fi
done
if [ -z "$PACKAGE_MANAGER" ]; then
error "Unknown package manager. Skipping CUDA installation."
fi
if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
case $OS_NAME in
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';;
amzn) install_cuda_driver_yum 'fedora' '37' ;;
debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;
*) exit ;;
esac
fi
if ! lsmod | grep -q nvidia || ! lsmod | grep -q nvidia_uvm; then
KERNEL_RELEASE="$(uname -r)"
case $OS_NAME in
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
centos|rhel|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
fedora) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE ;;
debian|ubuntu) $SUDO apt-get -y install linux-headers-$KERNEL_RELEASE ;;
*) exit ;;
esac
NVIDIA_CUDA_VERSION=$($SUDO dkms status | awk -F: '/added/ { print $1 }')
if [ -n "$NVIDIA_CUDA_VERSION" ]; then
$SUDO dkms install $NVIDIA_CUDA_VERSION
fi
if lsmod | grep -q nouveau; then
status 'Reboot to complete NVIDIA CUDA driver install.'
exit 0
fi
$SUDO modprobe nvidia
$SUDO modprobe nvidia_uvm
fi
# make sure the NVIDIA modules are loaded on boot with nvidia-persistenced
if command -v nvidia-persistenced > /dev/null 2>&1; then
$SUDO touch /etc/modules-load.d/nvidia.conf
MODULES="nvidia nvidia-uvm"
for MODULE in $MODULES; do
if ! grep -qxF "$MODULE" /etc/modules-load.d/nvidia.conf; then
echo "$MODULE" | sudo tee -a /etc/modules-load.d/nvidia.conf > /dev/null
fi
done
fi
status "NVIDIA GPU ready."
install_success

View file

@ -6,6 +6,7 @@ flake8
httpx httpx
huggingface-hub huggingface-hub
json-strong-typing json-strong-typing
llama-models
omegaconf omegaconf
pre-commit pre-commit
pydantic==1.10.13 pydantic==1.10.13

View file

@ -13,6 +13,7 @@ from llama_models.llama3_1.api.datatypes import (
UserMessage, UserMessage,
StopReason, StopReason,
SystemMessage, SystemMessage,
ToolResponseMessage,
) )
from llama_toolchain.inference.api.datatypes import ( from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -256,3 +257,33 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
) )
self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn)
self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point")
async def test_multi_turn(self):
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
self.system_prompt,
UserMessage(
content="Search the web and tell me who the "
"44th president of the United States was",
),
ToolResponseMessage(
call_id="1",
tool_name=BuiltinTool.brave_search,
# content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "<strong>Barack Obama</strong> served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, <strong>President Obama</strong> moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}',
content='"Barack Obama"',
),
],
stream=True,
)
iterator = self.api.chat_completion(request)
events = []
async for chunk in iterator:
events.append(chunk.event)
response = ""
for e in events[1:-1]:
response += e.delta
self.assertTrue("obama" in response.lower())

View file

@ -9,6 +9,7 @@ from llama_models.llama3_1.api.datatypes import (
SamplingParams, SamplingParams,
SamplingStrategy, SamplingStrategy,
SystemMessage, SystemMessage,
ToolResponseMessage,
) )
from llama_toolchain.inference.api.datatypes import ( from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -21,14 +22,10 @@ from llama_toolchain.inference.ollama.ollama import get_provider_impl
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" ollama_config = OllamaImplConfig(url="http://localhost:11434")
ollama_config = OllamaImplConfig(
model="llama3.1:8b-instruct-fp16",
url="http://localhost:11434",
)
# setup ollama # setup ollama
self.api = await get_provider_impl(ollama_config) self.api = await get_provider_impl(ollama_config, {})
await self.api.initialize() await self.api.initialize()
current_date = datetime.now() current_date = datetime.now()
@ -245,7 +242,6 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
iterator = self.api.chat_completion(request) iterator = self.api.chat_completion(request)
events = [] events = []
async for chunk in iterator: async for chunk in iterator:
# print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ")
events.append(chunk.event) events.append(chunk.event)
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
@ -253,6 +249,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual( self.assertEqual(
events[-1].event_type, ChatCompletionResponseEventType.complete events[-1].event_type, ChatCompletionResponseEventType.complete
) )
# last but one event should be eom with tool call
self.assertEqual(
events[-2].event_type, ChatCompletionResponseEventType.progress
)
self.assertEqual(events[-2].stop_reason, StopReason.end_of_message)
self.assertEqual(events[-2].delta.content.tool_name, BuiltinTool.brave_search)
async def test_custom_tool_call_streaming(self): async def test_custom_tool_call_streaming(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -317,3 +319,62 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
"top_p": 0.99, "top_p": 0.99,
}, },
) )
async def test_multi_turn(self):
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
self.system_prompt,
UserMessage(
content="Search the web and tell me who the "
"44th president of the United States was",
),
ToolResponseMessage(
call_id="1",
tool_name=BuiltinTool.brave_search,
content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "<strong>Barack Obama</strong> served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, <strong>President Obama</strong> moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}',
),
],
stream=True,
)
iterator = self.api.chat_completion(request)
events = []
async for chunk in iterator:
events.append(chunk.event)
response = ""
for e in events[1:-1]:
response += e.delta
self.assertTrue("obama" in response.lower())
async def test_tool_call_code_streaming(self):
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
self.system_prompt,
UserMessage(
content="Write code to answer this question: What is the 100th prime number?",
),
],
stream=True,
)
iterator = self.api.chat_completion(request)
events = []
async for chunk in iterator:
events.append(chunk.event)
self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start)
# last event is of type "complete"
self.assertEqual(
events[-1].event_type, ChatCompletionResponseEventType.complete
)
# last but one event should be eom with tool call
self.assertEqual(
events[-2].event_type, ChatCompletionResponseEventType.progress
)
self.assertEqual(events[-2].stop_reason, StopReason.end_of_message)
self.assertEqual(
events[-2].delta.content.tool_name, BuiltinTool.code_interpreter
)