mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Enable sane naming of registered objects with defaults (#429)
# What does this PR do? This is a follow-up to #425. That PR allows for specifying models in the registry, but each entry needs to look like: ```yaml - identifier: ... provider_id: ... provider_resource_identifier: ... ``` This is headache-inducing. The current PR makes this situation better by adopting the shape of our APIs. Namely, we need the user to only specify `model-id`. The rest should be optional and figured out by the Stack. You can always override it. Here's what example `ollama` "full stack" registry looks like (we still need to kill or simplify shield_type crap): ```yaml models: - model_id: Llama3.2-3B-Instruct - model_id: Llama-Guard-3-1B shields: - shield_id: llama_guard shield_type: llama_guard ``` ## Test Plan See test plan for #425. Re-ran it.
This commit is contained in:
parent
d9d271a684
commit
09269e2a44
17 changed files with 295 additions and 207 deletions
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from termcolor import colored
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
@ -67,30 +68,29 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
|
||||
objects = [
|
||||
*run_config.models,
|
||||
*run_config.shields,
|
||||
*run_config.memory_banks,
|
||||
*run_config.datasets,
|
||||
*run_config.scoring_fns,
|
||||
*run_config.eval_tasks,
|
||||
]
|
||||
for obj in objects:
|
||||
await dist_registry.register(obj)
|
||||
|
||||
resources = [
|
||||
("models", Api.models),
|
||||
("shields", Api.shields),
|
||||
("memory_banks", Api.memory_banks),
|
||||
("datasets", Api.datasets),
|
||||
("scoring_fns", Api.scoring_functions),
|
||||
("eval_tasks", Api.eval_tasks),
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
Api.scoring_functions,
|
||||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
]
|
||||
for rsrc, api in resources:
|
||||
for rsrc, api, register_method, list_method in resources:
|
||||
objects = getattr(run_config, rsrc)
|
||||
if api not in impls:
|
||||
continue
|
||||
|
||||
method = getattr(impls[api], f"list_{api.value}")
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
await method(**obj.model_dump())
|
||||
|
||||
method = getattr(impls[api], list_method)
|
||||
for obj in await method():
|
||||
print(
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue