From 7e4765c45bb12e901e0f6a0f4c20a594d6d1500c Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 13:02:45 -0800 Subject: [PATCH] address feedback --- .../remote/inference/ollama/ollama.py | 6 ++-- .../providers/remote/inference/vllm/vllm.py | 11 +++--- .../inference/test_model_registration.py | 35 +++++++++++++++++++ 3 files changed, 44 insertions(+), 8 deletions(-) create mode 100644 llama_stack/providers/tests/inference/test_model_registration.py diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 32825e153..297eecbdc 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -73,7 +73,7 @@ model_aliases = [ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, url: str) -> None: - self.model_register_helper = ModelRegistryHelper(model_aliases) + self.register_helper = ModelRegistryHelper(model_aliases) self.url = url self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -201,7 +201,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["raw"] = True input_dict["prompt"] = chat_completion_request_to_prompt( request, - self.model_register_helper.get_llama_model(request.model), + self.register_helper.get_llama_model(request.model), self.formatter, ) else: @@ -282,7 +282,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): raise NotImplementedError() async def register_model(self, model: Model) -> Model: - model = await self.model_register_helper.register_model(model) + model = await self.register_helper.register_model(model) models = await self.client.ps() available_models = [m["model"] for m in models["models"]] if model.provider_resource_id not in available_models: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 657b5b576..696cfb15d 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -47,7 +47,7 @@ def build_model_aliases(): class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: - self.model_register_helper = ModelRegistryHelper(build_model_aliases()) + self.register_helper = ModelRegistryHelper(build_model_aliases()) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None @@ -129,12 +129,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def register_model(self, model: Model) -> Model: - model = await self.model_register_helper.register_model(model) + model = await self.register_helper.register_model(model) res = self.client.models.list() available_models = [m.id for m in res] if model.provider_resource_id not in available_models: raise ValueError( - f"Model {model.provider_resource_id} is not being served by vLLM" + f"Model {model.provider_resource_id} is not being served by vLLM. " + f"Available models: {', '.join(available_models)}" ) return model @@ -157,7 +158,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): else: input_dict["prompt"] = chat_completion_request_to_prompt( request, - self.model_register_helper.get_llama_model(request.model), + self.register_helper.get_llama_model(request.model), self.formatter, ) else: @@ -166,7 +167,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ), "Together does not support media for Completion requests" input_dict["prompt"] = completion_request_to_prompt( request, - self.model_register_helper.get_llama_model(request.model), + self.register_helper.get_llama_model(request.model), self.formatter, ) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py new file mode 100644 index 000000000..4b20e519c --- /dev/null +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +# How to run this test: +# +# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py +# -m "meta_reference" +# --env TOGETHER_API_KEY= + + +class TestModelRegistration: + @pytest.mark.asyncio + async def test_register_unsupported_model(self, inference_stack): + _, models_impl = inference_stack + + # Try to register a model that's too large for local inference + with pytest.raises(Exception) as exc_info: + await models_impl.register_model( + model_id="Llama3.1-70B-Instruct", + ) + + @pytest.mark.asyncio + async def test_register_nonexistent_model(self, inference_stack): + _, models_impl = inference_stack + + # Try to register a non-existent model + with pytest.raises(Exception) as exc_info: + await models_impl.register_model( + model_id="Llama3-NonExistent-Model", + )