From e272f8aa62ffb195216c1bd7fc013da3aa1e74d2 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 12:12:22 -0800 Subject: [PATCH] fix vllm registry --- llama_stack/providers/remote/inference/ollama/ollama.py | 3 --- llama_stack/providers/remote/inference/vllm/vllm.py | 9 +++++++++ llama_stack/providers/utils/inference/model_registry.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index e0f75fdb0..ed30b7016 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -283,10 +283,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva raise NotImplementedError() async def register_model(self, model: Model) -> Model: - # First perform the parent class's registration check model = await super().register_model(model) - - # Additional Ollama-specific check models = await self.client.ps() available_models = [m["model"] for m in models["models"]] if model.provider_resource_id not in available_models: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index e5eb6e1ea..8fc3451e9 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -131,6 +131,15 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate ): yield chunk + async def register_model(self, model: Model) -> None: + model = await super().register_model(model) + res = self.client.models.list() + available_models = [m.id for m in res] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model {model.provider_resource_id} is not being served by vLLM" + ) + async def _get_params( self, request: Union[ChatCompletionRequest, CompletionRequest] ) -> dict: diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index ae0836baa..77eb5b415 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -57,7 +57,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): if provider_model_id in self.provider_id_to_llama_model_map: return self.provider_id_to_llama_model_map[provider_model_id] else: - None + return None async def register_model(self, model: Model) -> Model: model.provider_resource_id = self.get_provider_model_id(