diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index 09c17ee57..0259c7061 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import all_registered_models, resolve_model from openai import OpenAI @@ -26,40 +27,16 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -VLLM_SUPPORTED_MODELS = { - "Llama3.1-8B": "meta-llama/Llama-3.1-8B", - "Llama3.1-70B": "meta-llama/Llama-3.1-70B", - "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", - "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", - "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", - "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", - "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", - "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", - "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.2-1B": "meta-llama/Llama-3.2-1B", - "Llama3.2-3B": "meta-llama/Llama-3.2-3B", - "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", - "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", - "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", - "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", - "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", - "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", - "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", - "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", - "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", -} - - class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None + self.huggingface_repo_to_llama_model_id = { + model.huggingface_repo: model.descriptor() + for model in all_registered_models() + if model.huggingface_repo + } async def initialize(self) -> None: self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) @@ -71,15 +48,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): pass async def list_models(self) -> List[ModelDef]: - vllm_to_llama_map = {v: k for k, v in VLLM_SUPPORTED_MODELS.items()} - models = [] for model in self.client.models.list(): - if model.id not in vllm_to_llama_map: - print(f"Unknown model served by vllm: {model.id}") + repo = model.id + if repo not in self.huggingface_repo_to_llama_model_id: + print(f"Unknown model served by vllm: {repo}") continue - identifier = vllm_to_llama_map[model.id] + identifier = self.huggingface_repo_to_llama_model_id[repo] models.append( ModelDef( identifier=identifier, @@ -155,8 +131,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): options = get_sampling_options(request.sampling_params) if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens + + model = resolve_model(request.model) + if model is None: + raise ValueError(f"Unknown model: {request.model}") + return { - "model": VLLM_SUPPORTED_MODELS[request.model], + "model": model.huggingface_repo, "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, **options,