Enable sane naming of registered objects with defaults

This commit is contained in:
Ashwin Bharambe 2024-11-12 10:17:34 -08:00
parent 9e925f43e5
commit 48a6e27de9
13 changed files with 222 additions and 131 deletions

View file

@ -18,7 +18,7 @@ from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
@ -152,12 +152,12 @@ a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: List[Model] = Field(default_factory=list)
shields: List[Shield] = Field(default_factory=list)
memory_banks: List[MemoryBank] = Field(default_factory=list)
datasets: List[Dataset] = Field(default_factory=list)
scoring_fns: List[ScoringFn] = Field(default_factory=list)
eval_tasks: List[EvalTask] = Field(default_factory=list)
models: List[ModelInput] = Field(default_factory=list)
shields: List[ShieldInput] = Field(default_factory=list)
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
class BuildConfig(BaseModel):

View file

@ -277,10 +277,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memorybank_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
if provider_memorybank_id is None:
provider_memorybank_id = memory_bank_id
if provider_memory_bank_id is None:
provider_memory_bank_id = memory_bank_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
@ -295,7 +295,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memorybank_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
},
)

View file

@ -68,30 +68,29 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
objects = [
*run_config.models,
*run_config.shields,
*run_config.memory_banks,
*run_config.datasets,
*run_config.scoring_fns,
*run_config.eval_tasks,
]
for obj in objects:
await dist_registry.register(obj)
resources = [
("models", Api.models),
("shields", Api.shields),
("memory_banks", Api.memory_banks),
("datasets", Api.datasets),
("scoring_fns", Api.scoring_functions),
("eval_tasks", Api.eval_tasks),
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
]
for rsrc, api in resources:
for rsrc, api, register_method, list_method in resources:
objects = getattr(run_config, rsrc)
if api not in impls:
continue
method = getattr(impls[api], f"list_{api.value}")
method = getattr(impls[api], register_method)
for obj in objects:
await method(**obj.model_dump())
method = getattr(impls[api], list_method)
for obj in await method():
print(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",