From c085886cdb8fd2c4cd9f3ba546b0703e157a1391 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Fri, 11 Oct 2024 16:08:45 -0400 Subject: [PATCH] Address feedback Signed-off-by: Yuan Tang --- .../providers/adapters/inference/vllm/vllm.py | 62 ++++++++++++------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index dcc75ccba..a5934928a 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -8,11 +8,11 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -25,36 +25,54 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMImplConfig +VLLM_SUPPORTED_MODELS = { + "Llama3.1-8B": "meta-llama/Llama-3.1-8B", + "Llama3.1-70B": "meta-llama/Llama-3.1-70B", + "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", + "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", + "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", + "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", + "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", + "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", + "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", + "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", + "Llama3.2-1B": "meta-llama/Llama-3.2-1B", + "Llama3.2-3B": "meta-llama/Llama-3.2-3B", + "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", + "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", + "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", + "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", + "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", + "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", + "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", + "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", + "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", + "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", + "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", + "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", +} -class VLLMInferenceAdapter(Inference): - model_id: str +class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMImplConfig) -> None: - self.huggingface_repo_to_llama_model_id = { - model.huggingface_repo: model.descriptor() - for model in all_registered_models() - if model.huggingface_repo - } self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) + self.client = None async def initialize(self) -> None: - return + self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) + + async def register_model(self, model: ModelDef) -> None: + raise ValueError("Model registration is not supported for vLLM models") async def shutdown(self) -> None: pass async def list_models(self) -> List[ModelDef]: - repo = self.model_id - identifier = self.huggingface_repo_to_llama_model_id[repo] return [ - ModelDef( - identifier=identifier, - llama_model=identifier, - metadata={ - "huggingface_repo": repo, - }, - ) + ModelDef(identifier=model.id, llama_model=model.id) + for model in self.client.models.list() ] def completion( @@ -88,12 +106,10 @@ class VLLMInferenceAdapter(Inference): stream=stream, logprobs=logprobs, ) - - client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) if stream: - return self._stream_chat_completion(request, client) + return self._stream_chat_completion(request, self.client) else: - return self._nonstream_chat_completion(request, client) + return self._nonstream_chat_completion(request, self.client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI @@ -122,7 +138,7 @@ class VLLMInferenceAdapter(Inference): def _get_params(self, request: ChatCompletionRequest) -> dict: return { - "model": request.model, + "model": VLLM_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, **get_sampling_options(request),