mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
update memory endpoints
This commit is contained in:
parent
1ac188e1b3
commit
f467a0482c
5 changed files with 42 additions and 21 deletions
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety.safety import BuiltinShield # noqa: F403
|
||||
from llama_stack.distribution.distribution import (
|
||||
api_providers,
|
||||
builtin_automatically_routed_apis,
|
||||
|
@ -54,7 +55,7 @@ def configure_api_providers(
|
|||
for inf in builtin_automatically_routed_apis()
|
||||
}
|
||||
|
||||
config.apis_to_serve = [a for a in apis if a != "telemetry"]
|
||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = api_providers()
|
||||
|
@ -84,10 +85,31 @@ def configure_api_providers(
|
|||
|
||||
if api_str in router_api2builtin_api:
|
||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
||||
routing_key = prompt(
|
||||
"> Enter routing key for the {} provider: ".format(api_str),
|
||||
)
|
||||
config.routing_table[]
|
||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
||||
if api_str == "inference":
|
||||
if hasattr(cfg, "model"):
|
||||
routing_key = cfg.model
|
||||
else:
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported model your provider has for inference: ",
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
)
|
||||
|
||||
if api_str == "safety":
|
||||
# check all supported shields
|
||||
for shield_type in BuiltinShield:
|
||||
print(shield_type.value)
|
||||
|
||||
# if api_str == "memory":
|
||||
# # check all supported memory_banks
|
||||
|
||||
config.routing_table[api_str] = [
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_id=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
]
|
||||
else:
|
||||
config.api_providers[api_str] = GenericProviderConfig(
|
||||
provider_id=p,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue