forked from phoenix-oss/llama-stack-mirror
Fixes; make inference tests pass with newer tool call types
This commit is contained in:
parent
d9d34433fc
commit
2c2969f331
5 changed files with 24 additions and 25 deletions
|
@ -6,9 +6,9 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||
"remote::ollama": "python_list",
|
||||
"remote::together": "json",
|
||||
|
@ -39,7 +39,7 @@ def text_model_id(llama_stack_client):
|
|||
available_models = [
|
||||
model.identifier
|
||||
for model in llama_stack_client.models.list()
|
||||
if model.identifier.startswith("meta-llama")
|
||||
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||
]
|
||||
assert len(available_models) > 0
|
||||
return available_models[0]
|
||||
|
@ -208,12 +208,9 @@ def test_text_chat_completion_streaming(
|
|||
stream=True,
|
||||
)
|
||||
streamed_content = [
|
||||
str(log.content.lower().strip())
|
||||
for log in EventLogger().log(response)
|
||||
if log is not None
|
||||
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
||||
]
|
||||
assert len(streamed_content) > 0
|
||||
assert "assistant>" in streamed_content[0]
|
||||
assert expected.lower() in "".join(streamed_content)
|
||||
|
||||
|
||||
|
@ -250,17 +247,16 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
|||
def extract_tool_invocation_content(response):
|
||||
text_content: str = ""
|
||||
tool_invocation_content: str = ""
|
||||
for log in EventLogger().log(response):
|
||||
if log is None:
|
||||
continue
|
||||
if isinstance(log.content, str):
|
||||
text_content += log.content
|
||||
elif isinstance(log.content, object):
|
||||
if isinstance(log.content.content, str):
|
||||
continue
|
||||
elif isinstance(log.content.content, object):
|
||||
tool_invocation_content += f"[{log.content.content.tool_name}, {log.content.content.arguments}]"
|
||||
|
||||
for chunk in response:
|
||||
delta = chunk.event.delta
|
||||
if delta.type == "text":
|
||||
text_content += delta.text
|
||||
elif delta.type == "tool_call":
|
||||
if isinstance(delta.content, str):
|
||||
tool_invocation_content += delta.content
|
||||
else:
|
||||
call = delta.content
|
||||
tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
|
||||
return text_content, tool_invocation_content
|
||||
|
||||
|
||||
|
@ -280,7 +276,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
|||
)
|
||||
text_content, tool_invocation_content = extract_tool_invocation_content(response)
|
||||
|
||||
assert "Assistant>" in text_content
|
||||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||
|
||||
|
||||
|
@ -368,10 +363,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
|||
stream=True,
|
||||
)
|
||||
streamed_content = [
|
||||
str(log.content.lower().strip())
|
||||
for log in EventLogger().log(response)
|
||||
if log is not None
|
||||
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
||||
]
|
||||
assert len(streamed_content) > 0
|
||||
assert "assistant>" in streamed_content[0]
|
||||
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue