From b10e9f46bb0854e94ae108499e7e0ba085def890 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 6 Nov 2024 14:42:44 -0800 Subject: [PATCH] Enable remote::vllm (#384) * Enable remote::vllm * Kill the giant list of hard coded models --- .../adapters/inference/vllm/__init__.py | 11 +-- .../adapters/inference/vllm/config.py | 8 ++- .../providers/adapters/inference/vllm/vllm.py | 71 +++++++++---------- llama_stack/providers/registry/inference.py | 18 ++--- .../providers/tests/inference/fixtures.py | 25 ++++++- 5 files changed, 80 insertions(+), 53 deletions(-) diff --git a/llama_stack/providers/adapters/inference/vllm/__init__.py b/llama_stack/providers/adapters/inference/vllm/__init__.py index f4588a307..78222d7d9 100644 --- a/llama_stack/providers/adapters/inference/vllm/__init__.py +++ b/llama_stack/providers/adapters/inference/vllm/__init__.py @@ -4,12 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import VLLMImplConfig -from .vllm import VLLMInferenceAdapter +from .config import VLLMInferenceAdapterConfig -async def get_adapter_impl(config: VLLMImplConfig, _deps): - assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}" +async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): + from .vllm import VLLMInferenceAdapter + + assert isinstance( + config, VLLMInferenceAdapterConfig + ), f"Unexpected config type: {type(config)}" impl = VLLMInferenceAdapter(config) await impl.initialize() return impl diff --git a/llama_stack/providers/adapters/inference/vllm/config.py b/llama_stack/providers/adapters/inference/vllm/config.py index 65815922c..50a174589 100644 --- a/llama_stack/providers/adapters/inference/vllm/config.py +++ b/llama_stack/providers/adapters/inference/vllm/config.py @@ -11,12 +11,16 @@ from pydantic import BaseModel, Field @json_schema_type -class VLLMImplConfig(BaseModel): +class VLLMInferenceAdapterConfig(BaseModel): url: Optional[str] = Field( default=None, description="The URL for the vLLM model serving endpoint", ) + max_tokens: int = Field( + default=4096, + description="Maximum number of tokens to generate.", + ) api_token: Optional[str] = Field( - default=None, + default="fake", description="The API token", ) diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index aad2fdc1f..0259c7061 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import all_registered_models, resolve_model from openai import OpenAI @@ -23,42 +24,19 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, ) -from .config import VLLMImplConfig - -VLLM_SUPPORTED_MODELS = { - "Llama3.1-8B": "meta-llama/Llama-3.1-8B", - "Llama3.1-70B": "meta-llama/Llama-3.1-70B", - "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", - "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", - "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", - "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", - "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", - "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", - "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.2-1B": "meta-llama/Llama-3.2-1B", - "Llama3.2-3B": "meta-llama/Llama-3.2-3B", - "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", - "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", - "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", - "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", - "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", - "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", - "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", - "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", - "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", -} +from .config import VLLMInferenceAdapterConfig class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): - def __init__(self, config: VLLMImplConfig) -> None: + def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None + self.huggingface_repo_to_llama_model_id = { + model.huggingface_repo: model.descriptor() + for model in all_registered_models() + if model.huggingface_repo + } async def initialize(self) -> None: self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) @@ -70,10 +48,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): pass async def list_models(self) -> List[ModelDef]: - return [ - ModelDef(identifier=model.id, llama_model=model.id) - for model in self.client.models.list() - ] + models = [] + for model in self.client.models.list(): + repo = model.id + if repo not in self.huggingface_repo_to_llama_model_id: + print(f"Unknown model served by vllm: {repo}") + continue + + identifier = self.huggingface_repo_to_llama_model_id[repo] + models.append( + ModelDef( + identifier=identifier, + llama_model=identifier, + ) + ) + return models async def completion( self, @@ -118,7 +107,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(request, r, self.formatter) + return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI @@ -139,11 +128,19 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: + options = get_sampling_options(request.sampling_params) + if "max_tokens" not in options: + options["max_tokens"] = self.config.max_tokens + + model = resolve_model(request.model) + if model is None: + raise ValueError(f"Unknown model: {request.model}") + return { - "model": VLLM_SUPPORTED_MODELS[request.model], + "model": model.huggingface_repo, "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request.sampling_params), + **options, } async def embeddings( diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b4..717ff78a8 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -61,15 +61,15 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.adapters.inference.ollama", ), ), - # remote_provider_spec( - # api=Api.inference, - # adapter=AdapterSpec( - # adapter_type="vllm", - # pip_packages=["openai"], - # module="llama_stack.providers.adapters.inference.vllm", - # config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig", - # ), - # ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="vllm", + pip_packages=["openai"], + module="llama_stack.providers.adapters.inference.vllm", + config_class="llama_stack.providers.adapters.inference.vllm.VLLMInferenceAdapterConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 896acbad8..acff151cf 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig from llama_stack.providers.adapters.inference.together import TogetherImplConfig +from llama_stack.providers.adapters.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.impls.meta_reference.inference import ( MetaReferenceInferenceConfig, ) @@ -78,6 +79,21 @@ def inference_ollama(inference_model) -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_vllm_remote() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="remote::vllm", + provider_type="remote::vllm", + config=VLLMInferenceAdapterConfig( + url=get_env_or_fail("VLLM_URL"), + ).model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def inference_fireworks() -> ProviderFixture: return ProviderFixture( @@ -109,7 +125,14 @@ def inference_together() -> ProviderFixture: ) -INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"] +INFERENCE_FIXTURES = [ + "meta_reference", + "ollama", + "fireworks", + "together", + "vllm_remote", + "remote", +] @pytest_asyncio.fixture(scope="session")