mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
fix(test): update client-sdk tests to handle tool format parametrization better
This commit is contained in:
parent
30ef1c3680
commit
4c0f122a4b
2 changed files with 36 additions and 39 deletions
|
@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
await end_trace()
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
|
@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=options.json_data,
|
||||
json=convert_pydantic_to_json_value(body),
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
|
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=options.json_data,
|
||||
json=convert_pydantic_to_json_value(body),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -7,14 +7,9 @@
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
||||
|
||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||
"remote::ollama": "json",
|
||||
"remote::together": "json",
|
||||
"remote::fireworks": "json",
|
||||
"remote::vllm": "json",
|
||||
}
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||
|
||||
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
||||
|
||||
|
@ -24,18 +19,32 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
provider = providers[provider_id]
|
||||
print(f"Provider: {provider.provider_type} for model {model_id}")
|
||||
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def provider_tool_format(inference_provider_type):
|
||||
return (
|
||||
PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type]
|
||||
if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT
|
||||
else None
|
||||
)
|
||||
def get_llama_model(client_with_models, model_id):
|
||||
models = {}
|
||||
for m in client_with_models.models.list():
|
||||
models[m.identifier] = m
|
||||
models[m.provider_resource_id] = m
|
||||
|
||||
assert model_id in models, f"Model {model_id} not found"
|
||||
|
||||
model = models[model_id]
|
||||
ids = (model.identifier, model.provider_resource_id)
|
||||
for mid in ids:
|
||||
if resolve_model(mid):
|
||||
return mid
|
||||
|
||||
return model.metadata.get("llama_model", None)
|
||||
|
||||
|
||||
def get_tool_prompt_format(client_with_models, model_id):
|
||||
llama_model = get_llama_model(client_with_models, model_id)
|
||||
if not llama_model:
|
||||
return None
|
||||
return get_default_tool_prompt_format(llama_model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -237,12 +246,8 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||
client_with_models, text_model_id, provider_tool_format, test_case
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
|
||||
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
@ -280,12 +285,8 @@ def extract_tool_invocation_content(response):
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||
client_with_models, text_model_id, provider_tool_format, test_case
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
|
||||
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
@ -308,14 +309,8 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
client_with_models,
|
||||
text_model_id,
|
||||
provider_tool_format,
|
||||
test_case,
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
|
||||
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
|
@ -341,14 +336,15 @@ def test_text_chat_completion_with_tool_choice_required(
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
|
||||
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
|
||||
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=tc["messages"],
|
||||
tools=tc["tools"],
|
||||
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
|
||||
tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format},
|
||||
stream=True,
|
||||
)
|
||||
tool_invocation_content = extract_tool_invocation_content(response)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue