added options to ollama inference

This commit is contained in:
Hardik Shah 2024-08-02 14:44:22 -07:00
parent 09cf3fe78b
commit d7a4cdd70d
2 changed files with 89 additions and 13 deletions

View file

@ -5,6 +5,7 @@ from typing import AsyncGenerator
from ollama import AsyncClient
from llama_models.sku_list import resolve_model
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
CompletionMessage,
@ -29,6 +30,12 @@ from .api.endpoints import (
Inference,
)
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16"
# TODO: Add other variants for llama3.1
}
class OllamaInference(Inference):
@ -61,14 +68,41 @@ class OllamaInference(Inference):
return ollama_messages
def resolve_ollama_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None and
model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}"
return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True))
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None and
request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model)
if not request.stream:
r = await self.client.chat(
model=self.model,
model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages),
stream=False,
#TODO: add support for options like temp, top_p, max_seq_length, etc
options=options,
)
stop_reason = None
if r['done']:
if r['done_reason'] == 'stop':
stop_reason = StopReason.end_of_turn
@ -92,9 +126,10 @@ class OllamaInference(Inference):
)
stream = await self.client.chat(
model=self.model,
model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages),
stream=True
stream=True,
options=options,
)
buffer = ""