mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
rename augment_messages
This commit is contained in:
parent
336cf7a674
commit
640c5c54f7
10 changed files with 34 additions and 33 deletions
|
@ -15,15 +15,15 @@ from openai import OpenAI
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
|
|
|
@ -15,15 +15,15 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
|
|
@ -15,9 +15,6 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -25,6 +22,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
OLLAMA_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
|
|
|
@ -14,9 +14,6 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -24,6 +21,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
|
|
|
@ -15,15 +15,15 @@ from together import Together
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ from llama_models.sku_list import resolve_model
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
@ -94,7 +94,7 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
@ -136,7 +136,7 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
|
@ -19,10 +19,6 @@ from vllm.sampling_params import SamplingParams
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -30,6 +26,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
|
|||
scope="session",
|
||||
params=[
|
||||
{"model": Llama_8B},
|
||||
# {"model": Llama_3B},
|
||||
{"model": Llama_3B},
|
||||
],
|
||||
ids=lambda d: d["model"],
|
||||
)
|
||||
|
|
|
@ -8,7 +8,7 @@ import unittest
|
|||
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.augment_messages import augment_messages_for_tools
|
||||
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
|
||||
|
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
|
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
),
|
||||
],
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
|
@ -26,7 +26,7 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
|||
def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
model_input = formatter.encode_dialog_prompt(messages)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
@ -34,7 +34,7 @@ def chat_completion_request_to_prompt(
|
|||
def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
model_input = formatter.encode_dialog_prompt(messages)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
|
@ -42,7 +42,9 @@ def chat_completion_request_to_model_input_info(
|
|||
)
|
||||
|
||||
|
||||
def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
|
||||
def chat_completion_request_to_messages(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
"""Reads chat completion request and augments the messages to handle tools.
|
||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||
add user messsage for custom tools, etc.
|
Loading…
Add table
Add a link
Reference in a new issue