From 92ee627e89edc2070259194477674696b6af524e Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 13:59:46 -0800 Subject: [PATCH] vllm --- .../providers/remote/inference/vllm/vllm.py | 68 +++++++++---------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 3a8b8c326..c49541fd9 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -8,13 +8,17 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models, resolve_model +from llama_models.sku_list import all_registered_models from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -30,8 +34,24 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): +def build_model_aliases(): + return [ + ModelAlias( + provider_model_id=model.huggingface_repo, + aliases=[model.descriptor()], + llama_model=model.descriptor(), + ) + for model in all_registered_models() + if model.huggingface_repo + ] + + +class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=build_model_aliases(), + ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None @@ -44,31 +64,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def initialize(self) -> None: self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) - async def register_model(self, model: Model) -> None: - pass - - async def shutdown(self) -> None: - pass - - async def list_models(self) -> List[Model]: - models = [] - for model in self.client.models.list(): - repo = model.id - if repo not in self.huggingface_repo_to_llama_model_id: - print(f"Unknown model served by vllm: {repo}") - continue - - identifier = self.huggingface_repo_to_llama_model_id[repo] - if identifier == model.provider_resource_id: - print( - f"Verified that model {model.provider_resource_id} is being served by vLLM" - ) - return - - raise ValueError( - f"Model {model.provider_resource_id} is not being served by vLLM" - ) - async def shutdown(self) -> None: pass @@ -95,8 +90,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -148,10 +144,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens - model = resolve_model(request.model) - if model is None: - raise ValueError(f"Unknown model: {request.model}") - input_dict = {} media_present = request_has_media(request) if isinstance(request, ChatCompletionRequest): @@ -163,16 +155,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( not media_present ), "Together does not support media for Completion requests" - input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = completion_request_to_prompt( + request, + self.get_llama_model(request.model), + self.formatter, + ) return { - "model": model.huggingface_repo, + "model": request.model, **input_dict, "stream": request.stream, **options,