added options to ollama inference

This commit is contained in:
Hardik Shah 2024-08-02 14:44:22 -07:00
parent 09cf3fe78b
commit d7a4cdd70d
2 changed files with 89 additions and 13 deletions

View file

@ -5,6 +5,7 @@ from typing import AsyncGenerator
from ollama import AsyncClient from ollama import AsyncClient
from llama_models.sku_list import resolve_model
from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3_1.api.datatypes import (
BuiltinTool, BuiltinTool,
CompletionMessage, CompletionMessage,
@ -29,6 +30,12 @@ from .api.endpoints import (
Inference, Inference,
) )
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16"
# TODO: Add other variants for llama3.1
}
class OllamaInference(Inference): class OllamaInference(Inference):
@ -61,14 +68,41 @@ class OllamaInference(Inference):
return ollama_messages return ollama_messages
def resolve_ollama_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None and
model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}"
return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True))
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None and
request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model)
if not request.stream: if not request.stream:
r = await self.client.chat( r = await self.client.chat(
model=self.model, model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages), messages=self._messages_to_ollama_messages(request.messages),
stream=False, stream=False,
#TODO: add support for options like temp, top_p, max_seq_length, etc options=options,
) )
stop_reason = None
if r['done']: if r['done']:
if r['done_reason'] == 'stop': if r['done_reason'] == 'stop':
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
@ -92,9 +126,10 @@ class OllamaInference(Inference):
) )
stream = await self.client.chat( stream = await self.client.chat(
model=self.model, model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages), messages=self._messages_to_ollama_messages(request.messages),
stream=True stream=True,
options=options,
) )
buffer = "" buffer = ""

View file

@ -4,9 +4,10 @@ from datetime import datetime
from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3_1.api.datatypes import (
BuiltinTool, BuiltinTool,
InstructModel,
UserMessage, UserMessage,
StopReason, StopReason,
SamplingParams,
SamplingStrategy,
SystemMessage, SystemMessage,
) )
from llama_toolchain.inference.api_instance import ( from llama_toolchain.inference.api_instance import (
@ -84,13 +85,14 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
""" """
), ),
) )
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
async def asyncTearDown(self): async def asyncTearDown(self):
await self.api.shutdown() await self.api.shutdown()
async def test_text(self): async def test_text(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=self.valid_supported_model,
messages=[ messages=[
UserMessage( UserMessage(
content="What is the capital of France?", content="What is the capital of France?",
@ -107,7 +109,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def test_tool_call(self): async def test_tool_call(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=self.valid_supported_model,
messages=[ messages=[
self.system_prompt, self.system_prompt,
UserMessage( UserMessage(
@ -133,7 +135,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def test_code_execution(self): async def test_code_execution(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=self.valid_supported_model,
messages=[ messages=[
self.system_prompt, self.system_prompt,
UserMessage( UserMessage(
@ -158,7 +160,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def test_custom_tool(self): async def test_custom_tool(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=self.valid_supported_model,
messages=[ messages=[
self.system_prompt_with_custom_tool, self.system_prompt_with_custom_tool,
UserMessage( UserMessage(
@ -174,7 +176,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
completion_message = response.completion_message completion_message = response.completion_message
self.assertEqual(completion_message.content, "") self.assertEqual(completion_message.content, "")
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) self.assertTrue(
completion_message.stop_reason in {
StopReason.end_of_turn,
StopReason.end_of_message,
}
)
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point")
@ -186,7 +193,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def test_text_streaming(self): async def test_text_streaming(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=self.valid_supported_model,
messages=[ messages=[
UserMessage( UserMessage(
content="What is the capital of France?", content="What is the capital of France?",
@ -226,7 +233,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def test_tool_call_streaming(self): async def test_tool_call_streaming(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=self.valid_supported_model,
messages=[ messages=[
self.system_prompt, self.system_prompt,
UserMessage( UserMessage(
@ -253,7 +260,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def test_custom_tool_call_streaming(self): async def test_custom_tool_call_streaming(self):
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=self.valid_supported_model,
messages=[ messages=[
self.system_prompt_with_custom_tool, self.system_prompt_with_custom_tool,
UserMessage( UserMessage(
@ -294,3 +301,37 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
events[-2].stop_reason, events[-2].stop_reason,
StopReason.end_of_turn StopReason.end_of_turn
) )
def test_resolve_ollama_model(self):
ollama_model = self.api.resolve_ollama_model(self.valid_supported_model)
self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16")
invalid_model = "Meta-Llama3.1-8B"
with self.assertRaisesRegex(
AssertionError, f"Unsupported model: {invalid_model}"
):
self.api.resolve_ollama_model(invalid_model)
async def test_ollama_chat_options(self):
request = ChatCompletionRequest(
model=self.valid_supported_model,
messages=[
UserMessage(
content="What is the capital of France?",
),
],
stream=False,
sampling_params=SamplingParams(
sampling_strategy=SamplingStrategy.top_p,
top_p=0.99,
temperature=1.0,
)
)
options = self.api.get_ollama_chat_options(request)
self.assertEqual(
options,
{
"temperature": 1.0,
"top_p": 0.99,
}
)