memory bank registration fixes

This commit is contained in:
Ashwin Bharambe 2024-10-06 22:00:54 -07:00 committed by Ashwin Bharambe
parent 099a95b614
commit 3725e74906
8 changed files with 108 additions and 62 deletions

View file

@ -12,15 +12,15 @@ from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
provider_id: str
provider_type: str
description: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
method: str
providers: List[str]
provider_types: List[str]
@json_schema_type

View file

@ -5,8 +5,9 @@
# the root directory of this source tree.
import asyncio
import json
from typing import List, Optional
from typing import Any, Dict, List, Optional
import fire
import httpx
@ -15,6 +16,25 @@ from termcolor import cprint
from .memory_banks import * # noqa: F403
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
if j is None:
return None
if "type" not in j:
raise ValueError("Memory bank type not specified")
type = j["type"]
if type == MemoryBankType.vector.value:
return VectorMemoryBankDef(**j)
elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBankDef(**j)
elif type == MemoryBankType.keyword.value:
return KeywordMemoryBankDef(**j)
elif type == MemoryBankType.graph.value:
return GraphMemoryBankDef(**j)
else:
raise ValueError(f"Unknown memory bank type: {type}")
class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str):
self.base_url = base_url
@ -25,37 +45,57 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None:
pass
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
async def list_memory_banks(self) -> List[MemoryBankDef]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()]
return [deserialize_memory_bank_def(x) for x in response.json()]
async def get_serving_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]:
async def get_memory_bank(
self,
identifier: str,
) -> Optional[MemoryBankDef]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_type": bank_type.value,
"identifier": identifier,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return MemoryBankSpec(**j)
return deserialize_memory_bank_def(j)
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/memory/register_memory_bank",
json={
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_available_memory_banks()
await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")

View file

@ -241,13 +241,15 @@ class StackBuild(Subcommand):
default="conda",
)
cprint(textwrap.dedent(
cprint(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs.
""",
),
color="green")
color="green",
)
print("Tip: use <TAB> to see options for the providers.\n")
@ -257,9 +259,7 @@ class StackBuild(Subcommand):
x for x in providers_for_api.keys() if x != "remote"
]
api_provider = prompt(
"> Enter provider for API {}: ".format(
api.value
),
"> Enter provider for API {}: ".format(api.value),
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(

View file

@ -64,8 +64,8 @@ def configure_api_providers(
) -> StackRunConfig:
is_nux = len(config.providers) == 0
apis = set((config.apis or list(build_spec.providers.keys())))
config.apis = [a for a in apis if a != "telemetry"]
# keep this default so all APIs are served
config.apis = []
if is_nux:
print(
@ -79,7 +79,8 @@ def configure_api_providers(
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
for api_str in config.apis:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
for api_str in apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue

View file

@ -8,52 +8,56 @@ from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
class DistributionInspectConfig(BaseModel):
pass
run_config: StackRunConfig
async def get_provider_impl(*args, **kwargs):
impl = DistributionInspectImpl()
async def get_provider_impl(config, deps):
impl = DistributionInspectImpl(config, deps)
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect):
def __init__(self):
pass
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
run_config = self.config.run_config
ret = {}
all_providers = get_provider_registry()
for api, providers in all_providers.items():
ret[api.value] = [
for api, providers in run_config.providers.items():
ret[api] = [
ProviderInfo(
provider_id=p.provider_id,
provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
)
for p in providers.values()
for p in providers
]
return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
run_config = self.config.run_config
ret = {}
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret[api.value] = [
RouteInfo(
route=e.route,
method=e.method,
providers=[],
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
]

View file

@ -60,8 +60,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + list(routing_table_apis)
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
)
print(f"{apis_to_serve=}")
for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve:
@ -112,18 +115,22 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
sorted_providers = topological_sort(
{k: v.values() for k, v in providers_with_specs.items()}
)
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
"inspect",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={},
config={
"run_config": run_config.dict(),
},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
),
),
)
@ -233,6 +240,7 @@ async def instantiate_provider(
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -28,14 +28,8 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None:
pass
async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.routing_table.list_memory_banks()
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
return self.routing_table.get_memory_bank(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(bank)
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(memory_bank)
async def insert_documents(
self,
@ -73,12 +67,6 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[ModelDef]:
return self.routing_table.list_models()
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return self.routing_table.get_model(identifier)
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
@ -149,12 +137,6 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None:
pass
async def list_shields(self) -> List[ShieldDef]:
return self.routing_table.list_shields()
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
return self.routing_table.get_shield(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield)

View file

@ -80,10 +80,19 @@ class CommonRoutingTableImpl(RoutingTable):
return obj
return None
async def register_object(self, obj: RoutableObject) -> Any:
async def register_object(self, obj: RoutableObject):
if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered")
print(f"Object `{obj.identifier}` is already registered")
return
if not obj.provider_id:
provider_ids = list(self.impls_by_provider_id.keys())
if not provider_ids:
raise ValueError("No providers found")
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
obj.provider_id = provider_ids[0]
else:
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
@ -93,6 +102,8 @@ class CommonRoutingTableImpl(RoutingTable):
self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj)
# TODO: persist this to a store
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDef]: