mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 04:49:59 +00:00
Resolved merge conflicts
This commit is contained in:
parent
3298e50105
commit
967dd0aa08
82 changed files with 66055 additions and 0 deletions
66
tests/unit/server/test_replace_env_vars.py
Normal file
66
tests/unit/server/test_replace_env_vars.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
|
||||
|
||||
class TestReplaceEnvVars(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Clear any existing environment variables we'll use in tests
|
||||
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
# Set up test environment variables
|
||||
os.environ["TEST_VAR"] = "test_value"
|
||||
os.environ["EMPTY_VAR"] = ""
|
||||
os.environ["ZERO_VAR"] = "0"
|
||||
|
||||
def test_simple_replacement(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
|
||||
|
||||
def test_default_value_when_not_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.NOT_SET:default}"), "default")
|
||||
|
||||
def test_default_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:default}"), "test_value")
|
||||
|
||||
def test_default_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:default}"), "default")
|
||||
|
||||
def test_conditional_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR+conditional}"), "conditional")
|
||||
|
||||
def test_conditional_value_when_not_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.NOT_SET+conditional}"), "")
|
||||
|
||||
def test_conditional_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR+conditional}"), "")
|
||||
|
||||
def test_conditional_value_with_zero(self):
|
||||
self.assertEqual(replace_env_vars("${env.ZERO_VAR+conditional}"), "conditional")
|
||||
|
||||
def test_mixed_syntax(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:default} and ${env.NOT_SET+conditional}"), "test_value and ")
|
||||
self.assertEqual(
|
||||
replace_env_vars("${env.NOT_SET:default} and ${env.TEST_VAR+conditional}"), "default and conditional"
|
||||
)
|
||||
|
||||
def test_nested_structures(self):
|
||||
data = {
|
||||
"key1": "${env.TEST_VAR:default}",
|
||||
"key2": ["${env.NOT_SET:default}", "${env.TEST_VAR+conditional}"],
|
||||
"key3": {"nested": "${env.NOT_SET+conditional}"},
|
||||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": ""}}
|
||||
self.assertEqual(replace_env_vars(data), expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
117
tests/unit/server/test_resolver.py
Normal file
117
tests/unit/server/test_resolver.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, Dict, Protocol
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.distribution.datatypes import (
|
||||
Api,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.routers.routers import InferenceRouter
|
||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
|
||||
"""Dynamically add protocol methods to a class by inspecting the protocol."""
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||
# Get the signature
|
||||
sig = inspect.signature(value)
|
||||
|
||||
# Create an async function with the same signature that returns a MagicMock
|
||||
async def mock_impl(*args, **kwargs):
|
||||
return MagicMock()
|
||||
|
||||
# Set the signature on our mock implementation
|
||||
mock_impl.__signature__ = sig
|
||||
# Add it to the class
|
||||
setattr(cls, name, mock_impl)
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
foo: str = Field(
|
||||
default="bar",
|
||||
description="foo",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"foo": "baz",
|
||||
}
|
||||
|
||||
|
||||
class SampleImpl:
|
||||
def __init__(self, config: SampleConfig, deps: Dict[Api, Any], provider_spec: ProviderSpec = None):
|
||||
self.__provider_id__ = "test_provider"
|
||||
self.__provider_spec__ = provider_spec
|
||||
self.__provider_config__ = config
|
||||
self.__deps__ = deps
|
||||
self.foo = config.foo
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_impls_basic():
|
||||
# Create a real provider spec
|
||||
provider_spec = InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="sample",
|
||||
module="test_module",
|
||||
config_class="test_resolver.SampleConfig",
|
||||
api_dependencies=[],
|
||||
)
|
||||
|
||||
# Create provider registry with our provider
|
||||
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
|
||||
|
||||
run_config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="sample_provider",
|
||||
provider_type="sample",
|
||||
config=SampleConfig.sample_run_config(),
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
dist_registry = MagicMock()
|
||||
|
||||
mock_module = MagicMock()
|
||||
impl = SampleImpl(SampleConfig(foo="baz"), {}, provider_spec)
|
||||
add_protocol_methods(SampleImpl, Inference)
|
||||
|
||||
mock_module.get_provider_impl = AsyncMock(return_value=impl)
|
||||
sys.modules["test_module"] = mock_module
|
||||
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry)
|
||||
|
||||
assert Api.inference in impls
|
||||
assert isinstance(impls[Api.inference], InferenceRouter)
|
||||
|
||||
table = impls[Api.inference].routing_table
|
||||
assert isinstance(table, ModelsRoutingTable)
|
||||
|
||||
impl = table.impls_by_provider_id["sample_provider"]
|
||||
assert isinstance(impl, SampleImpl)
|
||||
assert impl.foo == "baz"
|
||||
assert impl.__provider_id__ == "sample_provider"
|
||||
assert impl.__provider_spec__ == provider_spec
|
||||
Loading…
Add table
Add a link
Reference in a new issue