From 640c5c54f7761f365a5358a4998b2ecebc872651 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 8 Oct 2024 13:48:44 -0700 Subject: [PATCH] rename augment_messages --- .../adapters/inference/databricks/databricks.py | 6 +++--- .../adapters/inference/fireworks/fireworks.py | 6 +++--- .../providers/adapters/inference/ollama/ollama.py | 6 +++--- llama_stack/providers/adapters/inference/tgi/tgi.py | 6 +++--- .../adapters/inference/together/together.py | 6 +++--- .../impls/meta_reference/inference/inference.py | 8 ++++---- llama_stack/providers/impls/vllm/vllm.py | 7 +++---- .../providers/tests/inference/test_inference.py | 2 +- .../providers/tests/inference/test_prompt_adapter.py | 12 ++++++------ .../{augment_messages.py => prompt_adapter.py} | 8 +++++--- 10 files changed, 34 insertions(+), 33 deletions(-) rename tests/test_augment_messages.py => llama_stack/providers/tests/inference/test_prompt_adapter.py (91%) rename llama_stack/providers/utils/inference/{augment_messages.py => prompt_adapter.py} (96%) diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index f318e6180..847c85eba 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -15,15 +15,15 @@ from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.augment_messages import ( - chat_completion_request_to_prompt, -) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, ) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) from .config import DatabricksImplConfig diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index ce57480a0..c0edc836a 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -15,15 +15,15 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.augment_messages import ( - chat_completion_request_to_prompt, -) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, ) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) from .config import FireworksImplConfig diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 86d72ca7f..fe5e39c30 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -15,9 +15,6 @@ from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.augment_messages import ( - chat_completion_request_to_prompt, -) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, OpenAICompatCompletionChoice, @@ -25,6 +22,9 @@ from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_response, process_chat_completion_stream_response, ) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) OLLAMA_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index bd05f98bb..59eb7f3f1 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -14,9 +14,6 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.augment_messages import ( - chat_completion_request_to_model_input_info, -) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, OpenAICompatCompletionChoice, @@ -24,6 +21,9 @@ from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_response, process_chat_completion_stream_response, ) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_model_input_info, +) from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index adea696fb..0ef5bc593 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -15,15 +15,15 @@ from together import Together from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.utils.inference.augment_messages import ( - chat_completion_request_to_prompt, -) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, ) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) from .config import TogetherImplConfig diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index ad8cc31fd..9e31f0834 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -12,8 +12,8 @@ from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.augment_messages import ( - augment_messages_for_tools, +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_messages, ) from .config import MetaReferenceImplConfig @@ -94,7 +94,7 @@ class MetaReferenceInferenceImpl(Inference): async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) tokens = [] logprobs = [] @@ -136,7 +136,7 @@ class MetaReferenceInferenceImpl(Inference): async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index 748871b4e..e0b063ac9 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -19,10 +19,6 @@ from vllm.sampling_params import SamplingParams from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.augment_messages import ( - chat_completion_request_to_prompt, -) - from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, @@ -30,6 +26,9 @@ from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_response, process_chat_completion_stream_response, ) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) from .config import VLLMConfig diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 107a534d5..b864c2ef4 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str): scope="session", params=[ {"model": Llama_8B}, - # {"model": Llama_3B}, + {"model": Llama_3B}, ], ids=lambda d: d["model"], ) diff --git a/tests/test_augment_messages.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py similarity index 91% rename from tests/test_augment_messages.py rename to llama_stack/providers/tests/inference/test_prompt_adapter.py index 1c2eb62b4..3a1e25d65 100644 --- a/tests/test_augment_messages.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -8,7 +8,7 @@ import unittest from llama_models.llama3.api import * # noqa: F403 from llama_stack.inference.api import * # noqa: F403 -from llama_stack.inference.augment_messages import augment_messages_for_tools +from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages MODEL = "Llama3.1-8B-Instruct" @@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): UserMessage(content=content), ], ) - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.brave_search), ], ) - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ], tool_prompt_format=ToolPromptFormat.json, ) - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ), ], ) - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.code_interpreter), ], ) - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) self.assertEqual(len(messages), 2, messages) self.assertTrue(messages[0].content.endswith(system_prompt)) diff --git a/llama_stack/providers/utils/inference/augment_messages.py b/llama_stack/providers/utils/inference/prompt_adapter.py similarity index 96% rename from llama_stack/providers/utils/inference/augment_messages.py rename to llama_stack/providers/utils/inference/prompt_adapter.py index 8f59b5295..5b8ded52c 100644 --- a/llama_stack/providers/utils/inference/augment_messages.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -26,7 +26,7 @@ from llama_stack.providers.utils.inference import supported_inference_models def chat_completion_request_to_prompt( request: ChatCompletionRequest, formatter: ChatFormat ) -> str: - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) model_input = formatter.encode_dialog_prompt(messages) return formatter.tokenizer.decode(model_input.tokens) @@ -34,7 +34,7 @@ def chat_completion_request_to_prompt( def chat_completion_request_to_model_input_info( request: ChatCompletionRequest, formatter: ChatFormat ) -> Tuple[str, int]: - messages = augment_messages_for_tools(request) + messages = chat_completion_request_to_messages(request) model_input = formatter.encode_dialog_prompt(messages) return ( formatter.tokenizer.decode(model_input.tokens), @@ -42,7 +42,9 @@ def chat_completion_request_to_model_input_info( ) -def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]: +def chat_completion_request_to_messages( + request: ChatCompletionRequest, +) -> List[Message]: """Reads chat completion request and augments the messages to handle tools. For eg. for llama_3_1, add system message with the appropriate tools or add user messsage for custom tools, etc.