mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Fixes; make inference tests pass with newer tool call types
This commit is contained in:
parent
d9d34433fc
commit
2c2969f331
5 changed files with 24 additions and 25 deletions
|
@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v4"
|
KEY_VERSION = "v5"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -142,7 +142,7 @@ async def process_completion_stream_response(
|
||||||
text = ""
|
text = ""
|
||||||
continue
|
continue
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=TextDelta(text=text),
|
delta=text,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
|
@ -153,7 +153,7 @@ async def process_completion_stream_response(
|
||||||
break
|
break
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=TextDelta(text=""),
|
delta="",
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -265,6 +265,7 @@ def chat_completion_request_to_messages(
|
||||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||||
add user messsage for custom tools, etc.
|
add user messsage for custom tools, etc.
|
||||||
"""
|
"""
|
||||||
|
assert llama_model is not None, "llama_model is required"
|
||||||
model = resolve_model(llama_model)
|
model = resolve_model(llama_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
log.error(f"Could not resolve model {llama_model}")
|
log.error(f"Could not resolve model {llama_model}")
|
||||||
|
|
|
@ -12,6 +12,11 @@ from llama_stack.providers.tests.env import get_env_or_fail
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.option.tbstyle = "short"
|
||||||
|
config.option.disable_warnings = True
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def provider_data():
|
def provider_data():
|
||||||
# check env for tavily secret, brave secret and inject all into provider data
|
# check env for tavily secret, brave secret and inject all into provider data
|
||||||
|
@ -29,6 +34,7 @@ def llama_stack_client(provider_data):
|
||||||
client = LlamaStackAsLibraryClient(
|
client = LlamaStackAsLibraryClient(
|
||||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||||
provider_data=provider_data,
|
provider_data=provider_data,
|
||||||
|
skip_logger_removal=True,
|
||||||
)
|
)
|
||||||
client.initialize()
|
client.initialize()
|
||||||
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
||||||
|
|
|
@ -6,9 +6,9 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||||
"remote::ollama": "python_list",
|
"remote::ollama": "python_list",
|
||||||
"remote::together": "json",
|
"remote::together": "json",
|
||||||
|
@ -39,7 +39,7 @@ def text_model_id(llama_stack_client):
|
||||||
available_models = [
|
available_models = [
|
||||||
model.identifier
|
model.identifier
|
||||||
for model in llama_stack_client.models.list()
|
for model in llama_stack_client.models.list()
|
||||||
if model.identifier.startswith("meta-llama")
|
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||||
]
|
]
|
||||||
assert len(available_models) > 0
|
assert len(available_models) > 0
|
||||||
return available_models[0]
|
return available_models[0]
|
||||||
|
@ -208,12 +208,9 @@ def test_text_chat_completion_streaming(
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
streamed_content = [
|
streamed_content = [
|
||||||
str(log.content.lower().strip())
|
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
||||||
for log in EventLogger().log(response)
|
|
||||||
if log is not None
|
|
||||||
]
|
]
|
||||||
assert len(streamed_content) > 0
|
assert len(streamed_content) > 0
|
||||||
assert "assistant>" in streamed_content[0]
|
|
||||||
assert expected.lower() in "".join(streamed_content)
|
assert expected.lower() in "".join(streamed_content)
|
||||||
|
|
||||||
|
|
||||||
|
@ -250,17 +247,16 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||||
def extract_tool_invocation_content(response):
|
def extract_tool_invocation_content(response):
|
||||||
text_content: str = ""
|
text_content: str = ""
|
||||||
tool_invocation_content: str = ""
|
tool_invocation_content: str = ""
|
||||||
for log in EventLogger().log(response):
|
for chunk in response:
|
||||||
if log is None:
|
delta = chunk.event.delta
|
||||||
continue
|
if delta.type == "text":
|
||||||
if isinstance(log.content, str):
|
text_content += delta.text
|
||||||
text_content += log.content
|
elif delta.type == "tool_call":
|
||||||
elif isinstance(log.content, object):
|
if isinstance(delta.content, str):
|
||||||
if isinstance(log.content.content, str):
|
tool_invocation_content += delta.content
|
||||||
continue
|
else:
|
||||||
elif isinstance(log.content.content, object):
|
call = delta.content
|
||||||
tool_invocation_content += f"[{log.content.content.tool_name}, {log.content.content.arguments}]"
|
tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
|
||||||
|
|
||||||
return text_content, tool_invocation_content
|
return text_content, tool_invocation_content
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,7 +276,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||||
)
|
)
|
||||||
text_content, tool_invocation_content = extract_tool_invocation_content(response)
|
text_content, tool_invocation_content = extract_tool_invocation_content(response)
|
||||||
|
|
||||||
assert "Assistant>" in text_content
|
|
||||||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||||
|
|
||||||
|
|
||||||
|
@ -368,10 +363,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
streamed_content = [
|
streamed_content = [
|
||||||
str(log.content.lower().strip())
|
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
||||||
for log in EventLogger().log(response)
|
|
||||||
if log is not None
|
|
||||||
]
|
]
|
||||||
assert len(streamed_content) > 0
|
assert len(streamed_content) > 0
|
||||||
assert "assistant>" in streamed_content[0]
|
|
||||||
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue