mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
meta reference inference fixes (#797)
Miscellaneous fixes for meta reference inference Tests for log probs dont pass because meta reference does not support top_k > 1
This commit is contained in:
parent
cb41848a2a
commit
9f14382d82
5 changed files with 20 additions and 12 deletions
|
@ -263,7 +263,7 @@ class ClientVersionMiddleware:
|
||||||
error_msg = json.dumps(
|
error_msg = json.dumps(
|
||||||
{
|
{
|
||||||
"error": {
|
"error": {
|
||||||
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client."
|
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).encode()
|
).encode()
|
||||||
|
|
|
@ -193,14 +193,14 @@ class MetaReferenceInferenceImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=TextDelta(text=text),
|
delta=text,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=TextDelta(text=""),
|
delta="",
|
||||||
stop_reason=StopReason.out_of_tokens,
|
stop_reason=StopReason.out_of_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -223,10 +223,10 @@ class MetaReferenceInferenceImpl(
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
for token_result in self.generator.completion(request):
|
for token_result in self.generator.completion(request):
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
if token_result.text == "<|eot_id|>":
|
||||||
if token_result.token in tokenizer.stop_tokens:
|
|
||||||
# not quite right semantically
|
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif token_result.text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
@ -243,6 +243,10 @@ class MetaReferenceInferenceImpl(
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||||
|
if content.endswith("<|eot_id|>"):
|
||||||
|
content = content[: -len("<|eot_id|>")]
|
||||||
|
elif content.endswith("<|eom_id|>"):
|
||||||
|
content = content[: -len("<|eom_id|>")]
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
content=content,
|
content=content,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
|
|
|
@ -227,9 +227,11 @@ async def completion_request_to_prompt_model_input_info(
|
||||||
def augment_content_with_response_format_prompt(response_format, content):
|
def augment_content_with_response_format_prompt(response_format, content):
|
||||||
if fmt_prompt := response_format_prompt(response_format):
|
if fmt_prompt := response_format_prompt(response_format):
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
return content + [fmt_prompt]
|
return content + [TextContentItem(text=fmt_prompt)]
|
||||||
|
elif isinstance(content, str):
|
||||||
|
return [TextContentItem(text=content), TextContentItem(text=fmt_prompt)]
|
||||||
else:
|
else:
|
||||||
return [content, fmt_prompt]
|
return [content, TextContentItem(text=fmt_prompt)]
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ class TestClientTool(ClientTool):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def agent_config(llama_stack_client):
|
def model_id(llama_stack_client):
|
||||||
available_models = [
|
available_models = [
|
||||||
model.identifier
|
model.identifier
|
||||||
for model in llama_stack_client.models.list()
|
for model in llama_stack_client.models.list()
|
||||||
|
@ -88,6 +88,11 @@ def agent_config(llama_stack_client):
|
||||||
]
|
]
|
||||||
model_id = available_models[0]
|
model_id = available_models[0]
|
||||||
print(f"Using model: {model_id}")
|
print(f"Using model: {model_id}")
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def agent_config(llama_stack_client, model_id):
|
||||||
available_shields = [
|
available_shields = [
|
||||||
shield.identifier for shield in llama_stack_client.shields.list()
|
shield.identifier for shield in llama_stack_client.shields.list()
|
||||||
]
|
]
|
||||||
|
@ -246,10 +251,8 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
client_tool = TestClientTool()
|
client_tool = TestClientTool()
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
|
||||||
"toolgroups": ["builtin::websearch"],
|
"toolgroups": ["builtin::websearch"],
|
||||||
"client_tools": [client_tool.get_tool_definition()],
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
"tool_prompt_format": "python_list",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||||
|
|
|
@ -229,7 +229,6 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||||
# response to be a tool call
|
# response to be a tool call
|
||||||
assert response.completion_message.content == ""
|
assert response.completion_message.content == ""
|
||||||
assert response.completion_message.role == "assistant"
|
assert response.completion_message.role == "assistant"
|
||||||
assert response.completion_message.stop_reason == "end_of_turn"
|
|
||||||
|
|
||||||
assert len(response.completion_message.tool_calls) == 1
|
assert len(response.completion_message.tool_calls) == 1
|
||||||
assert response.completion_message.tool_calls[0].tool_name == "get_weather"
|
assert response.completion_message.tool_calls[0].tool_name == "get_weather"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue