mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
fix(test): update client-sdk tests to handle tool format parametrization better
This commit is contained in:
parent
30ef1c3680
commit
4c0f122a4b
2 changed files with 36 additions and 39 deletions
|
@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=json_content.encode("utf-8"),
|
content=json_content.encode("utf-8"),
|
||||||
|
@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers or {},
|
headers=options.headers or {},
|
||||||
json=options.json_data,
|
json=convert_pydantic_to_json_value(body),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = APIResponse(
|
response = APIResponse(
|
||||||
|
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers or {},
|
headers=options.headers or {},
|
||||||
json=options.json_data,
|
json=convert_pydantic_to_json_value(body),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,14 +7,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
|
||||||
"remote::ollama": "json",
|
|
||||||
"remote::together": "json",
|
|
||||||
"remote::fireworks": "json",
|
|
||||||
"remote::vllm": "json",
|
|
||||||
}
|
|
||||||
|
|
||||||
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
||||||
|
|
||||||
|
@ -24,18 +19,32 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
||||||
provider_id = models[model_id].provider_id
|
provider_id = models[model_id].provider_id
|
||||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||||
provider = providers[provider_id]
|
provider = providers[provider_id]
|
||||||
print(f"Provider: {provider.provider_type} for model {model_id}")
|
|
||||||
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
|
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
def get_llama_model(client_with_models, model_id):
|
||||||
def provider_tool_format(inference_provider_type):
|
models = {}
|
||||||
return (
|
for m in client_with_models.models.list():
|
||||||
PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type]
|
models[m.identifier] = m
|
||||||
if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT
|
models[m.provider_resource_id] = m
|
||||||
else None
|
|
||||||
)
|
assert model_id in models, f"Model {model_id} not found"
|
||||||
|
|
||||||
|
model = models[model_id]
|
||||||
|
ids = (model.identifier, model.provider_resource_id)
|
||||||
|
for mid in ids:
|
||||||
|
if resolve_model(mid):
|
||||||
|
return mid
|
||||||
|
|
||||||
|
return model.metadata.get("llama_model", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_prompt_format(client_with_models, model_id):
|
||||||
|
llama_model = get_llama_model(client_with_models, model_id)
|
||||||
|
if not llama_model:
|
||||||
|
return None
|
||||||
|
return get_default_tool_prompt_format(llama_model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -237,12 +246,8 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
||||||
"inference:chat_completion:tool_calling",
|
"inference:chat_completion:tool_calling",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
client_with_models, text_model_id, provider_tool_format, test_case
|
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||||
):
|
|
||||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
|
||||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
|
||||||
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
response = client_with_models.inference.chat_completion(
|
response = client_with_models.inference.chat_completion(
|
||||||
|
@ -280,12 +285,8 @@ def extract_tool_invocation_content(response):
|
||||||
"inference:chat_completion:tool_calling",
|
"inference:chat_completion:tool_calling",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
|
||||||
client_with_models, text_model_id, provider_tool_format, test_case
|
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||||
):
|
|
||||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
|
||||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
|
||||||
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
response = client_with_models.inference.chat_completion(
|
response = client_with_models.inference.chat_completion(
|
||||||
|
@ -308,14 +309,8 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||||
"inference:chat_completion:tool_calling",
|
"inference:chat_completion:tool_calling",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_with_tool_choice_required(
|
def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
|
||||||
client_with_models,
|
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||||
text_model_id,
|
|
||||||
provider_tool_format,
|
|
||||||
test_case,
|
|
||||||
):
|
|
||||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
|
||||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
|
||||||
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
|
@ -341,14 +336,15 @@ def test_text_chat_completion_with_tool_choice_required(
|
||||||
"inference:chat_completion:tool_calling",
|
"inference:chat_completion:tool_calling",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
|
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
|
||||||
|
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
response = client_with_models.inference.chat_completion(
|
response = client_with_models.inference.chat_completion(
|
||||||
model_id=text_model_id,
|
model_id=text_model_id,
|
||||||
messages=tc["messages"],
|
messages=tc["messages"],
|
||||||
tools=tc["tools"],
|
tools=tc["tools"],
|
||||||
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
|
tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
tool_invocation_content = extract_tool_invocation_content(response)
|
tool_invocation_content = extract_tool_invocation_content(response)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue