mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix(test): update client-sdk tests to handle tool format parametrization better (#1287)
# What does this PR do? Tool format depends on the model. @ehhuang introduced a `get_default_tool_prompt_format` function for this purpose. We should use that instead of hacky model ID matching we had before. Secondly, non llama models don't have this concept so testing with those models should work as is. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ```bash for distro in fireworks ollama; do LLAMA_STACK_CONFIG=$distro \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=meta-llama/Llama-3.2-3B-Instruct \ --vision-inference-model="" done LLAMA_STACK_CONFIG=dev \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=openai/gpt-4o \ --vision-inference-model="" ``` [//]: # (## Documentation)
This commit is contained in:
parent
30ef1c3680
commit
23b65b6cee
3 changed files with 64 additions and 63 deletions
|
@ -7,14 +7,9 @@
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
||||
|
||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||
"remote::ollama": "json",
|
||||
"remote::together": "json",
|
||||
"remote::fireworks": "json",
|
||||
"remote::vllm": "json",
|
||||
}
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||
|
||||
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
||||
|
||||
|
@ -24,18 +19,32 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
provider = providers[provider_id]
|
||||
print(f"Provider: {provider.provider_type} for model {model_id}")
|
||||
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def provider_tool_format(inference_provider_type):
|
||||
return (
|
||||
PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type]
|
||||
if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT
|
||||
else None
|
||||
)
|
||||
def get_llama_model(client_with_models, model_id):
|
||||
models = {}
|
||||
for m in client_with_models.models.list():
|
||||
models[m.identifier] = m
|
||||
models[m.provider_resource_id] = m
|
||||
|
||||
assert model_id in models, f"Model {model_id} not found"
|
||||
|
||||
model = models[model_id]
|
||||
ids = (model.identifier, model.provider_resource_id)
|
||||
for mid in ids:
|
||||
if resolve_model(mid):
|
||||
return mid
|
||||
|
||||
return model.metadata.get("llama_model", None)
|
||||
|
||||
|
||||
def get_tool_prompt_format(client_with_models, model_id):
|
||||
llama_model = get_llama_model(client_with_models, model_id)
|
||||
if not llama_model:
|
||||
return None
|
||||
return get_default_tool_prompt_format(llama_model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -237,12 +246,8 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||
client_with_models, text_model_id, provider_tool_format, test_case
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
|
||||
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
@ -280,12 +285,8 @@ def extract_tool_invocation_content(response):
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||
client_with_models, text_model_id, provider_tool_format, test_case
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
|
||||
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
@ -308,14 +309,8 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
client_with_models,
|
||||
text_model_id,
|
||||
provider_tool_format,
|
||||
test_case,
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
|
||||
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
|
@ -341,14 +336,15 @@ def test_text_chat_completion_with_tool_choice_required(
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
|
||||
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
|
||||
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=tc["messages"],
|
||||
tools=tc["tools"],
|
||||
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
|
||||
tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format},
|
||||
stream=True,
|
||||
)
|
||||
tool_invocation_content = extract_tool_invocation_content(response)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue