get ollama working

This commit is contained in:
Hardik Shah 2024-08-07 17:52:49 -07:00
parent ea50086190
commit 171a178783
9 changed files with 151 additions and 375 deletions

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
import uuid
from typing import AsyncGenerator
from typing import AsyncGenerator, Dict
import httpx
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
CompletionMessage,
@ -17,11 +17,8 @@ from llama_models.llama3_1.api.datatypes import (
ToolCall,
)
from llama_models.llama3_1.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -33,18 +30,21 @@ from llama_toolchain.inference.api import (
ToolCallDelta,
ToolCallParseStatus,
)
from ollama import AsyncClient
from .config import OllamaImplConfig
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16"
# TODO: Add other variants for llama3.1
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
}
async def get_provider_impl(config: OllamaImplConfig) -> Inference:
async def get_provider_impl(
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, OllamaImplConfig
), f"Unexpected config type: {type(config)}"
@ -57,15 +57,14 @@ class OllamaInference(Inference):
def __init__(self, config: OllamaImplConfig) -> None:
self.config = config
self.model = config.model
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.config.url)
async def initialize(self) -> None:
self.client = AsyncClient(host=self.config.url)
try:
status = await self.client.pull(self.model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
await self.client.ps()
except httpx.ConnectError:
print(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
@ -81,7 +80,11 @@ class OllamaInference(Inference):
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = []
for message in messages:
ollama_messages.append({"role": message.role, "content": message.content})
if message.role == "ipython":
role = "tool"
else:
role = message.role
ollama_messages.append({"role": role, "content": message.content})
return ollama_messages
@ -112,6 +115,21 @@ class OllamaInference(Inference):
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
if not request.stream:
r = await self.client.chat(
model=ollama_model,
@ -141,7 +159,6 @@ class OllamaInference(Inference):
delta="",
)
)
stream = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages),
@ -154,11 +171,10 @@ class OllamaInference(Inference):
stop_reason = None
async for chunk in stream:
# check if ollama is done
if chunk["done"]:
if chunk["done_reason"] == "stop":
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif chunk["done_reason"] == "length":
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
@ -176,7 +192,7 @@ class OllamaInference(Inference):
),
)
)
buffer = buffer[len("<|python_tag|>") :]
buffer += text
continue
if ipython:
@ -214,7 +230,6 @@ class OllamaInference(Inference):
# parse tool calls and report errors
message = decode_assistant_message_from_content(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(