update memory endpoints

This commit is contained in:
Xi Yan 2024-09-22 23:59:53 -07:00
parent 1ac188e1b3
commit f467a0482c
5 changed files with 42 additions and 21 deletions

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.safety.safety import BuiltinShield # noqa: F403
from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis,
@ -54,7 +55,7 @@ def configure_api_providers(
for inf in builtin_automatically_routed_apis()
}
config.apis_to_serve = [a for a in apis if a != "telemetry"]
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
apis = [v.value for v in stack_apis()]
all_providers = api_providers()
@ -84,10 +85,31 @@ def configure_api_providers(
if api_str in router_api2builtin_api:
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
routing_key = prompt(
"> Enter routing key for the {} provider: ".format(api_str),
)
config.routing_table[]
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
if api_str == "inference":
if hasattr(cfg, "model"):
routing_key = cfg.model
else:
routing_key = prompt(
"> Please enter the supported model your provider has for inference: ",
default="Meta-Llama3.1-8B-Instruct",
)
if api_str == "safety":
# check all supported shields
for shield_type in BuiltinShield:
print(shield_type.value)
# if api_str == "memory":
# # check all supported memory_banks
config.routing_table[api_str] = [
RoutableProviderConfig(
routing_key=routing_key,
provider_id=p,
config=cfg.dict(),
)
]
else:
config.api_providers[api_str] = GenericProviderConfig(
provider_id=p,