mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
memory bank registration fixes
This commit is contained in:
parent
099a95b614
commit
3725e74906
8 changed files with 108 additions and 62 deletions
|
@ -12,15 +12,15 @@ from pydantic import BaseModel
|
|||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
description: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
providers: List[str]
|
||||
provider_types: List[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -5,8 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -15,6 +16,25 @@ from termcolor import cprint
|
|||
from .memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
|
||||
if j is None:
|
||||
return None
|
||||
|
||||
if "type" not in j:
|
||||
raise ValueError("Memory bank type not specified")
|
||||
type = j["type"]
|
||||
if type == MemoryBankType.vector.value:
|
||||
return VectorMemoryBankDef(**j)
|
||||
elif type == MemoryBankType.keyvalue.value:
|
||||
return KeyValueMemoryBankDef(**j)
|
||||
elif type == MemoryBankType.keyword.value:
|
||||
return KeywordMemoryBankDef(**j)
|
||||
elif type == MemoryBankType.graph.value:
|
||||
return GraphMemoryBankDef(**j)
|
||||
else:
|
||||
raise ValueError(f"Unknown memory bank type: {type}")
|
||||
|
||||
|
||||
class MemoryBanksClient(MemoryBanks):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
@ -25,37 +45,57 @@ class MemoryBanksClient(MemoryBanks):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [MemoryBankSpec(**x) for x in response.json()]
|
||||
return [deserialize_memory_bank_def(x) for x in response.json()]
|
||||
|
||||
async def get_serving_memory_bank(
|
||||
self, bank_type: MemoryBankType
|
||||
) -> Optional[MemoryBankSpec]:
|
||||
async def get_memory_bank(
|
||||
self,
|
||||
identifier: str,
|
||||
) -> Optional[MemoryBankDef]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/get",
|
||||
params={
|
||||
"bank_type": bank_type.value,
|
||||
"identifier": identifier,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return MemoryBankSpec(**j)
|
||||
return deserialize_memory_bank_def(j)
|
||||
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/memory/register_memory_bank",
|
||||
json={
|
||||
"memory_bank": json.loads(memory_bank.json()),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = MemoryBanksClient(f"http://{host}:{port}")
|
||||
|
||||
response = await client.list_available_memory_banks()
|
||||
await client.register_memory_bank(
|
||||
VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
provider_id="",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
|
||||
response = await client.list_memory_banks()
|
||||
cprint(f"list_memory_banks response={response}", "green")
|
||||
|
||||
|
||||
|
|
|
@ -241,13 +241,15 @@ class StackBuild(Subcommand):
|
|||
default="conda",
|
||||
)
|
||||
|
||||
cprint(textwrap.dedent(
|
||||
cprint(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. Let's select
|
||||
the provider types (implementations) you want to use for these APIs.
|
||||
""",
|
||||
),
|
||||
color="green")
|
||||
color="green",
|
||||
)
|
||||
|
||||
print("Tip: use <TAB> to see options for the providers.\n")
|
||||
|
||||
|
@ -257,9 +259,7 @@ class StackBuild(Subcommand):
|
|||
x for x in providers_for_api.keys() if x != "remote"
|
||||
]
|
||||
api_provider = prompt(
|
||||
"> Enter provider for API {}: ".format(
|
||||
api.value
|
||||
),
|
||||
"> Enter provider for API {}: ".format(api.value),
|
||||
completer=WordCompleter(available_providers),
|
||||
complete_while_typing=True,
|
||||
validator=Validator.from_callable(
|
||||
|
|
|
@ -64,8 +64,8 @@ def configure_api_providers(
|
|||
) -> StackRunConfig:
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
apis = set((config.apis or list(build_spec.providers.keys())))
|
||||
config.apis = [a for a in apis if a != "telemetry"]
|
||||
# keep this default so all APIs are served
|
||||
config.apis = []
|
||||
|
||||
if is_nux:
|
||||
print(
|
||||
|
@ -79,7 +79,8 @@ def configure_api_providers(
|
|||
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
for api_str in config.apis:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
|
|
|
@ -8,52 +8,56 @@ from typing import Dict, List
|
|||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
pass
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(*args, **kwargs):
|
||||
impl = DistributionInspectImpl()
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_providers = get_provider_registry()
|
||||
for api, providers in all_providers.items():
|
||||
ret[api.value] = [
|
||||
for api, providers in run_config.providers.items():
|
||||
ret[api] = [
|
||||
ProviderInfo(
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
description="Passthrough" if is_passthrough(p) else "",
|
||||
)
|
||||
for p in providers.values()
|
||||
for p in providers
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
for api, endpoints in all_endpoints.items():
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
ret[api.value] = [
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
providers=[],
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
|
|
|
@ -60,8 +60,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + list(routing_table_apis)
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
)
|
||||
print(f"{apis_to_serve=}")
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
|
@ -112,18 +115,22 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
sorted_providers = topological_sort(
|
||||
{k: v.values() for k, v in providers_with_specs.items()}
|
||||
)
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={},
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -233,6 +240,7 @@ async def instantiate_provider(
|
|||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
@ -28,14 +28,8 @@ class MemoryRouter(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
return self.routing_table.list_memory_banks()
|
||||
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||
return self.routing_table.get_memory_bank(identifier)
|
||||
|
||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||
await self.routing_table.register_memory_bank(bank)
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
await self.routing_table.register_memory_bank(memory_bank)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -73,12 +67,6 @@ class InferenceRouter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return self.routing_table.list_models()
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
return self.routing_table.get_model(identifier)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
|
||||
|
@ -149,12 +137,6 @@ class SafetyRouter(Safety):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return self.routing_table.list_shields()
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
|
||||
return self.routing_table.get_shield(shield_type)
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.routing_table.register_shield(shield)
|
||||
|
||||
|
|
|
@ -80,10 +80,19 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return obj
|
||||
return None
|
||||
|
||||
async def register_object(self, obj: RoutableObject) -> Any:
|
||||
async def register_object(self, obj: RoutableObject):
|
||||
if obj.identifier in self.routing_key_to_object:
|
||||
raise ValueError(f"Object `{obj.identifier}` already registered")
|
||||
print(f"Object `{obj.identifier}` is already registered")
|
||||
return
|
||||
|
||||
if not obj.provider_id:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if not provider_ids:
|
||||
raise ValueError("No providers found")
|
||||
|
||||
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
|
||||
obj.provider_id = provider_ids[0]
|
||||
else:
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
|
@ -93,6 +102,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self.routing_key_to_object[obj.identifier] = obj
|
||||
self.registry.append(obj)
|
||||
|
||||
# TODO: persist this to a store
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue