Add vLLM raw completions API (#823)

# What does this PR do?

Adds raw completions API to vLLM

## Test Plan

<details>
<summary>Setup</summary>

```bash
# Run vllm server
conda create -n vllm python=3.12 -y
conda activate vllm
pip install vllm

# Run llamastack
conda create --name llamastack-vllm python=3.10
conda activate llamastack-vllm

export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct && \
pip install -e . && \
pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==0.1.0rc7 && \
llama stack build --template remote-vllm --image-type conda && \
llama stack run ./distributions/remote-vllm/run.yaml \
  --port 5000 \
  --env INFERENCE_MODEL=$INFERENCE_MODEL \
  --env VLLM_URL=http://localhost:8000/v1 | tee -a llama-stack.log
```
</details>

<details>
<summary>Integration</summary>

```bash
# Run
conda activate llamastack-vllm
export VLLM_URL=http://localhost:8000/v1
pip install pytest pytest_html pytest_asyncio aiosqlite
pytest llama_stack/providers/tests/inference/test_text_inference.py -v -k vllm

# Results
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-vllm_remote] PASSED            [ 11%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-vllm_remote] PASSED            [ 22%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_logprobs[-vllm_remote] SKIPPED  [ 33%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-vllm_remote] SKIPPED [ 44%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-vllm_remote] PASSED [ 55%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-vllm_remote] PASSED     [ 66%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-vllm_remote] PASSED [ 77%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-vllm_remote] PASSED [ 88%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-vllm_remote] PASSED [100%]

====================================== 7 passed, 2 skipped, 99 deselected, 1 warning in 9.80s ======================================
```
</details>

<details>
<summary>Manual</summary>

```bash
# Install
pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==0.1.0rc7
```

Apply this diff
```diff
diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py
index 8dbb193..95173e2 100644
--- a/llama_stack/distribution/server/server.py
+++ b/llama_stack/distribution/server/server.py
@@ -250,7 +250,7 @@ class ClientVersionMiddleware:
                     server_version_parts = tuple(
                         map(int, self.server_version.split(".")[:2])
                     )
-                    if client_version_parts != server_version_parts:
+                    if False and client_version_parts != server_version_parts:
 
                         async def send_version_error(send):
                             await send(
diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml
index 4eac4da..32eb50e 100644
--- a/llama_stack/templates/remote-vllm/run.yaml
+++ b/llama_stack/templates/remote-vllm/run.yaml
@@ -94,7 +94,8 @@ metadata_store:
   type: sqlite
   db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
 models:
-- metadata: {}
+- metadata: 
+    llama_model: meta-llama/Llama-3.2-3B-Instruct
   model_id: ${env.INFERENCE_MODEL}
   provider_id: vllm-inference
   model_type: llm
```

Test 1:

```python
from llama_stack_client import LlamaStackClient

client = LlamaStackClient(
    base_url="http://localhost:5000",
)

response = client.inference.completion(
    model_id="meta-llama/Llama-3.2-3B-Instruct",
    content="Hello, world client!",
)

print(response)
```

Test 2

```
from llama_stack_client import LlamaStackClient

client = LlamaStackClient(
    base_url="http://localhost:5000",
)

response = client.inference.completion(
    model_id="meta-llama/Llama-3.2-3B-Instruct",
    content="Hello, world client!",
    stream=True,
)

for chunk in response:
    print(chunk.delta, end="", flush=True)
```

```
 I'm excited to introduce you to our latest project, a comprehensive guide to the best coffee shops in [City]. As a coffee connoisseur, you're in luck because we've scoured the city to bring you the top picks for the perfect cup of joe.

In this guide, we'll take you on a journey through the city's most iconic coffee shops, highlighting their unique features, must-try drinks, and insider tips from the baristas themselves. From cozy cafes to trendy cafes, we've got you covered.

**Top 5 Coffee Shops in [City]**

1. **The Daily Grind**: This beloved institution has been serving up expertly crafted pour-overs and lattes for over 10 years. Their expert baristas are always happy to guide you through their menu, which features a rotating selection of single-origin beans from around the world...
```

</details>

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Aidan Do 2025-01-23 17:58:27 +11:00 committed by GitHub
parent 4d7c8c797f
commit 910717c1fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 1 deletions

View file

@ -41,6 +41,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
@ -92,7 +94,19 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError("Completion not implemented for vLLM")
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def chat_completion(
self,
@ -154,6 +168,26 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
):
yield chunk
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self.client.completions.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
res = self.client.models.list()