forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Tool format depends on the model. @ehhuang introduced a `get_default_tool_prompt_format` function for this purpose. We should use that instead of hacky model ID matching we had before. Secondly, non llama models don't have this concept so testing with those models should work as is. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ```bash for distro in fireworks ollama; do LLAMA_STACK_CONFIG=$distro \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=meta-llama/Llama-3.2-3B-Instruct \ --vision-inference-model="" done LLAMA_STACK_CONFIG=dev \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=openai/gpt-4o \ --vision-inference-model="" ``` [//]: # (## Documentation)
422 lines
14 KiB
Python
422 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from llama_stack.models.llama.sku_list import resolve_model
|
|
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
|
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
|
|
|
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
|
|
|
|
|
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|
models = {m.identifier: m for m in client_with_models.models.list()}
|
|
provider_id = models[model_id].provider_id
|
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
|
provider = providers[provider_id]
|
|
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
|
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
|
|
|
|
|
def get_llama_model(client_with_models, model_id):
|
|
models = {}
|
|
for m in client_with_models.models.list():
|
|
models[m.identifier] = m
|
|
models[m.provider_resource_id] = m
|
|
|
|
assert model_id in models, f"Model {model_id} not found"
|
|
|
|
model = models[model_id]
|
|
ids = (model.identifier, model.provider_resource_id)
|
|
for mid in ids:
|
|
if resolve_model(mid):
|
|
return mid
|
|
|
|
return model.metadata.get("llama_model", None)
|
|
|
|
|
|
def get_tool_prompt_format(client_with_models, model_id):
|
|
llama_model = get_llama_model(client_with_models, model_id)
|
|
if not llama_model:
|
|
return None
|
|
return get_default_tool_prompt_format(llama_model)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:completion:sanity",
|
|
],
|
|
)
|
|
def test_text_completion_non_streaming(client_with_models, text_model_id, test_case):
|
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.completion(
|
|
content=tc["content"],
|
|
stream=False,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 50,
|
|
},
|
|
)
|
|
assert len(response.content) > 10
|
|
# assert "blue" in response.content.lower().strip()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:completion:sanity",
|
|
],
|
|
)
|
|
def test_text_completion_streaming(client_with_models, text_model_id, test_case):
|
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.completion(
|
|
content=tc["content"],
|
|
stream=True,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 50,
|
|
},
|
|
)
|
|
streamed_content = [chunk.delta for chunk in response]
|
|
content_str = "".join(streamed_content).lower().strip()
|
|
# assert "blue" in content_str
|
|
assert len(content_str) > 10
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:completion:log_probs",
|
|
],
|
|
)
|
|
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
|
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.completion(
|
|
content=tc["content"],
|
|
stream=False,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 5,
|
|
},
|
|
logprobs={
|
|
"top_k": 1,
|
|
},
|
|
)
|
|
assert response.logprobs, "Logprobs should not be empty"
|
|
assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5
|
|
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:completion:log_probs",
|
|
],
|
|
)
|
|
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
|
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.completion(
|
|
content=tc["content"],
|
|
stream=True,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 5,
|
|
},
|
|
logprobs={
|
|
"top_k": 1,
|
|
},
|
|
)
|
|
streamed_content = [chunk for chunk in response]
|
|
for chunk in streamed_content:
|
|
if chunk.delta: # if there's a token, we expect logprobs
|
|
assert chunk.logprobs, "Logprobs should not be empty"
|
|
assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs)
|
|
else: # no token, no logprobs
|
|
assert not chunk.logprobs, "Logprobs should be empty"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:completion:structured_output",
|
|
],
|
|
)
|
|
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
|
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
|
|
|
class AnswerFormat(BaseModel):
|
|
name: str
|
|
year_born: str
|
|
year_retired: str
|
|
|
|
tc = TestCase(test_case)
|
|
|
|
user_input = tc["user_input"]
|
|
response = client_with_models.inference.completion(
|
|
model_id=text_model_id,
|
|
content=user_input,
|
|
stream=False,
|
|
sampling_params={
|
|
"max_tokens": 50,
|
|
},
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": AnswerFormat.model_json_schema(),
|
|
},
|
|
)
|
|
answer = AnswerFormat.model_validate_json(response.content)
|
|
expected = tc["expected"]
|
|
assert answer.name == expected["name"]
|
|
assert answer.year_born == expected["year_born"]
|
|
assert answer.year_retired == expected["year_retired"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:non_streaming_01",
|
|
"inference:chat_completion:non_streaming_02",
|
|
],
|
|
)
|
|
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
|
tc = TestCase(test_case)
|
|
question = tc["question"]
|
|
expected = tc["expected"]
|
|
|
|
response = client_with_models.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": question,
|
|
}
|
|
],
|
|
stream=False,
|
|
)
|
|
message_content = response.completion_message.content.lower().strip()
|
|
assert len(message_content) > 0
|
|
assert expected.lower() in message_content
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:streaming_01",
|
|
"inference:chat_completion:streaming_02",
|
|
],
|
|
)
|
|
def test_text_chat_completion_streaming(client_with_models, text_model_id, test_case):
|
|
tc = TestCase(test_case)
|
|
question = tc["question"]
|
|
expected = tc["expected"]
|
|
|
|
response = client_with_models.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=[{"role": "user", "content": question}],
|
|
stream=True,
|
|
)
|
|
streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response]
|
|
assert len(streamed_content) > 0
|
|
assert expected.lower() in "".join(streamed_content)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:tool_calling",
|
|
],
|
|
)
|
|
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
|
|
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=tc["messages"],
|
|
tools=tc["tools"],
|
|
tool_choice="auto",
|
|
tool_prompt_format=tool_prompt_format,
|
|
stream=False,
|
|
)
|
|
# some models can return content for the response in addition to the tool call
|
|
assert response.completion_message.role == "assistant"
|
|
|
|
assert len(response.completion_message.tool_calls) == 1
|
|
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
|
|
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
|
|
|
|
|
|
# Will extract streamed text and separate it from tool invocation content
|
|
# The returned tool inovcation content will be a string so it's easy to comapare with expected value
|
|
# e.g. "[get_weather, {'location': 'San Francisco, CA'}]"
|
|
def extract_tool_invocation_content(response):
|
|
tool_invocation_content: str = ""
|
|
for chunk in response:
|
|
delta = chunk.event.delta
|
|
if delta.type == "tool_call" and delta.parse_status == "succeeded":
|
|
call = delta.tool_call
|
|
tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
|
|
return tool_invocation_content
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:tool_calling",
|
|
],
|
|
)
|
|
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
|
|
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=tc["messages"],
|
|
tools=tc["tools"],
|
|
tool_choice="auto",
|
|
tool_prompt_format=tool_prompt_format,
|
|
stream=True,
|
|
)
|
|
tool_invocation_content = extract_tool_invocation_content(response)
|
|
expected_tool_name = tc["tools"][0]["tool_name"]
|
|
expected_argument = tc["expected"]
|
|
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:tool_calling",
|
|
],
|
|
)
|
|
def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
|
|
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
|
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=tc["messages"],
|
|
tools=tc["tools"],
|
|
tool_config={
|
|
"tool_choice": "required",
|
|
"tool_prompt_format": tool_prompt_format,
|
|
},
|
|
stream=True,
|
|
)
|
|
tool_invocation_content = extract_tool_invocation_content(response)
|
|
expected_tool_name = tc["tools"][0]["tool_name"]
|
|
expected_argument = tc["expected"]
|
|
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:tool_calling",
|
|
],
|
|
)
|
|
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
|
|
tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=tc["messages"],
|
|
tools=tc["tools"],
|
|
tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format},
|
|
stream=True,
|
|
)
|
|
tool_invocation_content = extract_tool_invocation_content(response)
|
|
assert tool_invocation_content == ""
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:structured_output",
|
|
],
|
|
)
|
|
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
|
|
class AnswerFormat(BaseModel):
|
|
first_name: str
|
|
last_name: str
|
|
year_of_birth: int
|
|
num_seasons_in_nba: int
|
|
|
|
tc = TestCase(test_case)
|
|
|
|
response = client_with_models.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=tc["messages"],
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": AnswerFormat.model_json_schema(),
|
|
},
|
|
stream=False,
|
|
)
|
|
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
|
expected = tc["expected"]
|
|
assert answer.first_name == expected["first_name"]
|
|
assert answer.last_name == expected["last_name"]
|
|
assert answer.year_of_birth == expected["year_of_birth"]
|
|
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
|
|
|
|
|
|
@pytest.mark.parametrize("streaming", [True, False])
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:tool_calling_tools_absent",
|
|
],
|
|
)
|
|
def test_text_chat_completion_tool_calling_tools_not_in_request(
|
|
client_with_models, text_model_id, test_case, streaming
|
|
):
|
|
tc = TestCase(test_case)
|
|
|
|
# TODO: more dynamic lookup on tool_prompt_format for model family
|
|
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
|
request = {
|
|
"model_id": text_model_id,
|
|
"messages": tc["messages"],
|
|
"tools": tc["tools"],
|
|
"tool_choice": "auto",
|
|
"tool_prompt_format": tool_prompt_format,
|
|
"stream": streaming,
|
|
}
|
|
|
|
response = client_with_models.inference.chat_completion(**request)
|
|
|
|
if streaming:
|
|
for chunk in response:
|
|
delta = chunk.event.delta
|
|
if delta.type == "tool_call" and delta.parse_status == "succeeded":
|
|
assert delta.tool_call.tool_name == "get_object_namespace_list"
|
|
if delta.type == "tool_call" and delta.parse_status == "failed":
|
|
# expect raw message that failed to parse in tool_call
|
|
assert type(delta.tool_call) == str
|
|
assert len(delta.tool_call) > 0
|
|
else:
|
|
for tc in response.completion_message.tool_calls:
|
|
assert tc.tool_name == "get_object_namespace_list"
|