fix(test): update client-sdk tests to handle tool format parametrization better

This commit is contained in:
Ashwin Bharambe 2025-02-26 20:36:57 -08:00
parent 30ef1c3680
commit 4c0f122a4b
2 changed files with 36 additions and 39 deletions

View file

@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
await end_trace()
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url,
params=options.params,
headers=options.headers or {},
json=options.json_data,
json=convert_pydantic_to_json_value(body),
),
)
response = APIResponse(
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url,
params=options.params,
headers=options.headers or {},
json=options.json_data,
json=convert_pydantic_to_json_value(body),
),
)

View file

@ -7,14 +7,9 @@
import pytest
from pydantic import BaseModel
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.tests.test_cases.test_case import TestCase
PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "json",
"remote::together": "json",
"remote::fireworks": "json",
"remote::vllm": "json",
}
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
@ -24,18 +19,32 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
print(f"Provider: {provider.provider_type} for model {model_id}")
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
@pytest.fixture(scope="session")
def provider_tool_format(inference_provider_type):
return (
PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type]
if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT
else None
)
def get_llama_model(client_with_models, model_id):
models = {}
for m in client_with_models.models.list():
models[m.identifier] = m
models[m.provider_resource_id] = m
assert model_id in models, f"Model {model_id} not found"
model = models[model_id]
ids = (model.identifier, model.provider_resource_id)
for mid in ids:
if resolve_model(mid):
return mid
return model.metadata.get("llama_model", None)
def get_tool_prompt_format(client_with_models, model_id):
llama_model = get_llama_model(client_with_models, model_id)
if not llama_model:
return None
return get_default_tool_prompt_format(llama_model)
@pytest.mark.parametrize(
@ -237,12 +246,8 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_non_streaming(
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -280,12 +285,8 @@ def extract_tool_invocation_content(response):
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_streaming(
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -308,14 +309,8 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_required(
client_with_models,
text_model_id,
provider_tool_format,
test_case,
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case)
@ -341,14 +336,15 @@ def test_text_chat_completion_with_tool_choice_required(
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=tc["messages"],
tools=tc["tools"],
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)