# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import inspect import sys from typing import Any, Dict, Protocol from unittest.mock import AsyncMock, MagicMock import pytest from pydantic import BaseModel, Field from llama_stack.apis.inference import Inference from llama_stack.distribution.datatypes import ( Api, Provider, StackRunConfig, ) from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.routers.routers import InferenceRouter from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None: """Dynamically add protocol methods to a class by inspecting the protocol.""" for name, value in inspect.getmembers(protocol): if inspect.isfunction(value) and hasattr(value, "__webmethod__"): # Get the signature sig = inspect.signature(value) # Create an async function with the same signature that returns a MagicMock async def mock_impl(*args, **kwargs): return MagicMock() # Set the signature on our mock implementation mock_impl.__signature__ = sig # Add it to the class setattr(cls, name, mock_impl) class SampleConfig(BaseModel): foo: str = Field( default="bar", description="foo", ) @classmethod def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: return { "foo": "baz", } class SampleImpl: def __init__(self, config: SampleConfig, deps: Dict[Api, Any], provider_spec: ProviderSpec = None): self.__provider_id__ = "test_provider" self.__provider_spec__ = provider_spec self.__provider_config__ = config self.__deps__ = deps self.foo = config.foo async def initialize(self): pass @pytest.mark.asyncio async def test_resolve_impls_basic(): # Create a real provider spec provider_spec = InlineProviderSpec( api=Api.inference, provider_type="sample", module="test_module", config_class="test_resolver.SampleConfig", api_dependencies=[], ) # Create provider registry with our provider provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}} run_config = StackRunConfig( image_name="test_image", providers={ "inference": [ Provider( provider_id="sample_provider", provider_type="sample", config=SampleConfig.sample_run_config(), ) ] }, ) dist_registry = MagicMock() mock_module = MagicMock() impl = SampleImpl(SampleConfig(foo="baz"), {}, provider_spec) add_protocol_methods(SampleImpl, Inference) mock_module.get_provider_impl = AsyncMock(return_value=impl) sys.modules["test_module"] = mock_module impls = await resolve_impls(run_config, provider_registry, dist_registry) assert Api.inference in impls assert isinstance(impls[Api.inference], InferenceRouter) table = impls[Api.inference].routing_table assert isinstance(table, ModelsRoutingTable) impl = table.impls_by_provider_id["sample_provider"] assert isinstance(impl, SampleImpl) assert impl.foo == "baz" assert impl.__provider_id__ == "sample_provider" assert impl.__provider_spec__ == provider_spec