diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 2c222ffa1..4826e89d5 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -6,8 +6,14 @@ import unittest -from llama_models.llama3.api import * # noqa: F403 -from llama_stack.apis.inference.inference import * # noqa: F403 +from llama_models.llama3.api.datatypes import ( + BuiltinTool, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) + +from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) @@ -24,7 +30,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): UserMessage(content=content), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -41,7 +47,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.brave_search), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -69,7 +75,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ], tool_prompt_format=ToolPromptFormat.json, ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -99,7 +105,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -121,7 +127,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.code_interpreter), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2, messages) self.assertTrue(messages[0].content.endswith(system_prompt))