From 36e7bc6fbfe5eb6e9436d7a2643af2f7a02ca084 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 12 Mar 2025 21:20:53 -0700 Subject: [PATCH] test: add unit test to ensure all config types are instantiable --- llama_stack/providers/registry/safety.py | 21 --------- llama_stack/providers/registry/telemetry.py | 11 ----- .../remote/inference/runpod/__init__.py | 3 +- .../remote/inference/runpod/runpod.py | 1 - tests/unit/providers/test_configs.py | 43 +++++++++++++++++++ 5 files changed, 45 insertions(+), 34 deletions(-) create mode 100644 tests/unit/providers/test_configs.py diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index b9f7b6d78..6a824ae2f 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -27,27 +27,6 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.safety.prompt_guard", config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig", ), - InlineProviderSpec( - api=Api.safety, - provider_type="inline::meta-reference", - pip_packages=[ - "transformers", - "torch --index-url https://download.pytorch.org/whl/cpu", - ], - module="llama_stack.providers.inline.safety.meta_reference", - config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig", - api_dependencies=[ - Api.inference, - ], - deprecation_error=""" -Provider `inline::meta-reference` for API `safety` does not work with the latest Llama Stack. - -- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead. -- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead. -- if you are using Code Scanner, please use the `inline::code-scanner` provider instead. - - """, - ), InlineProviderSpec( api=Api.safety, provider_type="inline::llama-guard", diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index f3b41374c..fc249f3e2 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -7,11 +7,9 @@ from typing import List from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, ) @@ -28,13 +26,4 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.telemetry.meta_reference", config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", ), - remote_provider_spec( - api=Api.telemetry, - adapter=AdapterSpec( - adapter_type="sample", - pip_packages=[], - module="llama_stack.providers.remote.telemetry.sample", - config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig", - ), - ), ] diff --git a/llama_stack/providers/remote/inference/runpod/__init__.py b/llama_stack/providers/remote/inference/runpod/__init__.py index dcdfa9a84..69bf95046 100644 --- a/llama_stack/providers/remote/inference/runpod/__init__.py +++ b/llama_stack/providers/remote/inference/runpod/__init__.py @@ -5,10 +5,11 @@ # the root directory of this source tree. from .config import RunpodImplConfig -from .runpod import RunpodInferenceAdapter async def get_adapter_impl(config: RunpodImplConfig, _deps): + from .runpod import RunpodInferenceAdapter + assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}" impl = RunpodInferenceAdapter(config) await impl.initialize() diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 783842f71..72f858cd8 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -8,7 +8,6 @@ from typing import AsyncGenerator from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.models.llama.datatypes import Message # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper diff --git a/tests/unit/providers/test_configs.py b/tests/unit/providers/test_configs.py new file mode 100644 index 000000000..c284e682f --- /dev/null +++ b/tests/unit/providers/test_configs.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from pydantic import BaseModel + +from llama_stack.distribution.distribution import get_provider_registry, providable_apis +from llama_stack.distribution.utils.dynamic import instantiate_class_type + + +def test_all_provider_configs_can_be_instantiated(): + """ + Test that all provider configs can be instantiated. + This ensures that all config classes are correctly defined and can be instantiated without errors. + """ + # Get all provider registries + provider_registry = get_provider_registry() + + # Track any failures + failures = [] + + # For each API type + for api in providable_apis(): + providers = provider_registry.get(api, {}) + + # For each provider of this API type + for provider_type, provider_spec in providers.items(): + try: + # Get the config class + config_class_name = provider_spec.config_class + config_type = instantiate_class_type(config_class_name) + + assert issubclass(config_type, BaseModel) + + except Exception as e: + failures.append(f"Failed to instantiate {provider_type} config: {str(e)}") + + # Report all failures at once + if failures: + pytest.fail("\n".join(failures))