rename augment_messages

This commit is contained in:
Ashwin Bharambe 2024-10-08 13:48:44 -07:00 committed by Ashwin Bharambe
parent 336cf7a674
commit 640c5c54f7
10 changed files with 34 additions and 33 deletions

View file

@ -15,15 +15,15 @@ from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import DatabricksImplConfig

View file

@ -15,15 +15,15 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import FireworksImplConfig

View file

@ -15,9 +15,6 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@ -25,6 +22,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",

View file

@ -14,9 +14,6 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
chat_completion_request_to_model_input_info,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@ -24,6 +21,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig

View file

@ -15,15 +15,15 @@ from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import (
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import TogetherImplConfig

View file

@ -12,8 +12,8 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceImplConfig
@ -94,7 +94,7 @@ class MetaReferenceInferenceImpl(Inference):
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
tokens = []
logprobs = []
@ -136,7 +136,7 @@ class MetaReferenceInferenceImpl(Inference):
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(

View file

@ -19,10 +19,6 @@ from vllm.sampling_params import SamplingParams
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -30,6 +26,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMConfig

View file

@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
scope="session",
params=[
{"model": Llama_8B},
# {"model": Llama_3B},
{"model": Llama_3B},
],
ids=lambda d: d["model"],
)

View file

@ -8,7 +8,7 @@ import unittest
from llama_models.llama3.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.augment_messages import augment_messages_for_tools
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
MODEL = "Llama3.1-8B-Instruct"
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content),
],
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_prompt_format=ToolPromptFormat.json,
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
),
],
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -26,7 +26,7 @@ from llama_stack.providers.utils.inference import supported_inference_models
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
) -> str:
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
model_input = formatter.encode_dialog_prompt(messages)
return formatter.tokenizer.decode(model_input.tokens)
@ -34,7 +34,7 @@ def chat_completion_request_to_prompt(
def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
model_input = formatter.encode_dialog_prompt(messages)
return (
formatter.tokenizer.decode(model_input.tokens),
@ -42,7 +42,9 @@ def chat_completion_request_to_model_input_info(
)
def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
) -> List[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.