mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Add vLLM inference provider for OpenAI compatible vLLM server (#178)
This PR adds vLLM inference provider for OpenAI compatible vLLM server.
This commit is contained in:
parent
59c43736e8
commit
a27a2cd2af
6 changed files with 209 additions and 1 deletions
|
@ -7,4 +7,4 @@ distribution_spec:
|
|||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
image_type: conda
|
10
llama_stack/distribution/templates/remote-vllm-build.yaml
Normal file
10
llama_stack/distribution/templates/remote-vllm-build.yaml
Normal file
|
@ -0,0 +1,10 @@
|
|||
name: remote-vllm
|
||||
distribution_spec:
|
||||
description: Use remote vLLM for running LLM inference
|
||||
providers:
|
||||
inference: remote::vllm
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: docker
|
15
llama_stack/providers/adapters/inference/vllm/__init__.py
Normal file
15
llama_stack/providers/adapters/inference/vllm/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import VLLMImplConfig
|
||||
from .vllm import VLLMInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VLLMImplConfig, _deps):
|
||||
assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = VLLMInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
22
llama_stack/providers/adapters/inference/vllm/config.py
Normal file
22
llama_stack/providers/adapters/inference/vllm/config.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMImplConfig(BaseModel):
|
||||
url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The URL for the vLLM model serving endpoint",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The API token",
|
||||
)
|
152
llama_stack/providers/adapters/inference/vllm/vllm.py
Normal file
152
llama_stack/providers/adapters/inference/vllm/vllm.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMImplConfig
|
||||
|
||||
VLLM_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
||||
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
|
||||
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
|
||||
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
|
||||
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
|
||||
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
|
||||
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
|
||||
}
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: VLLMImplConfig) -> None:
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Model registration is not supported for vLLM models")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return [
|
||||
ModelDef(identifier=model.id, llama_model=model.id)
|
||||
for model in self.client.models.list()
|
||||
]
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, self.client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
|
||||
# generator so this wrapper is not necessary?
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": VLLM_SUPPORTED_MODELS[request.model],
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
|
@ -60,6 +60,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.adapters.inference.ollama",
|
||||
),
|
||||
),
|
||||
# remote_provider_spec(
|
||||
# api=Api.inference,
|
||||
# adapter=AdapterSpec(
|
||||
# adapter_type="vllm",
|
||||
# pip_packages=["openai"],
|
||||
# module="llama_stack.providers.adapters.inference.vllm",
|
||||
# config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig",
|
||||
# ),
|
||||
# ),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue