Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -18,9 +18,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
async def resolve_impls_for_test(
api: Api,
models: List[ModelDef] = None,
memory_banks: List[MemoryBankDef] = None,
shields: List[ShieldDef] = None,
):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
@ -47,45 +44,11 @@ async def resolve_impls_for_test(
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
models = models or []
shields = shields or []
memory_banks = memory_banks or []
models = [
ModelDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in models
]
shields = [
ShieldDef(
**{
**s.dict(),
"provider_id": provider_id,
}
)
for s in shields
]
memory_banks = [
MemoryBankDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in memory_banks
]
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api],
providers={api.value: [Provider(**provider)]},
models=models,
memory_banks=memory_banks,
shields=shields,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)