Enable remote::vllm

This commit is contained in:
Ashwin Bharambe 2024-11-06 14:11:31 -08:00
parent 6ebd553da5
commit 6deeee9b87
5 changed files with 70 additions and 24 deletions

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import VLLMImplConfig
from .vllm import VLLMInferenceAdapter
from .config import VLLMInferenceAdapterConfig
async def get_adapter_impl(config: VLLMImplConfig, _deps):
assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}"
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter
assert isinstance(
config, VLLMInferenceAdapterConfig
), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -11,12 +11,16 @@ from pydantic import BaseModel, Field
@json_schema_type
class VLLMImplConfig(BaseModel):
class VLLMInferenceAdapterConfig(BaseModel):
url: Optional[str] = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
api_token: Optional[str] = Field(
default=None,
default="fake",
description="The API token",
)

View file

@ -23,7 +23,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMImplConfig
from .config import VLLMInferenceAdapterConfig
VLLM_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
@ -55,7 +56,7 @@ VLLM_SUPPORTED_MODELS = {
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMImplConfig) -> None:
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
@ -70,10 +71,22 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
pass
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(identifier=model.id, llama_model=model.id)
for model in self.client.models.list()
]
vllm_to_llama_map = {v: k for k, v in VLLM_SUPPORTED_MODELS.items()}
models = []
for model in self.client.models.list():
if model.id not in vllm_to_llama_map:
print(f"Unknown model served by vllm: {model.id}")
continue
identifier = vllm_to_llama_map[model.id]
models.append(
ModelDef(
identifier=identifier,
llama_model=identifier,
)
)
return models
async def completion(
self,
@ -118,7 +131,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
@ -139,11 +152,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
return {
"model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
**options,
}
async def embeddings(