diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 3a32125b2..297eecbdc 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -71,12 +71,9 @@ model_aliases = [ ] -class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): +class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, url: str) -> None: - ModelRegistryHelper.__init__( - self, - model_aliases=model_aliases, - ) + self.register_helper = ModelRegistryHelper(model_aliases) self.url = url self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -203,7 +200,9 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva else: input_dict["raw"] = True input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter + request, + self.register_helper.get_llama_model(request.model), + self.formatter, ) else: assert ( @@ -282,6 +281,18 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva ) -> EmbeddingsResponse: raise NotImplementedError() + async def register_model(self, model: Model) -> Model: + model = await self.register_helper.register_model(model) + models = await self.client.ps() + available_models = [m["model"] for m in models["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. " + f"Available models: {', '.join(available_models)}" + ) + + return model + async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index e5eb6e1ea..696cfb15d 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -45,12 +45,9 @@ def build_model_aliases(): ] -class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): +class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: - ModelRegistryHelper.__init__( - self, - model_aliases=build_model_aliases(), - ) + self.register_helper = ModelRegistryHelper(build_model_aliases()) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None @@ -131,6 +128,17 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate ): yield chunk + async def register_model(self, model: Model) -> Model: + model = await self.register_helper.register_model(model) + res = self.client.models.list() + available_models = [m.id for m in res] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model {model.provider_resource_id} is not being served by vLLM. " + f"Available models: {', '.join(available_models)}" + ) + return model + async def _get_params( self, request: Union[ChatCompletionRequest, CompletionRequest] ) -> dict: @@ -149,7 +157,9 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter + request, + self.register_helper.get_llama_model(request.model), + self.formatter, ) else: assert ( @@ -157,7 +167,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate ), "Together does not support media for Completion requests" input_dict["prompt"] = completion_request_to_prompt( request, - self.get_llama_model(request.model), + self.register_helper.get_llama_model(request.model), self.formatter, ) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py new file mode 100644 index 000000000..4b20e519c --- /dev/null +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +# How to run this test: +# +# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py +# -m "meta_reference" +# --env TOGETHER_API_KEY= + + +class TestModelRegistration: + @pytest.mark.asyncio + async def test_register_unsupported_model(self, inference_stack): + _, models_impl = inference_stack + + # Try to register a model that's too large for local inference + with pytest.raises(Exception) as exc_info: + await models_impl.register_model( + model_id="Llama3.1-70B-Instruct", + ) + + @pytest.mark.asyncio + async def test_register_nonexistent_model(self, inference_stack): + _, models_impl = inference_stack + + # Try to register a non-existent model + with pytest.raises(Exception) as exc_info: + await models_impl.register_model( + model_id="Llama3-NonExistent-Model", + ) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 7120e9e97..77eb5b415 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -54,7 +54,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate): raise ValueError(f"Unknown model: `{identifier}`") def get_llama_model(self, provider_model_id: str) -> str: - return self.provider_id_to_llama_model_map[provider_model_id] + if provider_model_id in self.provider_id_to_llama_model_map: + return self.provider_id_to_llama_model_map[provider_model_id] + else: + return None async def register_model(self, model: Model) -> Model: model.provider_resource_id = self.get_provider_model_id(