mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
# What does this PR do? This PR adds the inline vLLM inference provider to the regression tests for inference providers. The PR also fixes some regressions in that inference provider in order to make the tests pass. ## Test Plan Command to run the new tests (from root of project): ``` pytest \ -vvv \ llama_stack/providers/tests/inference/test_text_inference.py \ --providers inference=vllm \ --inference-model meta-llama/Llama-3.2-3B-Instruct \ ``` Output of the above command after these changes: ``` /mnt/datadisk1/freiss/llama/env/lib/python3.12/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =================================================================== test session starts =================================================================== platform linux -- Python 3.12.7, pytest-8.3.4, pluggy-1.5.0 -- /mnt/datadisk1/freiss/llama/env/bin/python3.12 cachedir: .pytest_cache rootdir: /mnt/datadisk1/freiss/llama/llama-stack configfile: pyproject.toml plugins: asyncio-0.25.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None collected 9 items llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-vllm] PASSED [ 11%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-vllm] SKIPPED (Other inference providers don't support completion() yet) [ 22%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_logprobs[-vllm] SKIPPED (Other inference providers don't support completion() yet) [ 33%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-vllm] SKIPPED (This test is not quite robust) [ 44%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-vllm] PASSED [ 55%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-vllm] SKIPPED (Other inference providers don't support structured output yet) [ 66%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-vllm] PASSED [ 77%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-vllm] PASSED [ 88%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-vllm] PASSED [100%] ======================================================== 5 passed, 4 skipped, 2 warnings in 25.56s ======================================================== Task was destroyed but it is pending! task: <Task pending name='Task-6' coro=<AsyncLLMEngine.run_engine_loop() running at /mnt/datadisk1/freiss/llama/env/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py:848> cb=[_log_task_completion(error_callback=<bound method...7cfc479440b0>>)() at /mnt/datadisk1/freiss/llama/env/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py:45, shield.<locals>._inner_done_callback() at /mnt/datadisk1/freiss/llama/env/lib/python3.12/asyncio/tasks.py:905]> [rank0]:[W1219 11:38:34.689424319 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator()) ``` The warning about "asyncio_default_fixture_loop_scope" appears to be due to my environment having a newer version of pytest-asyncio. The warning about a pending task appears to be due to a bug in `vllm.AsyncLLMEngine.shutdown_background_loop()`. It looks like that method returns without stopping a pending task. I will look into that issue separately. ## Sources ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [X] Ran pre-commit to handle lint / formatting issues. - [X] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [X] Wrote necessary unit or integration tests.
483 lines
17 KiB
Python
483 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
|
|
from llama_models.llama3.api.datatypes import (
|
|
SamplingParams,
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolParamDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
SystemMessage,
|
|
ToolCallDelta,
|
|
ToolCallParseStatus,
|
|
ToolChoice,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from .utils import group_chunks
|
|
|
|
|
|
# How to run this test:
|
|
#
|
|
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py
|
|
# -m "(fireworks or ollama) and llama_3b"
|
|
# --env FIREWORKS_API_KEY=<your_api_key>
|
|
|
|
|
|
def get_expected_stop_reason(model: str):
|
|
return (
|
|
StopReason.end_of_message
|
|
if ("Llama3.1" in model or "Llama-3.1" in model)
|
|
else StopReason.end_of_turn
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def common_params(inference_model):
|
|
return {
|
|
"tool_choice": ToolChoice.auto,
|
|
"tool_prompt_format": (
|
|
ToolPromptFormat.json
|
|
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
|
|
else ToolPromptFormat.python_list
|
|
),
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_messages():
|
|
return [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="What's the weather like today?"),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_tool_definition():
|
|
return ToolDefinition(
|
|
tool_name="get_weather",
|
|
description="Get the current weather",
|
|
parameters={
|
|
"location": ToolParamDefinition(
|
|
param_type="string",
|
|
description="The city and state, e.g. San Francisco, CA",
|
|
),
|
|
},
|
|
)
|
|
|
|
|
|
class TestInference:
|
|
# Session scope for asyncio because the tests in this class all
|
|
# share the same provider instance.
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_model_list(self, inference_model, inference_stack):
|
|
_, models_impl = inference_stack
|
|
response = await models_impl.list_models()
|
|
assert isinstance(response, list)
|
|
assert len(response) >= 1
|
|
assert all(isinstance(model, Model) for model in response)
|
|
|
|
model_def = None
|
|
for model in response:
|
|
if model.identifier == inference_model:
|
|
model_def = model
|
|
break
|
|
|
|
assert model_def is not None
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_completion(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::fireworks",
|
|
"remote::nvidia",
|
|
"remote::cerebras",
|
|
):
|
|
pytest.skip("Other inference providers don't support completion() yet")
|
|
|
|
response = await inference_impl.completion(
|
|
content="Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
)
|
|
|
|
assert isinstance(response, CompletionResponse)
|
|
assert "1963" in response.content
|
|
|
|
chunks = [
|
|
r
|
|
async for r in await inference_impl.completion(
|
|
content="Roses are red,",
|
|
stream=True,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
)
|
|
]
|
|
|
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
|
assert len(chunks) >= 1
|
|
last = chunks[-1]
|
|
assert last.stop_reason == StopReason.out_of_tokens
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_completion_logprobs(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
# "remote::nvidia", -- provider doesn't provide all logprobs
|
|
):
|
|
pytest.skip("Other inference providers don't support completion() yet")
|
|
|
|
response = await inference_impl.completion(
|
|
content="Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=5,
|
|
),
|
|
logprobs=LogProbConfig(
|
|
top_k=3,
|
|
),
|
|
)
|
|
|
|
assert isinstance(response, CompletionResponse)
|
|
assert 1 <= len(response.logprobs) <= 5
|
|
assert response.logprobs, "Logprobs should not be empty"
|
|
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
|
|
|
chunks = [
|
|
r
|
|
async for r in await inference_impl.completion(
|
|
content="Roses are red,",
|
|
stream=True,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=5,
|
|
),
|
|
logprobs=LogProbConfig(
|
|
top_k=3,
|
|
),
|
|
)
|
|
]
|
|
|
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
|
assert (
|
|
1 <= len(chunks) <= 6
|
|
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
|
for chunk in chunks:
|
|
if chunk.delta: # if there's a token, we expect logprobs
|
|
assert chunk.logprobs, "Logprobs should not be empty"
|
|
assert all(
|
|
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
|
)
|
|
else: # no token, no logprobs
|
|
assert not chunk.logprobs, "Logprobs should be empty"
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
@pytest.mark.skip("This test is not quite robust")
|
|
async def test_completion_structured_output(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::fireworks",
|
|
"remote::nvidia",
|
|
"remote::vllm",
|
|
"remote::cerebras",
|
|
):
|
|
pytest.skip(
|
|
"Other inference providers don't support structured output in completions yet"
|
|
)
|
|
|
|
class Output(BaseModel):
|
|
name: str
|
|
year_born: str
|
|
year_retired: str
|
|
|
|
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
|
response = await inference_impl.completion(
|
|
model_id=inference_model,
|
|
content=user_input,
|
|
stream=False,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
response_format=JsonSchemaResponseFormat(
|
|
json_schema=Output.model_json_schema(),
|
|
),
|
|
)
|
|
assert isinstance(response, CompletionResponse)
|
|
assert isinstance(response.content, str)
|
|
|
|
answer = Output.model_validate_json(response.content)
|
|
assert answer.name == "Michael Jordan"
|
|
assert answer.year_born == "1963"
|
|
assert answer.year_retired == "2003"
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_non_streaming(
|
|
self, inference_model, inference_stack, common_params, sample_messages
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=sample_messages,
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert response.completion_message.role == "assistant"
|
|
assert isinstance(response.completion_message.content, str)
|
|
assert len(response.completion_message.content) > 0
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_structured_output(
|
|
self, inference_model, inference_stack, common_params
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::fireworks",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::vllm",
|
|
"remote::nvidia",
|
|
):
|
|
pytest.skip("Other inference providers don't support structured output yet")
|
|
|
|
class AnswerFormat(BaseModel):
|
|
first_name: str
|
|
last_name: str
|
|
year_of_birth: int
|
|
num_seasons_in_nba: int
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=[
|
|
# we include context about Michael Jordan in the prompt so that the test is
|
|
# focused on the funtionality of the model and not on the information embedded
|
|
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
|
|
SystemMessage(
|
|
content=(
|
|
"You are a helpful assistant.\n\n"
|
|
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
|
|
)
|
|
),
|
|
UserMessage(content="Please give me information about Michael Jordan."),
|
|
],
|
|
stream=False,
|
|
response_format=JsonSchemaResponseFormat(
|
|
json_schema=AnswerFormat.model_json_schema(),
|
|
),
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert response.completion_message.role == "assistant"
|
|
assert isinstance(response.completion_message.content, str)
|
|
|
|
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
|
assert answer.first_name == "Michael"
|
|
assert answer.last_name == "Jordan"
|
|
assert answer.year_of_birth == 1963
|
|
assert answer.num_seasons_in_nba == 15
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=[
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="Please give me information about Michael Jordan."),
|
|
],
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert isinstance(response.completion_message.content, str)
|
|
|
|
with pytest.raises(ValidationError):
|
|
AnswerFormat.model_validate_json(response.completion_message.content)
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_streaming(
|
|
self, inference_model, inference_stack, common_params, sample_messages
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
response = [
|
|
r
|
|
async for r in await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=sample_messages,
|
|
stream=True,
|
|
**common_params,
|
|
)
|
|
]
|
|
|
|
assert len(response) > 0
|
|
assert all(
|
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
|
)
|
|
grouped = group_chunks(response)
|
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
|
|
|
end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
assert end.event.stop_reason == StopReason.end_of_turn
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_with_tool_calling(
|
|
self,
|
|
inference_model,
|
|
inference_stack,
|
|
common_params,
|
|
sample_messages,
|
|
sample_tool_definition,
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
|
pytest.skip(
|
|
provider.__provider_spec__.provider_type
|
|
+ " doesn't support tool calling yet"
|
|
)
|
|
|
|
inference_impl, _ = inference_stack
|
|
messages = sample_messages + [
|
|
UserMessage(
|
|
content="What's the weather like in San Francisco?",
|
|
)
|
|
]
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=messages,
|
|
tools=[sample_tool_definition],
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
|
|
message = response.completion_message
|
|
|
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
|
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
|
# assert message.stop_reason == stop_reason
|
|
assert message.tool_calls is not None
|
|
assert len(message.tool_calls) > 0
|
|
|
|
call = message.tool_calls[0]
|
|
assert call.tool_name == "get_weather"
|
|
assert "location" in call.arguments
|
|
assert "San Francisco" in call.arguments["location"]
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_chat_completion_with_tool_calling_streaming(
|
|
self,
|
|
inference_model,
|
|
inference_stack,
|
|
common_params,
|
|
sample_messages,
|
|
sample_tool_definition,
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
|
pytest.skip(
|
|
provider.__provider_spec__.provider_type
|
|
+ " doesn't support tool calling yet"
|
|
)
|
|
|
|
messages = sample_messages + [
|
|
UserMessage(
|
|
content="What's the weather like in San Francisco?",
|
|
)
|
|
]
|
|
|
|
response = [
|
|
r
|
|
async for r in await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=messages,
|
|
tools=[sample_tool_definition],
|
|
stream=True,
|
|
**common_params,
|
|
)
|
|
]
|
|
|
|
assert len(response) > 0
|
|
assert all(
|
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
|
)
|
|
grouped = group_chunks(response)
|
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
|
|
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
|
# expected_stop_reason = get_expected_stop_reason(
|
|
# inference_settings["common_params"]["model"]
|
|
# )
|
|
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
# assert end.event.stop_reason == expected_stop_reason
|
|
|
|
if "Llama3.1" in inference_model:
|
|
assert all(
|
|
isinstance(chunk.event.delta, ToolCallDelta)
|
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
|
)
|
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
|
if not isinstance(
|
|
first.event.delta.content, ToolCall
|
|
): # first chunk may contain entire call
|
|
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
|
|
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
|
# assert last.event.stop_reason == expected_stop_reason
|
|
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
|
assert isinstance(last.event.delta.content, ToolCall)
|
|
|
|
call = last.event.delta.content
|
|
assert call.tool_name == "get_weather"
|
|
assert "location" in call.arguments
|
|
assert "San Francisco" in call.arguments["location"]
|