forked from phoenix-oss/llama-stack-mirror
# What does this PR do? 1) enabled structured output for ollama /completion API. It seems we missed this one. 2) fixed ollama structured output test in client sdk - ollama does not support list format for structured output 3) enable structured output unit test as the result was stable on Llama-3.1-8B-Instruct and ollama, fireworks, together. ## Test Plan 1) Run `test_completion_structured_output` on /completion API with 3 providers: ollama, fireworks, together. pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output ``` (base) sxyi@sxyi-mbp llama-stack % pytest -s -v llama_stack/providers/tests/inference --config=ci_test_config.yaml /Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ================================================================================================ test session starts ================================================================================================= platform darwin -- Python 3.13.0, pytest-8.3.4, pluggy-1.5.0 -- /Library/Frameworks/Python.framework/Versions/3.13/bin/python3.13 cachedir: .pytest_cache metadata: {'Python': '3.13.0', 'Platform': 'macOS-15.1.1-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'asyncio': '0.24.0', 'html': '4.1.1', 'metadata': '3.1.1', 'md': '0.2.0', 'dependency': '0.6.0', 'md-report': '0.6.3', 'anyio': '4.6.2.post1'}} rootdir: /Users/sxyi/llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, html-4.1.1, metadata-3.1.1, md-0.2.0, dependency-0.6.0, md-report-0.6.3, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 85 items / 82 deselected / 3 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct-fireworks] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct-together] PASSED ==================================================================================== 3 passed, 82 deselected, 8 warnings in 5.67s ==================================================================================== ``` 2) ` LLAMA_STACK_CONFIG="./llama_stack/templates/ollama/run.yaml" /opt/miniconda3/envs/stack/bin/pytest -s -v tests/client-sdk/inference` Before: ``` ________________________________________________________________________________________ test_completion_structured_output __________________________________________________________________________________________ tests/client-sdk/inference/test_inference.py:174: in test_completion_structured_output answer = AnswerFormat.model_validate_json(response.content) E pydantic_core._pydantic_core.ValidationError: 1 validation error for AnswerFormat E Invalid JSON: expected value at line 1 column 2 [type=json_invalid, input_value=' The year he retired, he...5\n\nThe best answer is', input_type=str] E For further information visit https://errors.pydantic.dev/2.10/v/json_invalid ``` After: test consistently passes ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
399 lines
13 KiB
Python
399 lines
13 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import base64
|
|
import os
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
PROVIDER_TOOL_PROMPT_FORMAT = {
|
|
"remote::ollama": "json",
|
|
"remote::together": "json",
|
|
"remote::fireworks": "json",
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def provider_tool_format(inference_provider_type):
|
|
return (
|
|
PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type]
|
|
if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT
|
|
else None
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_provider_type(llama_stack_client):
|
|
providers = llama_stack_client.providers.list()
|
|
inference_providers = [p for p in providers if p.api == "inference"]
|
|
assert len(inference_providers) > 0, "No inference providers found"
|
|
return inference_providers[0].provider_type
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def text_model_id(llama_stack_client):
|
|
available_models = [
|
|
model.identifier
|
|
for model in llama_stack_client.models.list()
|
|
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
|
]
|
|
assert len(available_models) > 0
|
|
return available_models[0]
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vision_model_id(llama_stack_client):
|
|
available_models = [
|
|
model.identifier
|
|
for model in llama_stack_client.models.list()
|
|
if "vision" in model.identifier.lower()
|
|
]
|
|
if len(available_models) == 0:
|
|
pytest.skip("No vision models available")
|
|
|
|
return available_models[0]
|
|
|
|
|
|
@pytest.fixture
|
|
def get_weather_tool_definition():
|
|
return {
|
|
"tool_name": "get_weather",
|
|
"description": "Get the current weather",
|
|
"parameters": {
|
|
"location": {
|
|
"param_type": "string",
|
|
"description": "The city and state, e.g. San Francisco, CA",
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def base64_image_url():
|
|
image_path = os.path.join(os.path.dirname(__file__), "dog.png")
|
|
with open(image_path, "rb") as image_file:
|
|
# Convert the image to base64
|
|
base64_string = base64.b64encode(image_file.read()).decode("utf-8")
|
|
base64_url = f"data:image;base64,{base64_string}"
|
|
return base64_url
|
|
|
|
|
|
def test_completion_non_streaming(llama_stack_client, text_model_id):
|
|
response = llama_stack_client.inference.completion(
|
|
content="Complete the sentence using one word: Roses are red, violets are ",
|
|
stream=False,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 50,
|
|
},
|
|
)
|
|
assert "blue" in response.content.lower().strip()
|
|
|
|
|
|
def test_completion_streaming(llama_stack_client, text_model_id):
|
|
response = llama_stack_client.inference.completion(
|
|
content="Complete the sentence using one word: Roses are red, violets are ",
|
|
stream=True,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 50,
|
|
},
|
|
)
|
|
streamed_content = [chunk.delta for chunk in response]
|
|
assert "blue" in "".join(streamed_content).lower().strip()
|
|
|
|
|
|
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
|
response = llama_stack_client.inference.completion(
|
|
content="Complete the sentence: Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 5,
|
|
},
|
|
logprobs={
|
|
"top_k": 3,
|
|
},
|
|
)
|
|
assert response.logprobs, "Logprobs should not be empty"
|
|
assert 1 <= len(response.logprobs) <= 5
|
|
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
|
|
|
|
|
def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
|
response = llama_stack_client.inference.completion(
|
|
content="Complete the sentence: Micheael Jordan is born in ",
|
|
stream=True,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 5,
|
|
},
|
|
logprobs={
|
|
"top_k": 3,
|
|
},
|
|
)
|
|
streamed_content = [chunk for chunk in response]
|
|
for chunk in streamed_content:
|
|
if chunk.delta: # if there's a token, we expect logprobs
|
|
assert chunk.logprobs, "Logprobs should not be empty"
|
|
assert all(
|
|
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
|
)
|
|
else: # no token, no logprobs
|
|
assert not chunk.logprobs, "Logprobs should be empty"
|
|
|
|
|
|
def test_completion_structured_output(
|
|
llama_stack_client, text_model_id, inference_provider_type
|
|
):
|
|
user_input = """
|
|
Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.
|
|
"""
|
|
|
|
class AnswerFormat(BaseModel):
|
|
name: str
|
|
year_born: str
|
|
year_retired: str
|
|
|
|
response = llama_stack_client.inference.completion(
|
|
model_id=text_model_id,
|
|
content=user_input,
|
|
stream=False,
|
|
sampling_params={
|
|
"max_tokens": 50,
|
|
},
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": AnswerFormat.model_json_schema(),
|
|
},
|
|
)
|
|
answer = AnswerFormat.model_validate_json(response.content)
|
|
assert answer.name == "Michael Jordan"
|
|
assert answer.year_born == "1963"
|
|
assert answer.year_retired == "2003"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"question,expected",
|
|
[
|
|
("What are the names of planets in our solar system?", "Earth"),
|
|
("What are the names of the planets that have rings around them?", "Saturn"),
|
|
],
|
|
)
|
|
def test_text_chat_completion_non_streaming(
|
|
llama_stack_client, text_model_id, question, expected
|
|
):
|
|
response = llama_stack_client.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": question,
|
|
}
|
|
],
|
|
stream=False,
|
|
)
|
|
message_content = response.completion_message.content.lower().strip()
|
|
assert len(message_content) > 0
|
|
assert expected.lower() in message_content
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"question,expected",
|
|
[
|
|
("What's the name of the Sun in latin?", "Sol"),
|
|
("What is the name of the US captial?", "Washington"),
|
|
],
|
|
)
|
|
def test_text_chat_completion_streaming(
|
|
llama_stack_client, text_model_id, question, expected
|
|
):
|
|
response = llama_stack_client.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=[{"role": "user", "content": question}],
|
|
stream=True,
|
|
)
|
|
streamed_content = [
|
|
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
|
]
|
|
assert len(streamed_content) > 0
|
|
assert expected.lower() in "".join(streamed_content)
|
|
|
|
|
|
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
|
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
|
):
|
|
response = llama_stack_client.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
|
],
|
|
tools=[get_weather_tool_definition],
|
|
tool_choice="auto",
|
|
tool_prompt_format=provider_tool_format,
|
|
stream=False,
|
|
)
|
|
# No content is returned for the system message since we expect the
|
|
# response to be a tool call
|
|
assert response.completion_message.content == ""
|
|
assert response.completion_message.role == "assistant"
|
|
|
|
assert len(response.completion_message.tool_calls) == 1
|
|
assert response.completion_message.tool_calls[0].tool_name == "get_weather"
|
|
assert response.completion_message.tool_calls[0].arguments == {
|
|
"location": "San Francisco, CA"
|
|
}
|
|
|
|
|
|
# Will extract streamed text and separate it from tool invocation content
|
|
# The returned tool inovcation content will be a string so it's easy to comapare with expected value
|
|
# e.g. "[get_weather, {'location': 'San Francisco, CA'}]"
|
|
def extract_tool_invocation_content(response):
|
|
tool_invocation_content: str = ""
|
|
for chunk in response:
|
|
delta = chunk.event.delta
|
|
if delta.type == "tool_call" and delta.parse_status == "succeeded":
|
|
call = delta.content
|
|
tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
|
|
return tool_invocation_content
|
|
|
|
|
|
def test_text_chat_completion_with_tool_calling_and_streaming(
|
|
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
|
):
|
|
response = llama_stack_client.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
|
],
|
|
tools=[get_weather_tool_definition],
|
|
tool_choice="auto",
|
|
tool_prompt_format=provider_tool_format,
|
|
stream=True,
|
|
)
|
|
tool_invocation_content = extract_tool_invocation_content(response)
|
|
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
|
|
|
|
|
def test_text_chat_completion_structured_output(
|
|
llama_stack_client, text_model_id, inference_provider_type
|
|
):
|
|
class AnswerFormat(BaseModel):
|
|
first_name: str
|
|
last_name: str
|
|
year_of_birth: int
|
|
num_seasons_in_nba: int
|
|
|
|
response = llama_stack_client.inference.chat_completion(
|
|
model_id=text_model_id,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Please give me information about Michael Jordan.",
|
|
},
|
|
],
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": AnswerFormat.model_json_schema(),
|
|
},
|
|
stream=False,
|
|
)
|
|
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
|
assert answer.first_name == "Michael"
|
|
assert answer.last_name == "Jordan"
|
|
assert answer.year_of_birth == 1963
|
|
assert answer.num_seasons_in_nba == 15
|
|
|
|
|
|
def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id):
|
|
message = {
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image",
|
|
"url": {
|
|
# TODO: Replace with Github based URI to resources/sample1.jpg
|
|
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
|
},
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": "Describe what is in this image.",
|
|
},
|
|
],
|
|
}
|
|
response = llama_stack_client.inference.chat_completion(
|
|
model_id=vision_model_id,
|
|
messages=[message],
|
|
stream=False,
|
|
)
|
|
message_content = response.completion_message.content.lower().strip()
|
|
assert len(message_content) > 0
|
|
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
|
|
|
|
|
def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
|
message = {
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image",
|
|
"url": {
|
|
# TODO: Replace with Github based URI to resources/sample1.jpg
|
|
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
|
},
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": "Describe what is in this image.",
|
|
},
|
|
],
|
|
}
|
|
response = llama_stack_client.inference.chat_completion(
|
|
model_id=vision_model_id,
|
|
messages=[message],
|
|
stream=True,
|
|
)
|
|
streamed_content = ""
|
|
for chunk in response:
|
|
streamed_content += chunk.event.delta.text.lower()
|
|
assert len(streamed_content) > 0
|
|
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
|
|
|
|
|
def test_image_chat_completion_base64_url(
|
|
llama_stack_client, vision_model_id, base64_image_url
|
|
):
|
|
|
|
message = {
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image",
|
|
"url": {
|
|
"uri": base64_image_url,
|
|
},
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": "Describe what is in this image.",
|
|
},
|
|
],
|
|
}
|
|
response = llama_stack_client.inference.chat_completion(
|
|
model_id=vision_model_id,
|
|
messages=[message],
|
|
stream=False,
|
|
)
|
|
message_content = response.completion_message.content.lower().strip()
|
|
assert len(message_content) > 0
|