diff --git a/tests/test_inference.py b/tests/test_inference.py index bbd4a2c1a..ad7bf6d19 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -10,14 +10,12 @@ from datetime import datetime from llama_models.llama3_1.api.datatypes import ( BuiltinTool, - InstructModel, UserMessage, StopReason, SystemMessage, ) from llama_toolchain.inference.api.config import ( - ImplType, InferenceConfig, InlineImplConfig, RemoteImplConfig, @@ -31,11 +29,7 @@ from llama_toolchain.inference.api_instance import ( from llama_toolchain.inference.api.datatypes import ( ChatCompletionResponseEventType, ) -from llama_toolchain.inference.api.endpoints import ( - ChatCompletionRequest -) -from llama_toolchain.inference.inference import InferenceImpl -from llama_toolchain.inference.event_logger import EventLogger +from llama_toolchain.inference.api.endpoints import ChatCompletionRequest HELPER_MSG = """ @@ -56,7 +50,9 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): @classmethod async def asyncSetUpClass(cls): # assert model exists on local - model_dir = os.path.expanduser("~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/") + model_dir = os.path.expanduser( + "~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/" + ) assert os.path.isdir(model_dir), HELPER_MSG tokenizer_path = os.path.join(model_dir, "tokenizer.model") @@ -73,17 +69,11 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): ), max_seq_len=2048, ) - inference_config = InferenceConfig( - impl_config=inline_config - ) + inference_config = InferenceConfig(impl_config=inline_config) # -- For faster testing iteration -- - # remote_config = RemoteImplConfig( - # url="http://localhost:5000" - # ) - # inference_config = InferenceConfig( - # impl_config=remote_config - # ) + # remote_config = RemoteImplConfig(url="http://localhost:5000") + # inference_config = InferenceConfig(impl_config=remote_config) cls.api = await get_inference_api_instance(inference_config) await cls.api.initialize() @@ -91,17 +81,20 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): current_date = datetime.now() formatted_date = current_date.strftime("%d %B %Y") cls.system_prompt = SystemMessage( - content=textwrap.dedent(f""" + content=textwrap.dedent( + f""" Environment: ipython Tools: brave_search Cutting Knowledge Date: December 2023 Today Date:{formatted_date} - """), + """ + ), ) cls.system_prompt_with_custom_tool = SystemMessage( - content=textwrap.dedent(""" + content=textwrap.dedent( + """ Environment: ipython Tools: brave_search, wolfram_alpha, photogen @@ -140,9 +133,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def asyncTearDownClass(cls): await cls.api.shutdown() + async def asyncSetUp(self): + self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" + async def test_text(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ UserMessage( content="What is the capital of France?", @@ -160,7 +156,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def test_text_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ UserMessage( content="What is the capital of France?", @@ -175,13 +171,9 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): events.append(chunk.event) # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) - self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete + events[-1].event_type, ChatCompletionResponseEventType.complete ) response = "" @@ -192,7 +184,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def test_custom_tool_call(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ InferenceTests.system_prompt_with_custom_tool, UserMessage( @@ -214,8 +206,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): # instead of end_of_turn # self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) - self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, "get_boiling_point" + ) args = completion_message.tool_calls[0].arguments self.assertTrue(isinstance(args, dict)) @@ -223,7 +219,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def test_tool_call_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ self.system_prompt, UserMessage( @@ -239,32 +235,21 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") events.append(chunk.event) - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete + events[-1].event_type, ChatCompletionResponseEventType.complete ) # last but one event should be eom with tool call self.assertEqual( - events[-2].event_type, - ChatCompletionResponseEventType.progress - ) - self.assertEqual( - events[-2].stop_reason, - StopReason.end_of_message - ) - self.assertEqual( - events[-2].delta.content.tool_name, - BuiltinTool.brave_search + events[-2].event_type, ChatCompletionResponseEventType.progress ) + self.assertEqual(events[-2].stop_reason, StopReason.end_of_message) + self.assertEqual(events[-2].delta.content.tool_name, BuiltinTool.brave_search) async def test_custom_tool_call_streaming(self): request = ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model=self.valid_supported_model, messages=[ InferenceTests.system_prompt_with_custom_tool, UserMessage( @@ -279,29 +264,15 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") events.append(chunk.event) - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete - ) - self.assertEqual( - events[-1].stop_reason, - StopReason.end_of_turn + events[-1].event_type, ChatCompletionResponseEventType.complete ) + self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) # last but one event should be eom with tool call self.assertEqual( - events[-2].event_type, - ChatCompletionResponseEventType.progress - ) - self.assertEqual( - events[-2].stop_reason, - StopReason.end_of_turn - ) - self.assertEqual( - events[-2].delta.content.tool_name, - "get_boiling_point" + events[-2].event_type, ChatCompletionResponseEventType.progress ) + self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) + self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index c0628fc73..67493db25 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -16,23 +16,16 @@ from llama_toolchain.inference.api_instance import ( from llama_toolchain.inference.api.datatypes import ( ChatCompletionResponseEventType, ) -from llama_toolchain.inference.api.endpoints import ( - ChatCompletionRequest -) -from llama_toolchain.inference.api.config import ( - InferenceConfig, - OllamaImplConfig -) -from llama_toolchain.inference.ollama import ( - OllamaInference -) +from llama_toolchain.inference.api.endpoints import ChatCompletionRequest +from llama_toolchain.inference.api.config import InferenceConfig, OllamaImplConfig class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): + self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" ollama_config = OllamaImplConfig( - model="llama3.1", + model="llama3.1:8b-instruct-fp16", url="http://localhost:11434", ) @@ -45,18 +38,21 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): current_date = datetime.now() formatted_date = current_date.strftime("%d %B %Y") self.system_prompt = SystemMessage( - content=textwrap.dedent(f""" + content=textwrap.dedent( + f""" Environment: ipython Tools: brave_search Cutting Knowledge Date: December 2023 Today Date:{formatted_date} - """), + """ + ), ) self.system_prompt_with_custom_tool = SystemMessage( - content=textwrap.dedent(""" + content=textwrap.dedent( + """ Environment: ipython Tools: brave_search, wolfram_alpha, photogen @@ -79,7 +75,6 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): - If looking for real time information use relevant functions before falling back to brave_search - Function calls MUST follow the specified format, start with - Required parameters MUST be specified - - Only call one function at a time - Put the entire function call reply on one line """ @@ -105,7 +100,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): response = r self.assertTrue("Paris" in response.completion_message.content) - self.assertEqual(response.completion_message.stop_reason, StopReason.end_of_turn) + self.assertEqual( + response.completion_message.stop_reason, StopReason.end_of_turn + ) async def test_tool_call(self): request = ChatCompletionRequest( @@ -127,8 +124,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(completion_message.content, "") self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) - self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search) + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search + ) self.assertTrue( "president" in completion_message.tool_calls[0].arguments["query"].lower() ) @@ -153,8 +154,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(completion_message.content, "") self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) - self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter) + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter + ) code = completion_message.tool_calls[0].arguments["code"] self.assertTrue("def " in code.lower(), code) @@ -177,20 +182,24 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(completion_message.content, "") self.assertTrue( - completion_message.stop_reason in { + completion_message.stop_reason + in { StopReason.end_of_turn, StopReason.end_of_message, } ) - self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) - self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") + self.assertEqual( + len(completion_message.tool_calls), 1, completion_message.tool_calls + ) + self.assertEqual( + completion_message.tool_calls[0].tool_name, "get_boiling_point" + ) args = completion_message.tool_calls[0].arguments self.assertTrue(isinstance(args, dict)) self.assertTrue(args["liquid_name"], "polyjuice") - async def test_text_streaming(self): request = ChatCompletionRequest( model=self.valid_supported_model, @@ -211,19 +220,14 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): for e in events[1:-1]: response += e.delta - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete + events[-1].event_type, ChatCompletionResponseEventType.complete ) # last but 1 event should be of type "progress" self.assertEqual( - events[-2].event_type, - ChatCompletionResponseEventType.progress + events[-2].event_type, ChatCompletionResponseEventType.progress ) self.assertEqual( events[-2].stop_reason, @@ -248,14 +252,10 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") events.append(chunk.event) - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete + events[-1].event_type, ChatCompletionResponseEventType.complete ) async def test_custom_tool_call_streaming(self): @@ -275,32 +275,18 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") events.append(chunk.event) - self.assertEqual( - events[0].event_type, - ChatCompletionResponseEventType.start - ) + self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) # last event is of type "complete" self.assertEqual( - events[-1].event_type, - ChatCompletionResponseEventType.complete - ) - self.assertEqual( - events[-1].stop_reason, - StopReason.end_of_turn + events[-1].event_type, ChatCompletionResponseEventType.complete ) + self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) # last but one event should be eom with tool call self.assertEqual( - events[-2].event_type, - ChatCompletionResponseEventType.progress - ) - self.assertEqual( - events[-2].delta.content.tool_name, - "get_boiling_point" - ) - self.assertEqual( - events[-2].stop_reason, - StopReason.end_of_turn + events[-2].event_type, ChatCompletionResponseEventType.progress ) + self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") + self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) def test_resolve_ollama_model(self): ollama_model = self.api.resolve_ollama_model(self.valid_supported_model) @@ -325,7 +311,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): sampling_strategy=SamplingStrategy.top_p, top_p=0.99, temperature=1.0, - ) + ), ) options = self.api.get_ollama_chat_options(request) self.assertEqual( @@ -333,5 +319,5 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): { "temperature": 1.0, "top_p": 0.99, - } + }, )