forked from phoenix-oss/llama-stack-mirror
test: add unit test to ensure all config types are instantiable (#1601)
This commit is contained in:
parent
0a0d6cb96e
commit
d072b5fa0c
71 changed files with 662 additions and 465 deletions
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to localfs storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="localfs_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
|
||||
) # Uses SQLite config specific to Meta Reference Eval storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="meta_reference_eval.db",
|
||||
)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -40,7 +42,7 @@ class VLLMConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -12,3 +12,9 @@ from pydantic import BaseModel
|
|||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeScannerConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -4,10 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
@ -23,3 +24,9 @@ class PromptGuardConfig(BaseModel):
|
|||
if v not in [t.value for t in PromptGuardType]:
|
||||
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
}
|
||||
|
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel): ...
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlmAsJudgeScoringConfig(BaseModel): ...
|
||||
class LlmAsJudgeScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleTelemetryImpl
|
||||
|
||||
impl = SampleTelemetryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagToolRuntimeConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
return {"db_path": "{env.CHROMADB_PATH}"}
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue