mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
# What does this PR do? The remote-vllm provider was not passing logprobs options from CompletionRequest or ChatCompletionRequests through to the OpenAI client parameters. I manually verified this, as well as observed this provider failing `TestInference::test_completion_logprobs`. This was filed as issue #1073. This fixes that by passing the `logprobs.top_k` value through to the parameters we pass into the OpenAI client. Additionally, this fixes a bug in `test_text_inference.py` where it mistakenly assumed chunk.delta were of type `ContentDelta` for completion requests. The deltas are of type `ContentDelta` for chat completion requests, but for basic completion requests the deltas are of type string. This test was likely failing for other providers that did properly support logprobs because of this latter issue in the test, which was hit while fixing the above issue with the remote-vllm provider. (Closes #1073) ## Test Plan First, you need a vllm running. I ran one locally like this: ``` vllm serve meta-llama/Llama-3.2-3B-Instruct --port 8001 --enable-auto-tool-choice --tool-call-parser llama3_json ``` Next, run test_text_inference.py against this vllm using the remote vllm provider like this: ``` VLLM_URL="http://localhost:8001/v1" python -m pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py --providers "inference=vllm_remote" ``` Before my change, the test failed with this error: ``` llama_stack/providers/tests/inference/test_text_inference.py:155: in test_completion_logprobs assert 1 <= len(response.logprobs) <= 5 E TypeError: object of type 'NoneType' has no len() ``` After my change, the test passes. [//]: # (## Documentation) Signed-off-by: Ben Browning <bbrownin@redhat.com>
405 lines
14 KiB
Python
405 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
|
|
from llama_models.llama3.api.datatypes import (
|
|
SamplingParams,
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolParamDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
SystemMessage,
|
|
ToolChoice,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.apis.models import ListModelsResponse, Model
|
|
|
|
from .utils import group_chunks
|
|
|
|
|
|
# How to run this test:
|
|
#
|
|
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py
|
|
# -m "(fireworks or ollama) and llama_3b"
|
|
# --env FIREWORKS_API_KEY=<your_api_key>
|
|
|
|
|
|
def get_expected_stop_reason(model: str):
|
|
return StopReason.end_of_message if ("Llama3.1" in model or "Llama-3.1" in model) else StopReason.end_of_turn
|
|
|
|
|
|
@pytest.fixture
|
|
def common_params(inference_model):
|
|
return {
|
|
"tool_choice": ToolChoice.auto,
|
|
"tool_prompt_format": (
|
|
ToolPromptFormat.json
|
|
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
|
|
else ToolPromptFormat.python_list
|
|
),
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_messages():
|
|
return [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="What's the weather like today?"),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_tool_definition():
|
|
return ToolDefinition(
|
|
tool_name="get_weather",
|
|
description="Get the current weather",
|
|
parameters={
|
|
"location": ToolParamDefinition(
|
|
param_type="string",
|
|
description="The city and state, e.g. San Francisco, CA",
|
|
),
|
|
},
|
|
)
|
|
|
|
|
|
class TestInference:
|
|
# Session scope for asyncio because the tests in this class all
|
|
# share the same provider instance.
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_model_list(self, inference_model, inference_stack):
|
|
_, models_impl = inference_stack
|
|
response = await models_impl.list_models()
|
|
assert isinstance(response, ListModelsResponse)
|
|
assert isinstance(response.data, list)
|
|
assert len(response.data) >= 1
|
|
assert all(isinstance(model, Model) for model in response.data)
|
|
|
|
model_def = None
|
|
for model in response.data:
|
|
if model.identifier == inference_model:
|
|
model_def = model
|
|
break
|
|
|
|
assert model_def is not None
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_completion(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
response = await inference_impl.completion(
|
|
content="Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
)
|
|
|
|
assert isinstance(response, CompletionResponse)
|
|
assert "1963" in response.content
|
|
|
|
chunks = [
|
|
r
|
|
async for r in await inference_impl.completion(
|
|
content="Roses are red,",
|
|
stream=True,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
)
|
|
]
|
|
|
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
|
assert len(chunks) >= 1
|
|
last = chunks[-1]
|
|
assert last.stop_reason == StopReason.out_of_tokens
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_completion_logprobs(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
response = await inference_impl.completion(
|
|
content="Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=5,
|
|
),
|
|
logprobs=LogProbConfig(
|
|
top_k=3,
|
|
),
|
|
)
|
|
|
|
assert isinstance(response, CompletionResponse)
|
|
assert 1 <= len(response.logprobs) <= 5
|
|
assert response.logprobs, "Logprobs should not be empty"
|
|
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
|
|
|
chunks = [
|
|
r
|
|
async for r in await inference_impl.completion(
|
|
content="Roses are red,",
|
|
stream=True,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=5,
|
|
),
|
|
logprobs=LogProbConfig(
|
|
top_k=3,
|
|
),
|
|
)
|
|
]
|
|
|
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
|
assert (
|
|
1 <= len(chunks) <= 6
|
|
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
|
for chunk in chunks:
|
|
if chunk.delta: # if there's a token, we expect logprobs
|
|
assert chunk.logprobs, "Logprobs should not be empty"
|
|
assert all(len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs)
|
|
else: # no token, no logprobs
|
|
assert not chunk.logprobs, "Logprobs should be empty"
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_completion_structured_output(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
class Output(BaseModel):
|
|
name: str
|
|
year_born: str
|
|
year_retired: str
|
|
|
|
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
|
response = await inference_impl.completion(
|
|
model_id=inference_model,
|
|
content=user_input,
|
|
stream=False,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
response_format=JsonSchemaResponseFormat(
|
|
json_schema=Output.model_json_schema(),
|
|
),
|
|
)
|
|
assert isinstance(response, CompletionResponse)
|
|
assert isinstance(response.content, str)
|
|
|
|
answer = Output.model_validate_json(response.content)
|
|
assert answer.name == "Michael Jordan"
|
|
assert answer.year_born == "1963"
|
|
assert answer.year_retired == "2003"
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_non_streaming(
|
|
self, inference_model, inference_stack, common_params, sample_messages
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=sample_messages,
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert response.completion_message.role == "assistant"
|
|
assert isinstance(response.completion_message.content, str)
|
|
assert len(response.completion_message.content) > 0
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_structured_output(self, inference_model, inference_stack, common_params):
|
|
inference_impl, _ = inference_stack
|
|
|
|
class AnswerFormat(BaseModel):
|
|
first_name: str
|
|
last_name: str
|
|
year_of_birth: int
|
|
num_seasons_in_nba: int
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=[
|
|
# we include context about Michael Jordan in the prompt so that the test is
|
|
# focused on the funtionality of the model and not on the information embedded
|
|
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
|
|
SystemMessage(
|
|
content=(
|
|
"You are a helpful assistant.\n\n"
|
|
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
|
|
)
|
|
),
|
|
UserMessage(content="Please give me information about Michael Jordan."),
|
|
],
|
|
stream=False,
|
|
response_format=JsonSchemaResponseFormat(
|
|
json_schema=AnswerFormat.model_json_schema(),
|
|
),
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert response.completion_message.role == "assistant"
|
|
assert isinstance(response.completion_message.content, str)
|
|
|
|
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
|
assert answer.first_name == "Michael"
|
|
assert answer.last_name == "Jordan"
|
|
assert answer.year_of_birth == 1963
|
|
assert answer.num_seasons_in_nba == 15
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=[
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="Please give me information about Michael Jordan."),
|
|
],
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert isinstance(response.completion_message.content, str)
|
|
|
|
with pytest.raises(ValidationError):
|
|
AnswerFormat.model_validate_json(response.completion_message.content)
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
|
|
inference_impl, _ = inference_stack
|
|
response = [
|
|
r
|
|
async for r in await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=sample_messages,
|
|
stream=True,
|
|
**common_params,
|
|
)
|
|
]
|
|
|
|
assert len(response) > 0
|
|
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
|
|
grouped = group_chunks(response)
|
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
|
|
|
end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
assert end.event.stop_reason == StopReason.end_of_turn
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_with_tool_calling(
|
|
self,
|
|
inference_model,
|
|
inference_stack,
|
|
common_params,
|
|
sample_messages,
|
|
sample_tool_definition,
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
messages = sample_messages + [
|
|
UserMessage(
|
|
content="What's the weather like in San Francisco?",
|
|
)
|
|
]
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=messages,
|
|
tools=[sample_tool_definition],
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
|
|
message = response.completion_message
|
|
|
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
|
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
|
# assert message.stop_reason == stop_reason
|
|
assert message.tool_calls is not None
|
|
assert len(message.tool_calls) > 0
|
|
|
|
call = message.tool_calls[0]
|
|
assert call.tool_name == "get_weather"
|
|
assert "location" in call.arguments
|
|
assert "San Francisco" in call.arguments["location"]
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_with_tool_calling_streaming(
|
|
self,
|
|
inference_model,
|
|
inference_stack,
|
|
common_params,
|
|
sample_messages,
|
|
sample_tool_definition,
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
messages = sample_messages + [
|
|
UserMessage(
|
|
content="What's the weather like in San Francisco?",
|
|
)
|
|
]
|
|
|
|
response = [
|
|
r
|
|
async for r in await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=messages,
|
|
tools=[sample_tool_definition],
|
|
stream=True,
|
|
**common_params,
|
|
)
|
|
]
|
|
assert len(response) > 0
|
|
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
|
|
grouped = group_chunks(response)
|
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
|
|
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
|
# expected_stop_reason = get_expected_stop_reason(
|
|
# inference_settings["common_params"]["model"]
|
|
# )
|
|
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
# assert end.event.stop_reason == expected_stop_reason
|
|
|
|
if "Llama3.1" in inference_model:
|
|
assert all(
|
|
chunk.event.delta.type == "tool_call" for chunk in grouped[ChatCompletionResponseEventType.progress]
|
|
)
|
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
|
if not isinstance(first.event.delta.tool_call, ToolCall): # first chunk may contain entire call
|
|
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
|
|
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
|
# assert last.event.stop_reason == expected_stop_reason
|
|
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
|
assert isinstance(last.event.delta.tool_call, ToolCall)
|
|
|
|
call = last.event.delta.tool_call
|
|
assert call.tool_name == "get_weather"
|
|
assert "location" in call.arguments
|
|
assert "San Francisco" in call.arguments["location"]
|