From 9f14382d82a266104825c53fbcff221676b19b64 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 16 Jan 2025 18:17:46 -0800 Subject: [PATCH] meta reference inference fixes (#797) Miscellaneous fixes for meta reference inference Tests for log probs dont pass because meta reference does not support top_k > 1 --- llama_stack/distribution/server/server.py | 2 +- .../inline/inference/meta_reference/inference.py | 14 +++++++++----- .../providers/utils/inference/prompt_adapter.py | 6 ++++-- tests/client-sdk/agents/test_agents.py | 9 ++++++--- tests/client-sdk/inference/test_inference.py | 1 - 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index bb9ef0361..8dbb193b9 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -263,7 +263,7 @@ class ClientVersionMiddleware: error_msg = json.dumps( { "error": { - "message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client." + "message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client." } } ).encode() diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index d64d32f03..31ad6fa28 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -193,14 +193,14 @@ class MetaReferenceInferenceImpl( ] yield CompletionResponseStreamChunk( - delta=TextDelta(text=text), + delta=text, stop_reason=stop_reason, logprobs=logprobs if request.logprobs else None, ) if stop_reason is None: yield CompletionResponseStreamChunk( - delta=TextDelta(text=""), + delta="", stop_reason=StopReason.out_of_tokens, ) @@ -223,10 +223,10 @@ class MetaReferenceInferenceImpl( tokenizer = self.generator.formatter.tokenizer for token_result in self.generator.completion(request): tokens.append(token_result.token) - - if token_result.token in tokenizer.stop_tokens: - # not quite right semantically + if token_result.text == "<|eot_id|>": stop_reason = StopReason.end_of_turn + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message if request.logprobs: assert len(token_result.logprobs) == 1 @@ -243,6 +243,10 @@ class MetaReferenceInferenceImpl( stop_reason = StopReason.out_of_tokens content = self.generator.formatter.tokenizer.decode(tokens) + if content.endswith("<|eot_id|>"): + content = content[: -len("<|eot_id|>")] + elif content.endswith("<|eom_id|>"): + content = content[: -len("<|eom_id|>")] return CompletionResponse( content=content, stop_reason=stop_reason, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index de4918f5c..7ee19fd7b 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -227,9 +227,11 @@ async def completion_request_to_prompt_model_input_info( def augment_content_with_response_format_prompt(response_format, content): if fmt_prompt := response_format_prompt(response_format): if isinstance(content, list): - return content + [fmt_prompt] + return content + [TextContentItem(text=fmt_prompt)] + elif isinstance(content, str): + return [TextContentItem(text=content), TextContentItem(text=fmt_prompt)] else: - return [content, fmt_prompt] + return [content, TextContentItem(text=fmt_prompt)] return content diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index f9b55b5cd..d6d88a34f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -80,7 +80,7 @@ class TestClientTool(ClientTool): @pytest.fixture(scope="session") -def agent_config(llama_stack_client): +def model_id(llama_stack_client): available_models = [ model.identifier for model in llama_stack_client.models.list() @@ -88,6 +88,11 @@ def agent_config(llama_stack_client): ] model_id = available_models[0] print(f"Using model: {model_id}") + return model_id + + +@pytest.fixture(scope="session") +def agent_config(llama_stack_client, model_id): available_shields = [ shield.identifier for shield in llama_stack_client.shields.list() ] @@ -246,10 +251,8 @@ def test_custom_tool(llama_stack_client, agent_config): client_tool = TestClientTool() agent_config = { **agent_config, - "model": "meta-llama/Llama-3.2-3B-Instruct", "toolgroups": ["builtin::websearch"], "client_tools": [client_tool.get_tool_definition()], - "tool_prompt_format": "python_list", } agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 671a37926..19314e4ab 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -229,7 +229,6 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( # response to be a tool call assert response.completion_message.content == "" assert response.completion_message.role == "assistant" - assert response.completion_message.stop_reason == "end_of_turn" assert len(response.completion_message.tool_calls) == 1 assert response.completion_message.tool_calls[0].tool_name == "get_weather"