Merge remote-tracking branch 'origin/main' into distros

This commit is contained in:
Ashwin Bharambe 2024-08-05 14:31:06 -07:00
commit f64668319c
3 changed files with 171 additions and 137 deletions

View file

@ -9,7 +9,6 @@ import uuid
from typing import AsyncGenerator
import httpx
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
CompletionMessage,
@ -19,6 +18,8 @@ from llama_models.llama3_1.api.datatypes import (
)
from llama_models.llama3_1.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from .api.config import OllamaImplConfig
@ -36,6 +37,13 @@ from .api.endpoints import (
Inference,
)
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16"
# TODO: Add other variants for llama3.1
}
def get_provider_impl(config: OllamaImplConfig) -> Inference:
assert isinstance(
@ -76,14 +84,41 @@ class OllamaInference(Inference):
return ollama_messages
def resolve_ollama_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}"
return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True))
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model)
if not request.stream:
r = await self.client.chat(
model=self.model,
model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages),
stream=False,
# TODO: add support for options like temp, top_p, max_seq_length, etc
options=options,
)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
@ -107,9 +142,10 @@ class OllamaInference(Inference):
)
stream = await self.client.chat(
model=self.model,
model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages),
stream=True,
options=options,
)
buffer = ""