Fixes; make inference tests pass with newer tool call types

This commit is contained in:
Ashwin Bharambe 2025-01-13 23:16:16 -08:00
parent d9d34433fc
commit 2c2969f331
5 changed files with 24 additions and 25 deletions

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v4" KEY_VERSION = "v5"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -142,7 +142,7 @@ async def process_completion_stream_response(
text = "" text = ""
continue continue
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=TextDelta(text=text), delta=text,
stop_reason=stop_reason, stop_reason=stop_reason,
) )
if finish_reason: if finish_reason:
@ -153,7 +153,7 @@ async def process_completion_stream_response(
break break
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=TextDelta(text=""), delta="",
stop_reason=stop_reason, stop_reason=stop_reason,
) )

View file

@ -265,6 +265,7 @@ def chat_completion_request_to_messages(
For eg. for llama_3_1, add system message with the appropriate tools or For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc. add user messsage for custom tools, etc.
""" """
assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model) model = resolve_model(llama_model)
if model is None: if model is None:
log.error(f"Could not resolve model {llama_model}") log.error(f"Could not resolve model {llama_model}")

View file

@ -12,6 +12,11 @@ from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def provider_data(): def provider_data():
# check env for tavily secret, brave secret and inject all into provider data # check env for tavily secret, brave secret and inject all into provider data
@ -29,6 +34,7 @@ def llama_stack_client(provider_data):
client = LlamaStackAsLibraryClient( client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"), get_env_or_fail("LLAMA_STACK_CONFIG"),
provider_data=provider_data, provider_data=provider_data,
skip_logger_removal=True,
) )
client.initialize() client.initialize()
elif os.environ.get("LLAMA_STACK_BASE_URL"): elif os.environ.get("LLAMA_STACK_BASE_URL"):

View file

@ -6,9 +6,9 @@
import pytest import pytest
from llama_stack_client.lib.inference.event_logger import EventLogger
from pydantic import BaseModel from pydantic import BaseModel
PROVIDER_TOOL_PROMPT_FORMAT = { PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "python_list", "remote::ollama": "python_list",
"remote::together": "json", "remote::together": "json",
@ -39,7 +39,7 @@ def text_model_id(llama_stack_client):
available_models = [ available_models = [
model.identifier model.identifier
for model in llama_stack_client.models.list() for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") if model.identifier.startswith("meta-llama") and "405" not in model.identifier
] ]
assert len(available_models) > 0 assert len(available_models) > 0
return available_models[0] return available_models[0]
@ -208,12 +208,9 @@ def test_text_chat_completion_streaming(
stream=True, stream=True,
) )
streamed_content = [ streamed_content = [
str(log.content.lower().strip()) str(chunk.event.delta.text.lower().strip()) for chunk in response
for log in EventLogger().log(response)
if log is not None
] ]
assert len(streamed_content) > 0 assert len(streamed_content) > 0
assert "assistant>" in streamed_content[0]
assert expected.lower() in "".join(streamed_content) assert expected.lower() in "".join(streamed_content)
@ -250,17 +247,16 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
def extract_tool_invocation_content(response): def extract_tool_invocation_content(response):
text_content: str = "" text_content: str = ""
tool_invocation_content: str = "" tool_invocation_content: str = ""
for log in EventLogger().log(response): for chunk in response:
if log is None: delta = chunk.event.delta
continue if delta.type == "text":
if isinstance(log.content, str): text_content += delta.text
text_content += log.content elif delta.type == "tool_call":
elif isinstance(log.content, object): if isinstance(delta.content, str):
if isinstance(log.content.content, str): tool_invocation_content += delta.content
continue else:
elif isinstance(log.content.content, object): call = delta.content
tool_invocation_content += f"[{log.content.content.tool_name}, {log.content.content.arguments}]" tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
return text_content, tool_invocation_content return text_content, tool_invocation_content
@ -280,7 +276,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
) )
text_content, tool_invocation_content = extract_tool_invocation_content(response) text_content, tool_invocation_content = extract_tool_invocation_content(response)
assert "Assistant>" in text_content
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
@ -368,10 +363,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
stream=True, stream=True,
) )
streamed_content = [ streamed_content = [
str(log.content.lower().strip()) str(chunk.event.delta.text.lower().strip()) for chunk in response
for log in EventLogger().log(response)
if log is not None
] ]
assert len(streamed_content) > 0 assert len(streamed_content) > 0
assert "assistant>" in streamed_content[0]
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})