llama-stack-mirror/llama_stack/providers/tests/inference/test_text_inference.py
Aidan Do 910717c1fd
Add vLLM raw completions API (#823)
# What does this PR do?

Adds raw completions API to vLLM

## Test Plan

<details>
<summary>Setup</summary>

```bash
# Run vllm server
conda create -n vllm python=3.12 -y
conda activate vllm
pip install vllm

# Run llamastack
conda create --name llamastack-vllm python=3.10
conda activate llamastack-vllm

export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct && \
pip install -e . && \
pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==0.1.0rc7 && \
llama stack build --template remote-vllm --image-type conda && \
llama stack run ./distributions/remote-vllm/run.yaml \
  --port 5000 \
  --env INFERENCE_MODEL=$INFERENCE_MODEL \
  --env VLLM_URL=http://localhost:8000/v1 | tee -a llama-stack.log
```
</details>

<details>
<summary>Integration</summary>

```bash
# Run
conda activate llamastack-vllm
export VLLM_URL=http://localhost:8000/v1
pip install pytest pytest_html pytest_asyncio aiosqlite
pytest llama_stack/providers/tests/inference/test_text_inference.py -v -k vllm

# Results
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-vllm_remote] PASSED            [ 11%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-vllm_remote] PASSED            [ 22%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_logprobs[-vllm_remote] SKIPPED  [ 33%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-vllm_remote] SKIPPED [ 44%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-vllm_remote] PASSED [ 55%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-vllm_remote] PASSED     [ 66%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-vllm_remote] PASSED [ 77%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-vllm_remote] PASSED [ 88%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-vllm_remote] PASSED [100%]

====================================== 7 passed, 2 skipped, 99 deselected, 1 warning in 9.80s ======================================
```
</details>

<details>
<summary>Manual</summary>

```bash
# Install
pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==0.1.0rc7
```

Apply this diff
```diff
diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py
index 8dbb193..95173e2 100644
--- a/llama_stack/distribution/server/server.py
+++ b/llama_stack/distribution/server/server.py
@@ -250,7 +250,7 @@ class ClientVersionMiddleware:
                     server_version_parts = tuple(
                         map(int, self.server_version.split(".")[:2])
                     )
-                    if client_version_parts != server_version_parts:
+                    if False and client_version_parts != server_version_parts:
 
                         async def send_version_error(send):
                             await send(
diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml
index 4eac4da..32eb50e 100644
--- a/llama_stack/templates/remote-vllm/run.yaml
+++ b/llama_stack/templates/remote-vllm/run.yaml
@@ -94,7 +94,8 @@ metadata_store:
   type: sqlite
   db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
 models:
-- metadata: {}
+- metadata: 
+    llama_model: meta-llama/Llama-3.2-3B-Instruct
   model_id: ${env.INFERENCE_MODEL}
   provider_id: vllm-inference
   model_type: llm
```

Test 1:

```python
from llama_stack_client import LlamaStackClient

client = LlamaStackClient(
    base_url="http://localhost:5000",
)

response = client.inference.completion(
    model_id="meta-llama/Llama-3.2-3B-Instruct",
    content="Hello, world client!",
)

print(response)
```

Test 2

```
from llama_stack_client import LlamaStackClient

client = LlamaStackClient(
    base_url="http://localhost:5000",
)

response = client.inference.completion(
    model_id="meta-llama/Llama-3.2-3B-Instruct",
    content="Hello, world client!",
    stream=True,
)

for chunk in response:
    print(chunk.delta, end="", flush=True)
```

```
 I'm excited to introduce you to our latest project, a comprehensive guide to the best coffee shops in [City]. As a coffee connoisseur, you're in luck because we've scoured the city to bring you the top picks for the perfect cup of joe.

In this guide, we'll take you on a journey through the city's most iconic coffee shops, highlighting their unique features, must-try drinks, and insider tips from the baristas themselves. From cozy cafes to trendy cafes, we've got you covered.

**Top 5 Coffee Shops in [City]**

1. **The Daily Grind**: This beloved institution has been serving up expertly crafted pour-overs and lattes for over 10 years. Their expert baristas are always happy to guide you through their menu, which features a rotating selection of single-origin beans from around the world...
```

</details>

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2025-01-22 22:58:27 -08:00

486 lines
17 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import ListModelsResponse, Model
from .utils import group_chunks
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py
# -m "(fireworks or ollama) and llama_3b"
# --env FIREWORKS_API_KEY=<your_api_key>
def get_expected_stop_reason(model: str):
return (
StopReason.end_of_message
if ("Llama3.1" in model or "Llama-3.1" in model)
else StopReason.end_of_turn
)
@pytest.fixture
def common_params(inference_model):
return {
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
else ToolPromptFormat.python_list
),
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
class TestInference:
# Session scope for asyncio because the tests in this class all
# share the same provider instance.
@pytest.mark.asyncio(loop_scope="session")
async def test_model_list(self, inference_model, inference_stack):
_, models_impl = inference_stack
response = await models_impl.list_models()
assert isinstance(response, ListModelsResponse)
assert isinstance(response.data, list)
assert len(response.data) >= 1
assert all(isinstance(model, Model) for model in response.data)
model_def = None
for model in response.data:
if model.identifier == inference_model:
model_def = model
break
assert model_def is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::cerebras",
"remote::vllm",
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) >= 1
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_logprobs(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
# "remote::nvidia", -- provider doesn't provide all logprobs
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
assert isinstance(response, CompletionResponse)
assert 1 <= len(response.logprobs) <= 5
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert (
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_structured_output(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::vllm",
"remote::cerebras",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
model_id=inference_model,
content=user_input,
stream=False,
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
answer = Output.model_validate_json(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(
self, inference_model, inference_stack, common_params
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::vllm",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
# we include context about Michael Jordan in the prompt so that the test is
# focused on the funtionality of the model and not on the information embedded
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
SystemMessage(
content=(
"You are a helpful assistant.\n\n"
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
)
),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.model_validate_json(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling_streaming(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
if "Llama3.1" in inference_model:
assert all(
chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance(
first.event.delta.tool_call, ToolCall
): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert isinstance(last.event.delta.tool_call, ToolCall)
call = last.event.delta.tool_call
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]