fix(test): update client-sdk tests to handle tool format parametrization better (#1287)

# What does this PR do?

Tool format depends on the model. @ehhuang introduced a
`get_default_tool_prompt_format` function for this purpose. We should
use that instead of hacky model ID matching we had before.

Secondly, non llama models don't have this concept so testing with those
models should work as is.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

```bash
for distro in fireworks ollama; do
  LLAMA_STACK_CONFIG=$distro \
    pytest -s -v tests/client-sdk/inference/test_text_inference.py \
       --inference-model=meta-llama/Llama-3.2-3B-Instruct \
       --vision-inference-model=""
done

LLAMA_STACK_CONFIG=dev \
   pytest -s -v tests/client-sdk/inference/test_text_inference.py \
       --inference-model=openai/gpt-4o \
       --vision-inference-model=""

```

[//]: # (## Documentation)
This commit is contained in:
Ashwin Bharambe 2025-02-26 21:16:00 -08:00 committed by GitHub
parent 30ef1c3680
commit 23b65b6cee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 64 additions and 63 deletions

View file

@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
await end_trace()
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url,
params=options.params,
headers=options.headers or {},
json=options.json_data,
json=convert_pydantic_to_json_value(body),
),
)
response = APIResponse(
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url,
params=options.params,
headers=options.headers or {},
json=options.json_data,
json=convert_pydantic_to_json_value(body),
),
)

View file

@ -518,15 +518,13 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...]
async def _convert_user_message_content(
async def _convert_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl():
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str):
return OpenAIChatCompletionContentPartTextParam(
type="text",
text=content,
)
return content
elif isinstance(content, TextContentItem):
return OpenAIChatCompletionContentPartTextParam(
type="text",
@ -537,21 +535,27 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
type="image_url",
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
)
elif isinstance(content, List):
return [await _convert_user_message_content(item) for item in content]
elif isinstance(content, list):
return [await _convert_message_content(item) for item in content]
else:
raise ValueError(f"Unsupported content type: {type(content)}")
ret = await impl()
if isinstance(ret, str) or isinstance(ret, list):
return ret
else:
return [ret]
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=await _convert_user_message_content(message.content),
content=await _convert_message_content(message.content),
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=message.content,
content=await _convert_message_content(message.content),
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
@ -568,12 +572,12 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=message.content,
content=await _convert_message_content(message.content),
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=message.content,
content=await _convert_message_content(message.content),
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")

View file

@ -7,14 +7,9 @@
import pytest
from pydantic import BaseModel
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.tests.test_cases.test_case import TestCase
PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "json",
"remote::together": "json",
"remote::fireworks": "json",
"remote::vllm": "json",
}
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
@ -24,18 +19,32 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
print(f"Provider: {provider.provider_type} for model {model_id}")
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
@pytest.fixture(scope="session")
def provider_tool_format(inference_provider_type):
return (
PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type]
if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT
else None
)
def get_llama_model(client_with_models, model_id):
models = {}
for m in client_with_models.models.list():
models[m.identifier] = m
models[m.provider_resource_id] = m
assert model_id in models, f"Model {model_id} not found"
model = models[model_id]
ids = (model.identifier, model.provider_resource_id)
for mid in ids:
if resolve_model(mid):
return mid
return model.metadata.get("llama_model", None)
def get_tool_prompt_format(client_with_models, model_id):
llama_model = get_llama_model(client_with_models, model_id)
if not llama_model:
return None
return get_default_tool_prompt_format(llama_model)
@pytest.mark.parametrize(
@ -237,12 +246,8 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_non_streaming(
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -280,12 +285,8 @@ def extract_tool_invocation_content(response):
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_streaming(
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -308,14 +309,8 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_required(
client_with_models,
text_model_id,
provider_tool_format,
test_case,
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case)
@ -341,14 +336,15 @@ def test_text_chat_completion_with_tool_choice_required(
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=tc["messages"],
tools=tc["tools"],
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)