mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
memory bank registration fixes
This commit is contained in:
parent
099a95b614
commit
3725e74906
8 changed files with 108 additions and 62 deletions
|
|
@ -64,8 +64,8 @@ def configure_api_providers(
|
|||
) -> StackRunConfig:
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
apis = set((config.apis or list(build_spec.providers.keys())))
|
||||
config.apis = [a for a in apis if a != "telemetry"]
|
||||
# keep this default so all APIs are served
|
||||
config.apis = []
|
||||
|
||||
if is_nux:
|
||||
print(
|
||||
|
|
@ -79,7 +79,8 @@ def configure_api_providers(
|
|||
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
for api_str in config.apis:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -8,52 +8,56 @@ from typing import Dict, List
|
|||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
pass
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(*args, **kwargs):
|
||||
impl = DistributionInspectImpl()
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_providers = get_provider_registry()
|
||||
for api, providers in all_providers.items():
|
||||
ret[api.value] = [
|
||||
for api, providers in run_config.providers.items():
|
||||
ret[api] = [
|
||||
ProviderInfo(
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
description="Passthrough" if is_passthrough(p) else "",
|
||||
)
|
||||
for p in providers.values()
|
||||
for p in providers
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
for api, endpoints in all_endpoints.items():
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
ret[api.value] = [
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
providers=[],
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
|
|
|
|||
|
|
@ -60,8 +60,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + list(routing_table_apis)
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
)
|
||||
print(f"{apis_to_serve=}")
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
|
|
@ -112,18 +115,22 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
sorted_providers = topological_sort(
|
||||
{k: v.values() for k, v in providers_with_specs.items()}
|
||||
)
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={},
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
|
@ -233,6 +240,7 @@ async def instantiate_provider(
|
|||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -28,14 +28,8 @@ class MemoryRouter(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
return self.routing_table.list_memory_banks()
|
||||
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||
return self.routing_table.get_memory_bank(identifier)
|
||||
|
||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||
await self.routing_table.register_memory_bank(bank)
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
await self.routing_table.register_memory_bank(memory_bank)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
|
@ -73,12 +67,6 @@ class InferenceRouter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return self.routing_table.list_models()
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
return self.routing_table.get_model(identifier)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
|
||||
|
|
@ -149,12 +137,6 @@ class SafetyRouter(Safety):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return self.routing_table.list_shields()
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
|
||||
return self.routing_table.get_shield(shield_type)
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.routing_table.register_shield(shield)
|
||||
|
||||
|
|
|
|||
|
|
@ -80,12 +80,21 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return obj
|
||||
return None
|
||||
|
||||
async def register_object(self, obj: RoutableObject) -> Any:
|
||||
async def register_object(self, obj: RoutableObject):
|
||||
if obj.identifier in self.routing_key_to_object:
|
||||
raise ValueError(f"Object `{obj.identifier}` already registered")
|
||||
print(f"Object `{obj.identifier}` is already registered")
|
||||
return
|
||||
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
if not obj.provider_id:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if not provider_ids:
|
||||
raise ValueError("No providers found")
|
||||
|
||||
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
|
||||
obj.provider_id = provider_ids[0]
|
||||
else:
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await register_object_with_provider(obj, p)
|
||||
|
|
@ -93,6 +102,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self.routing_key_to_object[obj.identifier] = obj
|
||||
self.registry.append(obj)
|
||||
|
||||
# TODO: persist this to a store
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue