test: add unit test to ensure all config types are instantiable (#1601)

This commit is contained in:
Ashwin Bharambe 2025-03-12 22:29:58 -07:00 committed by GitHub
parent 0a0d6cb96e
commit d072b5fa0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
71 changed files with 662 additions and 465 deletions

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleAgentsImpl
impl = SampleAgentsImpl(config)
await impl.initialize()
return impl

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.agents import Agents
from .config import SampleConfig
class SampleAgentsImpl(Agents):
def __init__(self, config: SampleConfig):
self.config = config
async def initialize(self):
pass

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
class HuggingfaceDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "huggingface_datasetio.db").as_posix()
) # Uses SQLite config specific to HF storage
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="huggingface_datasetio.db",
)
}

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel, Field
@ -20,3 +21,15 @@ class DatabricksImplConfig(BaseModel):
default=None,
description="The Databricks API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL}",
api_token: str = "${env.DATABRICKS_API_TOKEN}",
**kwargs: Any,
) -> Dict[str, Any]:
return {
"url": url,
"api_token": api_token,
}

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from .config import RunpodImplConfig
from .runpod import RunpodInferenceAdapter
async def get_adapter_impl(config: RunpodImplConfig, _deps):
from .runpod import RunpodInferenceAdapter
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
impl = RunpodInferenceAdapter(config)
await impl.initialize()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
@ -21,3 +21,10 @@ class RunpodImplConfig(BaseModel):
default=None,
description="The API token",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:}",
"api_token": "${env.RUNPOD_API_TOKEN:}",
}

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.models.llama.datatypes import Message
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleInferenceImpl
impl = SampleInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model
from .config import SampleConfig
class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig):
self.config = config
async def register_model(self, model: Model) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def initialize(self):
pass

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleSafetyImpl
impl = SampleSafetyImpl(config)
await impl.initialize()
return impl

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shield
from .config import SampleConfig
class SampleSafetyImpl(Safety):
def __init__(self, config: SampleConfig):
self.config = config
async def register_shield(self, shield: Shield) -> None:
# these are the safety shields the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def initialize(self):
pass

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel
@ -14,3 +14,9 @@ class BingSearchToolConfig(BaseModel):
api_key: Optional[str] = None
top_k: int = 3
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"api_key": "${env.BING_API_KEY:}",
}

View file

@ -4,8 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel):
pass
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel
@ -13,3 +13,9 @@ class WolframAlphaToolConfig(BaseModel):
"""Configuration for WolframAlpha Tool Runtime"""
api_key: Optional[str] = None
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:}",
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel
@ -24,3 +24,9 @@ class QdrantVectorIOConfig(BaseModel):
timeout: Optional[int] = None
host: Optional[str] = None
path: Optional[str] = None
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
return {
"api_key": "${env.QDRANT_API_KEY}",
}

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleVectorIOConfig
async def get_adapter_impl(config: SampleVectorIOConfig, _deps) -> Any:
from .sample import SampleVectorIOImpl
impl = SampleVectorIOImpl(config)
await impl.initialize()
return impl

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleVectorIOConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -1,26 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import VectorIO
from .config import SampleVectorIOConfig
class SampleVectorIOImpl(VectorIO):
def __init__(self, config: SampleVectorIOConfig):
self.config = config
async def register_vector_db(self, vector_db: VectorDB) -> None:
# these are the vector dbs the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def initialize(self):
pass
async def shutdown(self):
pass

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
@ -13,4 +15,6 @@ class WeaviateRequestProviderData(BaseModel):
class WeaviateVectorIOConfig(BaseModel):
pass
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
return {}