test: first unit test for resolver (#1475)

Starting to create unit tests to cover critical (and mostly
undocumented) provider resolution and routing logic.

## Test Plan

Unit tests
This commit is contained in:
Ashwin Bharambe 2025-03-07 10:20:51 -08:00 committed by GitHub
parent 60e7f3d705
commit 290cc843fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 126 additions and 0 deletions

View file

@ -0,0 +1,9 @@
---
description: General rules always applicable across the project
globs:
alwaysApply: true
---
# Style
- Comments must add value to code. Don't write filler comments explaining what you are doing next; they just add noise.
- Add a comment to clarify surprising behavior which would not be obvious. Good variable naming and clear code organization is more important.

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import sys
from typing import Any, Dict, Protocol
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
from llama_stack.distribution.datatypes import (
Api,
Provider,
StackRunConfig,
)
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.routers.routers import InferenceRouter
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
"""Dynamically add protocol methods to a class by inspecting the protocol."""
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
# Get the signature
sig = inspect.signature(value)
# Create an async function with the same signature that returns a MagicMock
async def mock_impl(*args, **kwargs):
return MagicMock()
# Set the signature on our mock implementation
mock_impl.__signature__ = sig
# Add it to the class
setattr(cls, name, mock_impl)
class SampleConfig(BaseModel):
foo: str = Field(
default="bar",
description="foo",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
return {
"foo": "baz",
}
class SampleImpl:
def __init__(self, config: SampleConfig, deps: Dict[Api, Any], provider_spec: ProviderSpec = None):
self.__provider_id__ = "test_provider"
self.__provider_spec__ = provider_spec
self.__provider_config__ = config
self.__deps__ = deps
self.foo = config.foo
async def initialize(self):
pass
@pytest.mark.asyncio
async def test_resolve_impls_basic():
# Create a real provider spec
provider_spec = InlineProviderSpec(
api=Api.inference,
provider_type="sample",
module="test_module",
config_class="test_resolver.SampleConfig",
api_dependencies=[],
)
# Create provider registry with our provider
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
run_config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="sample_provider",
provider_type="sample",
config=SampleConfig.sample_run_config(),
)
]
},
)
dist_registry = MagicMock()
mock_module = MagicMock()
impl = SampleImpl(SampleConfig(foo="baz"), {}, provider_spec)
add_protocol_methods(SampleImpl, Inference)
mock_module.get_provider_impl = AsyncMock(return_value=impl)
sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry)
assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter)
table = impls[Api.inference].routing_table
assert isinstance(table, ModelsRoutingTable)
impl = table.impls_by_provider_id["sample_provider"]
assert isinstance(impl, SampleImpl)
assert impl.foo == "baz"
assert impl.__provider_id__ == "sample_provider"
assert impl.__provider_spec__ == provider_spec