forked from phoenix-oss/llama-stack-mirror
Add vLLM raw completions API (#823)
# What does this PR do? Adds raw completions API to vLLM ## Test Plan <details> <summary>Setup</summary> ```bash # Run vllm server conda create -n vllm python=3.12 -y conda activate vllm pip install vllm # Run llamastack conda create --name llamastack-vllm python=3.10 conda activate llamastack-vllm export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct && \ pip install -e . && \ pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==0.1.0rc7 && \ llama stack build --template remote-vllm --image-type conda && \ llama stack run ./distributions/remote-vllm/run.yaml \ --port 5000 \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env VLLM_URL=http://localhost:8000/v1 | tee -a llama-stack.log ``` </details> <details> <summary>Integration</summary> ```bash # Run conda activate llamastack-vllm export VLLM_URL=http://localhost:8000/v1 pip install pytest pytest_html pytest_asyncio aiosqlite pytest llama_stack/providers/tests/inference/test_text_inference.py -v -k vllm # Results llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-vllm_remote] PASSED [ 11%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-vllm_remote] PASSED [ 22%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_logprobs[-vllm_remote] SKIPPED [ 33%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-vllm_remote] SKIPPED [ 44%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-vllm_remote] PASSED [ 55%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-vllm_remote] PASSED [ 66%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-vllm_remote] PASSED [ 77%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-vllm_remote] PASSED [ 88%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-vllm_remote] PASSED [100%] ====================================== 7 passed, 2 skipped, 99 deselected, 1 warning in 9.80s ====================================== ``` </details> <details> <summary>Manual</summary> ```bash # Install pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==0.1.0rc7 ``` Apply this diff ```diff diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8dbb193..95173e2 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -250,7 +250,7 @@ class ClientVersionMiddleware: server_version_parts = tuple( map(int, self.server_version.split(".")[:2]) ) - if client_version_parts != server_version_parts: + if False and client_version_parts != server_version_parts: async def send_version_error(send): await send( diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 4eac4da..32eb50e 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -94,7 +94,8 @@ metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db models: -- metadata: {} +- metadata: + llama_model: meta-llama/Llama-3.2-3B-Instruct model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm ``` Test 1: ```python from llama_stack_client import LlamaStackClient client = LlamaStackClient( base_url="http://localhost:5000", ) response = client.inference.completion( model_id="meta-llama/Llama-3.2-3B-Instruct", content="Hello, world client!", ) print(response) ``` Test 2 ``` from llama_stack_client import LlamaStackClient client = LlamaStackClient( base_url="http://localhost:5000", ) response = client.inference.completion( model_id="meta-llama/Llama-3.2-3B-Instruct", content="Hello, world client!", stream=True, ) for chunk in response: print(chunk.delta, end="", flush=True) ``` ``` I'm excited to introduce you to our latest project, a comprehensive guide to the best coffee shops in [City]. As a coffee connoisseur, you're in luck because we've scoured the city to bring you the top picks for the perfect cup of joe. In this guide, we'll take you on a journey through the city's most iconic coffee shops, highlighting their unique features, must-try drinks, and insider tips from the baristas themselves. From cozy cafes to trendy cafes, we've got you covered. **Top 5 Coffee Shops in [City]** 1. **The Daily Grind**: This beloved institution has been serving up expertly crafted pour-overs and lattes for over 10 years. Their expert baristas are always happy to guide you through their menu, which features a rotating selection of single-origin beans from around the world... ``` </details> ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
4d7c8c797f
commit
910717c1fd
2 changed files with 36 additions and 1 deletions
|
@ -41,6 +41,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
|
@ -92,7 +94,19 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError("Completion not implemented for vLLM")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -154,6 +168,26 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self.client.completions.create(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# Wrapper for async generator similar
|
||||
async def _to_async_generator():
|
||||
stream = self.client.completions.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = self.client.models.list()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue