mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
484 lines
17 KiB
Python
484 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
|
|
from llama_models.llama3.api.datatypes import (
|
|
SamplingParams,
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolParamDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
SystemMessage,
|
|
ToolChoice,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from .utils import group_chunks
|
|
|
|
|
|
# How to run this test:
|
|
#
|
|
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py
|
|
# -m "(fireworks or ollama) and llama_3b"
|
|
# --env FIREWORKS_API_KEY=<your_api_key>
|
|
|
|
|
|
def get_expected_stop_reason(model: str):
|
|
return (
|
|
StopReason.end_of_message
|
|
if ("Llama3.1" in model or "Llama-3.1" in model)
|
|
else StopReason.end_of_turn
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def common_params(inference_model):
|
|
return {
|
|
"tool_choice": ToolChoice.auto,
|
|
"tool_prompt_format": (
|
|
ToolPromptFormat.json
|
|
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
|
|
else ToolPromptFormat.python_list
|
|
),
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_messages():
|
|
return [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="What's the weather like today?"),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_tool_definition():
|
|
return ToolDefinition(
|
|
tool_name="get_weather",
|
|
description="Get the current weather",
|
|
parameters={
|
|
"location": ToolParamDefinition(
|
|
param_type="string",
|
|
description="The city and state, e.g. San Francisco, CA",
|
|
),
|
|
},
|
|
)
|
|
|
|
|
|
class TestInference:
|
|
# Session scope for asyncio because the tests in this class all
|
|
# share the same provider instance.
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_model_list(self, inference_model, inference_stack):
|
|
_, models_impl = inference_stack
|
|
response = await models_impl.list_models()
|
|
assert isinstance(response, list)
|
|
assert len(response) >= 1
|
|
assert all(isinstance(model, Model) for model in response)
|
|
|
|
model_def = None
|
|
for model in response:
|
|
if model.identifier == inference_model:
|
|
model_def = model
|
|
break
|
|
|
|
assert model_def is not None
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_completion(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::fireworks",
|
|
"remote::nvidia",
|
|
"remote::cerebras",
|
|
):
|
|
pytest.skip("Other inference providers don't support completion() yet")
|
|
|
|
response = await inference_impl.completion(
|
|
content="Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
)
|
|
|
|
assert isinstance(response, CompletionResponse)
|
|
assert "1963" in response.content
|
|
|
|
chunks = [
|
|
r
|
|
async for r in await inference_impl.completion(
|
|
content="Roses are red,",
|
|
stream=True,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
)
|
|
]
|
|
|
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
|
assert len(chunks) >= 1
|
|
last = chunks[-1]
|
|
assert last.stop_reason == StopReason.out_of_tokens
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_completion_logprobs(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
# "remote::nvidia", -- provider doesn't provide all logprobs
|
|
):
|
|
pytest.skip("Other inference providers don't support completion() yet")
|
|
|
|
response = await inference_impl.completion(
|
|
content="Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=5,
|
|
),
|
|
logprobs=LogProbConfig(
|
|
top_k=3,
|
|
),
|
|
)
|
|
|
|
assert isinstance(response, CompletionResponse)
|
|
assert 1 <= len(response.logprobs) <= 5
|
|
assert response.logprobs, "Logprobs should not be empty"
|
|
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
|
|
|
chunks = [
|
|
r
|
|
async for r in await inference_impl.completion(
|
|
content="Roses are red,",
|
|
stream=True,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=5,
|
|
),
|
|
logprobs=LogProbConfig(
|
|
top_k=3,
|
|
),
|
|
)
|
|
]
|
|
|
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
|
assert (
|
|
1 <= len(chunks) <= 6
|
|
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
|
for chunk in chunks:
|
|
if (
|
|
chunk.delta.type == "text" and chunk.delta.text
|
|
): # if there's a token, we expect logprobs
|
|
assert chunk.logprobs, "Logprobs should not be empty"
|
|
assert all(
|
|
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
|
)
|
|
else: # no token, no logprobs
|
|
assert not chunk.logprobs, "Logprobs should be empty"
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
@pytest.mark.skip("This test is not quite robust")
|
|
async def test_completion_structured_output(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::fireworks",
|
|
"remote::nvidia",
|
|
"remote::vllm",
|
|
"remote::cerebras",
|
|
):
|
|
pytest.skip(
|
|
"Other inference providers don't support structured output in completions yet"
|
|
)
|
|
|
|
class Output(BaseModel):
|
|
name: str
|
|
year_born: str
|
|
year_retired: str
|
|
|
|
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
|
response = await inference_impl.completion(
|
|
model_id=inference_model,
|
|
content=user_input,
|
|
stream=False,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
response_format=JsonSchemaResponseFormat(
|
|
json_schema=Output.model_json_schema(),
|
|
),
|
|
)
|
|
assert isinstance(response, CompletionResponse)
|
|
assert isinstance(response.content, str)
|
|
|
|
answer = Output.model_validate_json(response.content)
|
|
assert answer.name == "Michael Jordan"
|
|
assert answer.year_born == "1963"
|
|
assert answer.year_retired == "2003"
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_non_streaming(
|
|
self, inference_model, inference_stack, common_params, sample_messages
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=sample_messages,
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert response.completion_message.role == "assistant"
|
|
assert isinstance(response.completion_message.content, str)
|
|
assert len(response.completion_message.content) > 0
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_structured_output(
|
|
self, inference_model, inference_stack, common_params
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::fireworks",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::vllm",
|
|
"remote::nvidia",
|
|
):
|
|
pytest.skip("Other inference providers don't support structured output yet")
|
|
|
|
class AnswerFormat(BaseModel):
|
|
first_name: str
|
|
last_name: str
|
|
year_of_birth: int
|
|
num_seasons_in_nba: int
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=[
|
|
# we include context about Michael Jordan in the prompt so that the test is
|
|
# focused on the funtionality of the model and not on the information embedded
|
|
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
|
|
SystemMessage(
|
|
content=(
|
|
"You are a helpful assistant.\n\n"
|
|
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
|
|
)
|
|
),
|
|
UserMessage(content="Please give me information about Michael Jordan."),
|
|
],
|
|
stream=False,
|
|
response_format=JsonSchemaResponseFormat(
|
|
json_schema=AnswerFormat.model_json_schema(),
|
|
),
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert response.completion_message.role == "assistant"
|
|
assert isinstance(response.completion_message.content, str)
|
|
|
|
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
|
assert answer.first_name == "Michael"
|
|
assert answer.last_name == "Jordan"
|
|
assert answer.year_of_birth == 1963
|
|
assert answer.num_seasons_in_nba == 15
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=[
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="Please give me information about Michael Jordan."),
|
|
],
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert isinstance(response.completion_message.content, str)
|
|
|
|
with pytest.raises(ValidationError):
|
|
AnswerFormat.model_validate_json(response.completion_message.content)
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_streaming(
|
|
self, inference_model, inference_stack, common_params, sample_messages
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
response = [
|
|
r
|
|
async for r in await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=sample_messages,
|
|
stream=True,
|
|
**common_params,
|
|
)
|
|
]
|
|
|
|
assert len(response) > 0
|
|
assert all(
|
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
|
)
|
|
grouped = group_chunks(response)
|
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
|
|
|
end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
assert end.event.stop_reason == StopReason.end_of_turn
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_with_tool_calling(
|
|
self,
|
|
inference_model,
|
|
inference_stack,
|
|
common_params,
|
|
sample_messages,
|
|
sample_tool_definition,
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if (
|
|
provider.__provider_spec__.provider_type == "remote::groq"
|
|
and "Llama-3.2" in inference_model
|
|
):
|
|
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
|
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
|
|
|
messages = sample_messages + [
|
|
UserMessage(
|
|
content="What's the weather like in San Francisco?",
|
|
)
|
|
]
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=messages,
|
|
tools=[sample_tool_definition],
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
|
|
message = response.completion_message
|
|
|
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
|
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
|
# assert message.stop_reason == stop_reason
|
|
assert message.tool_calls is not None
|
|
assert len(message.tool_calls) > 0
|
|
|
|
call = message.tool_calls[0]
|
|
assert call.tool_name == "get_weather"
|
|
assert "location" in call.arguments
|
|
assert "San Francisco" in call.arguments["location"]
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_with_tool_calling_streaming(
|
|
self,
|
|
inference_model,
|
|
inference_stack,
|
|
common_params,
|
|
sample_messages,
|
|
sample_tool_definition,
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if (
|
|
provider.__provider_spec__.provider_type == "remote::groq"
|
|
and "Llama-3.2" in inference_model
|
|
):
|
|
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
|
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
|
|
|
messages = sample_messages + [
|
|
UserMessage(
|
|
content="What's the weather like in San Francisco?",
|
|
)
|
|
]
|
|
|
|
response = [
|
|
r
|
|
async for r in await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=messages,
|
|
tools=[sample_tool_definition],
|
|
stream=True,
|
|
**common_params,
|
|
)
|
|
]
|
|
assert len(response) > 0
|
|
assert all(
|
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
|
)
|
|
grouped = group_chunks(response)
|
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
|
|
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
|
# expected_stop_reason = get_expected_stop_reason(
|
|
# inference_settings["common_params"]["model"]
|
|
# )
|
|
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
# assert end.event.stop_reason == expected_stop_reason
|
|
|
|
if "Llama3.1" in inference_model:
|
|
assert all(
|
|
chunk.event.delta.type == "tool_call"
|
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
|
)
|
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
|
if not isinstance(
|
|
first.event.delta.content, ToolCall
|
|
): # first chunk may contain entire call
|
|
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
|
|
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
|
# assert last.event.stop_reason == expected_stop_reason
|
|
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
|
assert last.event.delta.content.type == "tool_call"
|
|
|
|
call = last.event.delta.content
|
|
assert call.tool_name == "get_weather"
|
|
assert "location" in call.arguments
|
|
assert "San Francisco" in call.arguments["location"]
|