diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama.py index 560960f8b..9526a8665 100644 --- a/llama_toolchain/inference/ollama.py +++ b/llama_toolchain/inference/ollama.py @@ -9,7 +9,6 @@ import uuid from typing import AsyncGenerator import httpx - from llama_models.llama3_1.api.datatypes import ( BuiltinTool, CompletionMessage, @@ -19,6 +18,8 @@ from llama_models.llama3_1.api.datatypes import ( ) from llama_models.llama3_1.api.tool_utils import ToolUtils +from llama_models.sku_list import resolve_model + from ollama import AsyncClient from .api.config import OllamaImplConfig @@ -36,6 +37,13 @@ from .api.endpoints import ( Inference, ) +# TODO: Eventually this will move to the llama cli model list command +# mapping of Model SKUs to ollama models +OLLAMA_SUPPORTED_SKUS = { + "Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16" + # TODO: Add other variants for llama3.1 +} + def get_provider_impl(config: OllamaImplConfig) -> Inference: assert isinstance( @@ -76,14 +84,41 @@ class OllamaInference(Inference): return ollama_messages + def resolve_ollama_model(self, model_name: str) -> str: + model = resolve_model(model_name) + assert ( + model is not None + and model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS + ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}" + + return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True)) + + def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict: + options = {} + if request.sampling_params is not None: + for attr in {"temperature", "top_p", "top_k", "max_tokens"}: + if getattr(request.sampling_params, attr): + options[attr] = getattr(request.sampling_params, attr) + if ( + request.sampling_params.repetition_penalty is not None + and request.sampling_params.repetition_penalty != 1.0 + ): + options["repeat_penalty"] = request.sampling_params.repetition_penalty + + return options + async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + # accumulate sampling params and other options to pass to ollama + options = self.get_ollama_chat_options(request) + ollama_model = self.resolve_ollama_model(request.model) if not request.stream: r = await self.client.chat( - model=self.model, + model=ollama_model, messages=self._messages_to_ollama_messages(request.messages), stream=False, - # TODO: add support for options like temp, top_p, max_seq_length, etc + options=options, ) + stop_reason = None if r["done"]: if r["done_reason"] == "stop": stop_reason = StopReason.end_of_turn @@ -107,9 +142,10 @@ class OllamaInference(Inference): ) stream = await self.client.chat( - model=self.model, + model=ollama_model, messages=self._messages_to_ollama_messages(request.messages), stream=True, + options=options, ) buffer = "" diff --git a/tests/test_inference.py b/tests/test_inference.py index bbd4a2c1a..ad7bf6d19 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -10,14 +10,12 @@ from datetime import datetime from llama_models.llama3_1.api.datatypes import ( BuiltinTool, - InstructModel, UserMessage, StopReason, SystemMessage, ) from llama_toolchain.inference.api.config import ( - ImplType, InferenceConfig, InlineImplConfig, RemoteImplConfig, @@ -31,11 +29,7 @@ from llama_toolchain.inference.api_instance import ( from llama_toolchain.inference.api.datatypes import ( ChatCompletionResponseEventType, ) -from llama_toolchain.inference.api.endpoints import ( - ChatCompletionRequest -) -from llama_toolchain.inference.inference import InferenceImpl -from llama_toolchain.inference.event_logger import EventLogger +from llama_toolchain.inference.api.endpoints import ChatCompletionRequest HELPER_MSG = """ @@ -56,7 +50,9 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): @classmethod async def asyncSetUpClass(cls): # assert model exists on local - model_dir = os.path.expanduser("~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/") + model_dir = os.path.expanduser( + "~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/" + ) assert os.path.isdir(model_dir), HELPER_MSG tokenizer_path = os.path.join(model_dir, "tokenizer.model") @@ -73,17 +69,11 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): ), max_seq_len=2048, ) - inference_config = InferenceConfig( - impl_config=inline_config - ) + inference_config = InferenceConfig(impl_config=inline_config) # -- For faster testing iteration -- - # remote_config = RemoteImplConfig( - # url="http://localhost:5000" - # ) - # inference_config = InferenceConfig( - # impl_config=remote_config - # ) + # remote_config = RemoteImplConfig(url="http://localhost:5000") + # inference_config = InferenceConfig(impl_config=remote_config) cls.api = await get_inference_api_instance(inference_config) await cls.api.initialize() @@ -91,17 +81,20 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): current_date = datetime.now() formatted_date = current_date.strftime("%d %B %Y") cls.system_prompt = SystemMessage( - content=textwrap.dedent(f""" + content=textwrap.dedent( + f""" Environment: ipython Tools: brave_search Cutting Knowledge Date: December 2023 Today Date:{formatted_date} - """), + """ + ), ) cls.system_prompt_with_custom_tool = SystemMessage( - content=textwrap.dedent(""" + content=textwrap.dedent( + """ Environment: ipython Tools: brave_search, wolfram_alpha, photogen @@ -140,9 +133,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def asyncTearDownClass(cls): await cls.api.shutdown() + async def asyncSetUp(self): + self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" + async def test_text(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ UserMessage( content="What is the capital of France?", @@ -160,7 +156,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def test_text_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ UserMessage( content="What is the capital of France?", @@ -175,13 +171,9 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): events.append(chunk.event) # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) - self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete + events[-1].event_type, ChatCompletionResponseEventType.complete ) response = "" @@ -192,7 +184,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def test_custom_tool_call(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ InferenceTests.system_prompt_with_custom_tool, UserMessage( @@ -214,8 +206,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): # instead of end_of_turn # self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) - self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, "get_boiling_point" + ) args = completion_message.tool_calls[0].arguments self.assertTrue(isinstance(args, dict)) @@ -223,7 +219,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def test_tool_call_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt, UserMessage( @@ -239,32 +235,21 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") events.append(chunk.event) - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete + events[-1].event_type, ChatCompletionResponseEventType.complete ) # last but one event should be eom with tool call self.assertEqual( - events[-2].event_type, - ChatCompletionResponseEventType.progress - ) - self.assertEqual( - events[-2].stop_reason, - StopReason.end_of_message - ) - self.assertEqual( - events[-2].delta.content.tool_name, - BuiltinTool.brave_search + events[-2].event_type, ChatCompletionResponseEventType.progress ) + self.assertEqual(events[-2].stop_reason, StopReason.end_of_message) + self.assertEqual(events[-2].delta.content.tool_name, BuiltinTool.brave_search) async def test_custom_tool_call_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ InferenceTests.system_prompt_with_custom_tool, UserMessage( @@ -279,29 +264,15 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") events.append(chunk.event) - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete - ) - self.assertEqual( - events[-1].stop_reason, - StopReason.end_of_turn + events[-1].event_type, ChatCompletionResponseEventType.complete ) + self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) # last but one event should be eom with tool call self.assertEqual( - events[-2].event_type, - ChatCompletionResponseEventType.progress - ) - self.assertEqual( - events[-2].stop_reason, - StopReason.end_of_turn - ) - self.assertEqual( - events[-2].delta.content.tool_name, - "get_boiling_point" + events[-2].event_type, ChatCompletionResponseEventType.progress ) + self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) + self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index d9dbab6e9..67493db25 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -4,9 +4,10 @@ from datetime import datetime from llama_models.llama3_1.api.datatypes import ( BuiltinTool, - InstructModel, UserMessage, StopReason, + SamplingParams, + SamplingStrategy, SystemMessage, ) from llama_toolchain.inference.api_instance import ( @@ -15,23 +16,16 @@ from llama_toolchain.inference.api_instance import ( from llama_toolchain.inference.api.datatypes import ( ChatCompletionResponseEventType, ) -from llama_toolchain.inference.api.endpoints import ( - ChatCompletionRequest -) -from llama_toolchain.inference.api.config import ( - InferenceConfig, - OllamaImplConfig -) -from llama_toolchain.inference.ollama import ( - OllamaInference -) +from llama_toolchain.inference.api.endpoints import ChatCompletionRequest +from llama_toolchain.inference.api.config import InferenceConfig, OllamaImplConfig class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): + self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" ollama_config = OllamaImplConfig( - model="llama3.1", + model="llama3.1:8b-instruct-fp16", url="http://localhost:11434", ) @@ -44,18 +38,21 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): current_date = datetime.now() formatted_date = current_date.strftime("%d %B %Y") self.system_prompt = SystemMessage( - content=textwrap.dedent(f""" + content=textwrap.dedent( + f""" Environment: ipython Tools: brave_search Cutting Knowledge Date: December 2023 Today Date:{formatted_date} - """), + """ + ), ) self.system_prompt_with_custom_tool = SystemMessage( - content=textwrap.dedent(""" + content=textwrap.dedent( + """ Environment: ipython Tools: brave_search, wolfram_alpha, photogen @@ -78,19 +75,19 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): - If looking for real time information use relevant functions before falling back to brave_search - Function calls MUST follow the specified format, start with - Required parameters MUST be specified - - Only call one function at a time - Put the entire function call reply on one line """ ), ) + self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" async def asyncTearDown(self): await self.api.shutdown() async def test_text(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ UserMessage( content="What is the capital of France?", @@ -103,11 +100,13 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): response = r self.assertTrue("Paris" in response.completion_message.content) - self.assertEqual(response.completion_message.stop_reason, StopReason.end_of_turn) + self.assertEqual( + response.completion_message.stop_reason, StopReason.end_of_turn + ) async def test_tool_call(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt, UserMessage( @@ -125,15 +124,19 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(completion_message.content, "") self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) - self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search) + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search + ) self.assertTrue( "president" in completion_message.tool_calls[0].arguments["query"].lower() ) async def test_code_execution(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt, UserMessage( @@ -151,14 +154,18 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(completion_message.content, "") self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) - self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter) + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter + ) code = completion_message.tool_calls[0].arguments["code"] self.assertTrue("def " in code.lower(), code) async def test_custom_tool(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt_with_custom_tool, UserMessage( @@ -174,19 +181,28 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): completion_message = response.completion_message self.assertEqual(completion_message.content, "") - self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) + self.assertTrue( + completion_message.stop_reason + in { + StopReason.end_of_turn, + StopReason.end_of_message, + } + ) - self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, "get_boiling_point" + ) args = completion_message.tool_calls[0].arguments self.assertTrue(isinstance(args, dict)) self.assertTrue(args["liquid_name"], "polyjuice") - async def test_text_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ UserMessage( content="What is the capital of France?", @@ -204,19 +220,14 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): for e in events[1:-1]: response += e.delta - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete + events[-1].event_type, ChatCompletionResponseEventType.complete ) # last but 1 event should be of type "progress" self.assertEqual( - events[-2].event_type, - ChatCompletionResponseEventType.progress + events[-2].event_type, ChatCompletionResponseEventType.progress ) self.assertEqual( events[-2].stop_reason, @@ -226,7 +237,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def test_tool_call_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt, UserMessage( @@ -241,19 +252,15 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") events.append(chunk.event) - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete + events[-1].event_type, ChatCompletionResponseEventType.complete ) async def test_custom_tool_call_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt_with_custom_tool, UserMessage( @@ -268,29 +275,49 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") events.append(chunk.event) - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete - ) - self.assertEqual( - events[-1].stop_reason, - StopReason.end_of_turn + events[-1].event_type, ChatCompletionResponseEventType.complete ) + self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) # last but one event should be eom with tool call self.assertEqual( - events[-2].event_type, - ChatCompletionResponseEventType.progress + events[-2].event_type, ChatCompletionResponseEventType.progress ) + self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") + self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) + + def test_resolve_ollama_model(self): + ollama_model = self.api.resolve_ollama_model(self.valid_supported_model) + self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16") + + invalid_model = "Meta-Llama3.1-8B" + with self.assertRaisesRegex( + AssertionError, f"Unsupported model: {invalid_model}" + ): + self.api.resolve_ollama_model(invalid_model) + + async def test_ollama_chat_options(self): + request = ChatCompletionRequest( + model=self.valid_supported_model, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=False, + sampling_params=SamplingParams( + sampling_strategy=SamplingStrategy.top_p, + top_p=0.99, + temperature=1.0, + ), + ) + options = self.api.get_ollama_chat_options(request) self.assertEqual( - events[-2].delta.content.tool_name, - "get_boiling_point" - ) - self.assertEqual( - events[-2].stop_reason, - StopReason.end_of_turn + options, + { + "temperature": 1.0, + "top_p": 0.99, + }, )