mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
rename augment_messages
This commit is contained in:
parent
336cf7a674
commit
640c5c54f7
10 changed files with 34 additions and 33 deletions
|
@ -15,15 +15,15 @@ from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import DatabricksImplConfig
|
from .config import DatabricksImplConfig
|
||||||
|
|
||||||
|
|
|
@ -15,15 +15,15 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,6 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -25,6 +22,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
OLLAMA_SUPPORTED_MODELS = {
|
OLLAMA_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||||
|
|
|
@ -14,9 +14,6 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
|
||||||
chat_completion_request_to_model_input_info,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -24,6 +21,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_model_input_info,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||||
|
|
||||||
|
|
|
@ -15,15 +15,15 @@ from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,8 @@ from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
augment_messages_for_tools,
|
chat_completion_request_to_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
@ -94,7 +94,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
@ -136,7 +136,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
|
|
@ -19,10 +19,6 @@ from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -30,6 +26,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import VLLMConfig
|
from .config import VLLMConfig
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
|
||||||
scope="session",
|
scope="session",
|
||||||
params=[
|
params=[
|
||||||
{"model": Llama_8B},
|
{"model": Llama_8B},
|
||||||
# {"model": Llama_3B},
|
{"model": Llama_3B},
|
||||||
],
|
],
|
||||||
ids=lambda d: d["model"],
|
ids=lambda d: d["model"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,7 @@ import unittest
|
||||||
|
|
||||||
from llama_models.llama3.api import * # noqa: F403
|
from llama_models.llama3.api import * # noqa: F403
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.inference.api import * # noqa: F403
|
||||||
from llama_stack.inference.augment_messages import augment_messages_for_tools
|
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
|
||||||
|
|
||||||
MODEL = "Llama3.1-8B-Instruct"
|
MODEL = "Llama3.1-8B-Instruct"
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
UserMessage(content=content),
|
UserMessage(content=content),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 2)
|
self.assertEqual(len(messages), 2)
|
||||||
self.assertEqual(messages[-1].content, content)
|
self.assertEqual(messages[-1].content, content)
|
||||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||||
|
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 2)
|
self.assertEqual(len(messages), 2)
|
||||||
self.assertEqual(messages[-1].content, content)
|
self.assertEqual(messages[-1].content, content)
|
||||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||||
|
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
tool_prompt_format=ToolPromptFormat.json,
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 3)
|
self.assertEqual(len(messages), 3)
|
||||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 3)
|
self.assertEqual(len(messages), 3)
|
||||||
|
|
||||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||||
|
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 2, messages)
|
self.assertEqual(len(messages), 2, messages)
|
||||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||||
|
|
|
@ -26,7 +26,7 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
def chat_completion_request_to_prompt(
|
def chat_completion_request_to_prompt(
|
||||||
request: ChatCompletionRequest, formatter: ChatFormat
|
request: ChatCompletionRequest, formatter: ChatFormat
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
model_input = formatter.encode_dialog_prompt(messages)
|
model_input = formatter.encode_dialog_prompt(messages)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ def chat_completion_request_to_prompt(
|
||||||
def chat_completion_request_to_model_input_info(
|
def chat_completion_request_to_model_input_info(
|
||||||
request: ChatCompletionRequest, formatter: ChatFormat
|
request: ChatCompletionRequest, formatter: ChatFormat
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
model_input = formatter.encode_dialog_prompt(messages)
|
model_input = formatter.encode_dialog_prompt(messages)
|
||||||
return (
|
return (
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
formatter.tokenizer.decode(model_input.tokens),
|
||||||
|
@ -42,7 +42,9 @@ def chat_completion_request_to_model_input_info(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
|
def chat_completion_request_to_messages(
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> List[Message]:
|
||||||
"""Reads chat completion request and augments the messages to handle tools.
|
"""Reads chat completion request and augments the messages to handle tools.
|
||||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||||
add user messsage for custom tools, etc.
|
add user messsage for custom tools, etc.
|
Loading…
Add table
Add a link
Reference in a new issue