From a01901132654ef6bb8acf82c53dd9e2de38b0c72 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 12:03:43 -0800 Subject: [PATCH] fix ollama registry --- .../providers/remote/inference/ollama/ollama.py | 15 +++++++++++++++ .../providers/utils/inference/model_registry.py | 5 ++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 3a32125b2..e0f75fdb0 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -282,6 +282,21 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva ) -> EmbeddingsResponse: raise NotImplementedError() + async def register_model(self, model: Model) -> Model: + # First perform the parent class's registration check + model = await super().register_model(model) + + # Additional Ollama-specific check + models = await self.client.ps() + available_models = [m["model"] for m in models["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. " + f"Available models: {', '.join(available_models)}" + ) + + return model + async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 7120e9e97..ae0836baa 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -54,7 +54,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate): raise ValueError(f"Unknown model: `{identifier}`") def get_llama_model(self, provider_model_id: str) -> str: - return self.provider_id_to_llama_model_map[provider_model_id] + if provider_model_id in self.provider_id_to_llama_model_map: + return self.provider_id_to_llama_model_map[provider_model_id] + else: + None async def register_model(self, model: Model) -> Model: model.provider_resource_id = self.get_provider_model_id(