diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama.py index 91727fd62..4f995b7ca 100644 --- a/llama_toolchain/inference/ollama.py +++ b/llama_toolchain/inference/ollama.py @@ -5,6 +5,7 @@ from typing import AsyncGenerator from ollama import AsyncClient +from llama_models.sku_list import resolve_model from llama_models.llama3_1.api.datatypes import ( BuiltinTool, CompletionMessage, @@ -29,6 +30,12 @@ from .api.endpoints import ( Inference, ) +# TODO: Eventually this will move to the llama cli model list command +# mapping of Model SKUs to ollama models +OLLAMA_SUPPORTED_SKUS = { + "Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16" + # TODO: Add other variants for llama3.1 +} class OllamaInference(Inference): @@ -61,14 +68,41 @@ class OllamaInference(Inference): return ollama_messages + def resolve_ollama_model(self, model_name: str) -> str: + model = resolve_model(model_name) + assert ( + model is not None and + model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS + ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}" + + return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True)) + + def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict: + options = {} + if request.sampling_params is not None: + for attr in {"temperature", "top_p", "top_k", "max_tokens"}: + if getattr(request.sampling_params, attr): + options[attr] = getattr(request.sampling_params, attr) + if ( + request.sampling_params.repetition_penalty is not None and + request.sampling_params.repetition_penalty != 1.0 + ): + options["repeat_penalty"] = request.sampling_params.repetition_penalty + + return options + async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + # accumulate sampling params and other options to pass to ollama + options = self.get_ollama_chat_options(request) + ollama_model = self.resolve_ollama_model(request.model) if not request.stream: r = await self.client.chat( - model=self.model, + model=ollama_model, messages=self._messages_to_ollama_messages(request.messages), stream=False, - #TODO: add support for options like temp, top_p, max_seq_length, etc + options=options, ) + stop_reason = None if r['done']: if r['done_reason'] == 'stop': stop_reason = StopReason.end_of_turn @@ -92,9 +126,10 @@ class OllamaInference(Inference): ) stream = await self.client.chat( - model=self.model, + model=ollama_model, messages=self._messages_to_ollama_messages(request.messages), - stream=True + stream=True, + options=options, ) buffer = "" diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index d9dbab6e9..c0628fc73 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -4,9 +4,10 @@ from datetime import datetime from llama_models.llama3_1.api.datatypes import ( BuiltinTool, - InstructModel, UserMessage, StopReason, + SamplingParams, + SamplingStrategy, SystemMessage, ) from llama_toolchain.inference.api_instance import ( @@ -84,13 +85,14 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): """ ), ) + self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" async def asyncTearDown(self): await self.api.shutdown() async def test_text(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ UserMessage( content="What is the capital of France?", @@ -107,7 +109,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def test_tool_call(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt, UserMessage( @@ -133,7 +135,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def test_code_execution(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt, UserMessage( @@ -158,7 +160,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def test_custom_tool(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt_with_custom_tool, UserMessage( @@ -174,7 +176,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): completion_message = response.completion_message self.assertEqual(completion_message.content, "") - self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) + self.assertTrue( + completion_message.stop_reason in { + StopReason.end_of_turn, + StopReason.end_of_message, + } + ) self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") @@ -186,7 +193,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def test_text_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ UserMessage( content="What is the capital of France?", @@ -226,7 +233,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def test_tool_call_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt, UserMessage( @@ -253,7 +260,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def test_custom_tool_call_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt_with_custom_tool, UserMessage( @@ -294,3 +301,37 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): events[-2].stop_reason, StopReason.end_of_turn ) + + def test_resolve_ollama_model(self): + ollama_model = self.api.resolve_ollama_model(self.valid_supported_model) + self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16") + + invalid_model = "Meta-Llama3.1-8B" + with self.assertRaisesRegex( + AssertionError, f"Unsupported model: {invalid_model}" + ): + self.api.resolve_ollama_model(invalid_model) + + async def test_ollama_chat_options(self): + request = ChatCompletionRequest( + model=self.valid_supported_model, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=False, + sampling_params=SamplingParams( + sampling_strategy=SamplingStrategy.top_p, + top_p=0.99, + temperature=1.0, + ) + ) + options = self.api.get_ollama_chat_options(request) + self.assertEqual( + options, + { + "temperature": 1.0, + "top_p": 0.99, + } + )