From ad4e65e8764ece0290fb655312bc0732855d861e Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 3 Oct 2024 22:05:07 -0400 Subject: [PATCH] Fixes --- .../templates/build_configs/local-vllm-build.yaml | 2 +- .../providers/adapters/inference/vllm/__init__.py | 12 +++++------- .../providers/adapters/inference/vllm/vllm.py | 4 ++-- llama_stack/providers/registry/inference.py | 1 + 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/llama_stack/distribution/templates/build_configs/local-vllm-build.yaml b/llama_stack/distribution/templates/build_configs/local-vllm-build.yaml index e907cb7c9..e333a137b 100644 --- a/llama_stack/distribution/templates/build_configs/local-vllm-build.yaml +++ b/llama_stack/distribution/templates/build_configs/local-vllm-build.yaml @@ -7,4 +7,4 @@ distribution_spec: safety: meta-reference agents: meta-reference telemetry: meta-reference -image_type: conda +image_type: conda \ No newline at end of file diff --git a/llama_stack/providers/adapters/inference/vllm/__init__.py b/llama_stack/providers/adapters/inference/vllm/__init__.py index 146020d97..bf3f671a1 100644 --- a/llama_stack/providers/adapters/inference/vllm/__init__.py +++ b/llama_stack/providers/adapters/inference/vllm/__init__.py @@ -4,17 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import DatabricksImplConfig -from .vllm import InferenceEndpointAdapter, VLLMAdapter +from .config import VLLMImplConfig +from .vllm import VLLMInferenceAdapter -async def get_adapter_impl(config: DatabricksImplConfig, _deps): - assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}" +async def get_adapter_impl(config: VLLMImplConfig, _deps): + assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}" if config.url is not None: - impl = VLLMAdapter(config) - elif config.is_inference_endpoint(): - impl = InferenceEndpointAdapter(config) + impl = VLLMInferenceAdapter(config) else: raise ValueError( "Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)." diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index 9df94d94d..050f173a3 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -22,8 +22,8 @@ from .config import VLLMImplConfig # Reference: https://docs.vllm.ai/en/latest/models/supported_models.html VLLM_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", - "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", + # "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", + # "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", } diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index b4f3b137a..8885d135a 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -66,6 +66,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="vllm", pip_packages=["openai"], module="llama_stack.providers.adapters.inference.vllm", + config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig", ), ), remote_provider_spec(