llama-stack/llama_stack/providers/tests/inference/test_text_inference.py
Aidan Do 095125e463
[#391] Add support for json structured output for vLLM (#528)
# What does this PR do?

Addresses issue (#391)

- Adds json structured output for vLLM
- Enables structured output tests for vLLM

> Give me a recipe for Spaghetti Bolognaise:

```json
{
  "recipe_name": "Spaghetti Bolognaise",
  "preamble": "Ah, spaghetti bolognaise - the quintessential Italian dish that fills my kitchen with the aromas of childhood nostalgia. As a child, I would watch my nonna cook up a big pot of spaghetti bolognaise every Sunday, filling our small Italian household with the savory scent of simmering meat and tomatoes. The way the sauce would thicken and the spaghetti would al dente - it was love at first bite. And now, as a chef, I want to share that same love with you, so you can recreate these warm, comforting memories at home.",
  "ingredients": [
    "500g minced beef",
    "1 medium onion, finely chopped",
    "2 cloves garlic, minced",
    "1 carrot, finely chopped",
    " celery, finely chopped",
    "1 (28 oz) can whole peeled tomatoes",
    "1 tbsp tomato paste",
    "1 tsp dried basil",
    "1 tsp dried oregano",
    "1 tsp salt",
    "1/2 tsp black pepper",
    "1/2 tsp sugar",
    "1 lb spaghetti",
    "Grated Parmesan cheese, for serving",
    "Extra virgin olive oil, for serving"
  ],
  "steps": [
    "Heat a large pot over medium heat and add a generous drizzle of extra virgin olive oil.",
    "Add the chopped onion, garlic, carrot, and celery and cook until the vegetables are soft and translucent, about 5-7 minutes.",
    "Add the minced beef and cook until browned, breaking it up with a spoon as it cooks.",
    "Add the tomato paste and cook for 1-2 minutes, stirring constantly.",
    "Add the canned tomatoes, dried basil, dried oregano, salt, black pepper, and sugar. Stir well to combine.",
    "Bring the sauce to a simmer and let it cook for 20-30 minutes, stirring occasionally, until the sauce has thickened and the flavors have melded together.",
    "While the sauce cooks, bring a large pot of salted water to a boil and cook the spaghetti according to the package instructions until al dente. Reserve 1 cup of pasta water before draining the spaghetti.",
    "Add the reserved pasta water to the sauce and stir to combine.",
    "Combine the cooked spaghetti and sauce, tossing to coat the pasta evenly.",
    "Serve hot, topped with grated Parmesan cheese and a drizzle of extra virgin olive oil.",
    "Enjoy!"
  ]
}
```

Generated with Llama-3.2-3B-Instruct model - pretty good for a 3B
parameter model 👍

## Test Plan

`pytest -v -s
llama_stack/providers/tests/inference/test_text_inference.py -k
llama_3b-vllm_remote`

With the following setup:

```bash
# Environment
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export INFERENCE_PORT=8000
export VLLM_URL=http://localhost:8000/v1

# vLLM server
sudo docker run --gpus all \
    -v $STORAGE_DIR/.cache/huggingface:/root/.cache/huggingface \
    --env "HUGGING_FACE_HUB_TOKEN=$(cat ~/.cache/huggingface/token)" \
    -p 8000:$INFERENCE_PORT \
    --ipc=host \
    --net=host \
    vllm/vllm-openai:v0.6.3.post1 \
    --model $INFERENCE_MODEL

# llama-stack server
llama stack build --template remote-vllm --image-type conda && llama stack run distributions/remote-vllm/run.yaml \
  --port 5001 \
  --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```

Results:

```
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_3b-vllm_remote] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_3b-vllm_remote] SKIPPED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[llama_3b-vllm_remote] SKIPPED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_3b-vllm_remote] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-vllm_remote] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_3b-vllm_remote] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_3b-vllm_remote] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_3b-vllm_remote] PASSED

================================ 6 passed, 2 skipped, 120 deselected, 2 warnings in 13.26s ================================
```

## Sources

- https://github.com/vllm-project/vllm/discussions/8300
- By default, vLLM uses https://github.com/dottxt-ai/outlines for
structured outputs
[[1](32e7db2536/vllm/engine/arg_utils.py (L279-L280))]

## Before submitting

[N/A] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case)

- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?

[N/A?] Updated relevant documentation. Couldn't find any relevant
documentation. Lmk if I've missed anything.

- [x] Wrote necessary unit or integration tests.
2024-12-08 15:02:51 -08:00

390 lines
13 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from pydantic import BaseModel, ValidationError
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from .utils import group_chunks
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py
# -m "(fireworks or ollama) and llama_3b"
# --env FIREWORKS_API_KEY=<your_api_key>
def get_expected_stop_reason(model: str):
return (
StopReason.end_of_message
if ("Llama3.1" in model or "Llama-3.1" in model)
else StopReason.end_of_turn
)
@pytest.fixture
def common_params(inference_model):
return {
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
else ToolPromptFormat.python_list
),
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
class TestInference:
@pytest.mark.asyncio
async def test_model_list(self, inference_model, inference_stack):
_, models_impl = inference_stack
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
assert all(isinstance(model, Model) for model in response)
model_def = None
for model in response:
if model.identifier == inference_model:
model_def = model
break
assert model_def is not None
@pytest.mark.asyncio
async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) >= 1
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(
self, inference_model, inference_stack
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::vllm",
"remote::cerebras",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
model_id=inference_model,
content=user_input,
stream=False,
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
answer = Output.model_validate_json(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_structured_output(
self, inference_model, inference_stack, common_params
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::vllm",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
# we include context about Michael Jordan in the prompt so that the test is
# focused on the funtionality of the model and not on the information embedded
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
SystemMessage(
content=(
"You are a helpful assistant.\n\n"
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
)
),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.model_validate_json(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
if "Llama3.1" in inference_model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance(
first.event.delta.content, ToolCall
): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]