fix(test): update client-sdk tests to handle tool format parametrization better

This commit is contained in:
Ashwin Bharambe 2025-02-26 20:36:57 -08:00
parent 30ef1c3680
commit 4c0f122a4b
2 changed files with 36 additions and 39 deletions

View file

@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
await end_trace() await end_trace()
json_content = json.dumps(convert_pydantic_to_json_value(result)) json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=json_content.encode("utf-8"), content=json_content.encode("utf-8"),
@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers or {}, headers=options.headers or {},
json=options.json_data, json=convert_pydantic_to_json_value(body),
), ),
) )
response = APIResponse( response = APIResponse(
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers or {}, headers=options.headers or {},
json=options.json_data, json=convert_pydantic_to_json_value(body),
), ),
) )

View file

@ -7,14 +7,9 @@
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.tests.test_cases.test_case import TestCase from llama_stack.providers.tests.test_cases.test_case import TestCase
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "json",
"remote::together": "json",
"remote::fireworks": "json",
"remote::vllm": "json",
}
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
@ -24,18 +19,32 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
provider_id = models[model_id].provider_id provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()} providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id] provider = providers[provider_id]
print(f"Provider: {provider.provider_type} for model {model_id}")
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"): if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
@pytest.fixture(scope="session") def get_llama_model(client_with_models, model_id):
def provider_tool_format(inference_provider_type): models = {}
return ( for m in client_with_models.models.list():
PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type] models[m.identifier] = m
if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT models[m.provider_resource_id] = m
else None
) assert model_id in models, f"Model {model_id} not found"
model = models[model_id]
ids = (model.identifier, model.provider_resource_id)
for mid in ids:
if resolve_model(mid):
return mid
return model.metadata.get("llama_model", None)
def get_tool_prompt_format(client_with_models, model_id):
llama_model = get_llama_model(client_with_models, model_id)
if not llama_model:
return None
return get_default_tool_prompt_format(llama_model)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -237,12 +246,8 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
"inference:chat_completion:tool_calling", "inference:chat_completion:tool_calling",
], ],
) )
def test_text_chat_completion_with_tool_calling_and_non_streaming( def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
client_with_models, text_model_id, provider_tool_format, test_case tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case) tc = TestCase(test_case)
response = client_with_models.inference.chat_completion( response = client_with_models.inference.chat_completion(
@ -280,12 +285,8 @@ def extract_tool_invocation_content(response):
"inference:chat_completion:tool_calling", "inference:chat_completion:tool_calling",
], ],
) )
def test_text_chat_completion_with_tool_calling_and_streaming( def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
client_with_models, text_model_id, provider_tool_format, test_case tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case) tc = TestCase(test_case)
response = client_with_models.inference.chat_completion( response = client_with_models.inference.chat_completion(
@ -308,14 +309,8 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
"inference:chat_completion:tool_calling", "inference:chat_completion:tool_calling",
], ],
) )
def test_text_chat_completion_with_tool_choice_required( def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
client_with_models, tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
text_model_id,
provider_tool_format,
test_case,
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case) tc = TestCase(test_case)
@ -341,14 +336,15 @@ def test_text_chat_completion_with_tool_choice_required(
"inference:chat_completion:tool_calling", "inference:chat_completion:tool_calling",
], ],
) )
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case): def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case) tc = TestCase(test_case)
response = client_with_models.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=tc["messages"], messages=tc["messages"],
tools=tc["tools"], tools=tc["tools"],
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format}, tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format},
stream=True, stream=True,
) )
tool_invocation_content = extract_tool_invocation_content(response) tool_invocation_content = extract_tool_invocation_content(response)