From 910717c1fd8033daa4cc53b845beb648a2d71e51 Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Thu, 23 Jan 2025 17:58:27 +1100 Subject: [PATCH] Add vLLM raw completions API (#823) # What does this PR do? Adds raw completions API to vLLM ## Test Plan
Setup ```bash # Run vllm server conda create -n vllm python=3.12 -y conda activate vllm pip install vllm # Run llamastack conda create --name llamastack-vllm python=3.10 conda activate llamastack-vllm export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct && \ pip install -e . && \ pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==0.1.0rc7 && \ llama stack build --template remote-vllm --image-type conda && \ llama stack run ./distributions/remote-vllm/run.yaml \ --port 5000 \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env VLLM_URL=http://localhost:8000/v1 | tee -a llama-stack.log ```
Integration ```bash # Run conda activate llamastack-vllm export VLLM_URL=http://localhost:8000/v1 pip install pytest pytest_html pytest_asyncio aiosqlite pytest llama_stack/providers/tests/inference/test_text_inference.py -v -k vllm # Results llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-vllm_remote] PASSED [ 11%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-vllm_remote] PASSED [ 22%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_logprobs[-vllm_remote] SKIPPED [ 33%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-vllm_remote] SKIPPED [ 44%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-vllm_remote] PASSED [ 55%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-vllm_remote] PASSED [ 66%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-vllm_remote] PASSED [ 77%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-vllm_remote] PASSED [ 88%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-vllm_remote] PASSED [100%] ====================================== 7 passed, 2 skipped, 99 deselected, 1 warning in 9.80s ====================================== ```
Manual ```bash # Install pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==0.1.0rc7 ``` Apply this diff ```diff diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8dbb193..95173e2 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -250,7 +250,7 @@ class ClientVersionMiddleware: server_version_parts = tuple( map(int, self.server_version.split(".")[:2]) ) - if client_version_parts != server_version_parts: + if False and client_version_parts != server_version_parts: async def send_version_error(send): await send( diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 4eac4da..32eb50e 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -94,7 +94,8 @@ metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db models: -- metadata: {} +- metadata: + llama_model: meta-llama/Llama-3.2-3B-Instruct model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm ``` Test 1: ```python from llama_stack_client import LlamaStackClient client = LlamaStackClient( base_url="http://localhost:5000", ) response = client.inference.completion( model_id="meta-llama/Llama-3.2-3B-Instruct", content="Hello, world client!", ) print(response) ``` Test 2 ``` from llama_stack_client import LlamaStackClient client = LlamaStackClient( base_url="http://localhost:5000", ) response = client.inference.completion( model_id="meta-llama/Llama-3.2-3B-Instruct", content="Hello, world client!", stream=True, ) for chunk in response: print(chunk.delta, end="", flush=True) ``` ``` I'm excited to introduce you to our latest project, a comprehensive guide to the best coffee shops in [City]. As a coffee connoisseur, you're in luck because we've scoured the city to bring you the top picks for the perfect cup of joe. In this guide, we'll take you on a journey through the city's most iconic coffee shops, highlighting their unique features, must-try drinks, and insider tips from the baristas themselves. From cozy cafes to trendy cafes, we've got you covered. **Top 5 Coffee Shops in [City]** 1. **The Daily Grind**: This beloved institution has been serving up expertly crafted pour-overs and lattes for over 10 years. Their expert baristas are always happy to guide you through their menu, which features a rotating selection of single-origin beans from around the world... ```
## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../providers/remote/inference/vllm/vllm.py | 36 ++++++++++++++++++- .../tests/inference/test_text_inference.py | 1 + 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 1dbb4ecfa..0cf16f013 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -41,6 +41,8 @@ from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, @@ -92,7 +94,19 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - raise NotImplementedError("Completion not implemented for vLLM") + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) async def chat_completion( self, @@ -154,6 +168,26 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ): yield chunk + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = await self._get_params(request) + r = self.client.completions.create(**params) + return process_completion_response(r, self.formatter) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + + # Wrapper for async generator similar + async def _to_async_generator(): + stream = self.client.completions.create(**params) + for chunk in stream: + yield chunk + + stream = _to_async_generator() + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) res = self.client.models.list() diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index c39556b8e..e1052c289 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -118,6 +118,7 @@ class TestInference: "remote::fireworks", "remote::nvidia", "remote::cerebras", + "remote::vllm", ): pytest.skip("Other inference providers don't support completion() yet")