From 3725e74906a830dbc121cfc7f235eef86a7e6913 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 6 Oct 2024 22:00:54 -0700 Subject: [PATCH] memory bank registration fixes --- llama_stack/apis/inspect/inspect.py | 4 +- llama_stack/apis/memory_banks/client.py | 62 +++++++++++++++---- llama_stack/cli/stack/build.py | 14 ++--- llama_stack/distribution/configure.py | 7 ++- llama_stack/distribution/inspect.py | 30 +++++---- llama_stack/distribution/resolver.py | 12 +++- llama_stack/distribution/routers/routers.py | 22 +------ .../distribution/routers/routing_tables.py | 19 ++++-- 8 files changed, 108 insertions(+), 62 deletions(-) diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index ca444098c..a30f39a16 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -12,15 +12,15 @@ from pydantic import BaseModel @json_schema_type class ProviderInfo(BaseModel): + provider_id: str provider_type: str - description: str @json_schema_type class RouteInfo(BaseModel): route: str method: str - providers: List[str] + provider_types: List[str] @json_schema_type diff --git a/llama_stack/apis/memory_banks/client.py b/llama_stack/apis/memory_banks/client.py index 78a991374..3b763d1f3 100644 --- a/llama_stack/apis/memory_banks/client.py +++ b/llama_stack/apis/memory_banks/client.py @@ -5,8 +5,9 @@ # the root directory of this source tree. import asyncio +import json -from typing import List, Optional +from typing import Any, Dict, List, Optional import fire import httpx @@ -15,6 +16,25 @@ from termcolor import cprint from .memory_banks import * # noqa: F403 +def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef: + if j is None: + return None + + if "type" not in j: + raise ValueError("Memory bank type not specified") + type = j["type"] + if type == MemoryBankType.vector.value: + return VectorMemoryBankDef(**j) + elif type == MemoryBankType.keyvalue.value: + return KeyValueMemoryBankDef(**j) + elif type == MemoryBankType.keyword.value: + return KeywordMemoryBankDef(**j) + elif type == MemoryBankType.graph.value: + return GraphMemoryBankDef(**j) + else: + raise ValueError(f"Unknown memory bank type: {type}") + + class MemoryBanksClient(MemoryBanks): def __init__(self, base_url: str): self.base_url = base_url @@ -25,37 +45,57 @@ class MemoryBanksClient(MemoryBanks): async def shutdown(self) -> None: pass - async def list_available_memory_banks(self) -> List[MemoryBankSpec]: + async def list_memory_banks(self) -> List[MemoryBankDef]: async with httpx.AsyncClient() as client: response = await client.get( f"{self.base_url}/memory_banks/list", headers={"Content-Type": "application/json"}, ) response.raise_for_status() - return [MemoryBankSpec(**x) for x in response.json()] + return [deserialize_memory_bank_def(x) for x in response.json()] - async def get_serving_memory_bank( - self, bank_type: MemoryBankType - ) -> Optional[MemoryBankSpec]: + async def get_memory_bank( + self, + identifier: str, + ) -> Optional[MemoryBankDef]: async with httpx.AsyncClient() as client: response = await client.get( f"{self.base_url}/memory_banks/get", params={ - "bank_type": bank_type.value, + "identifier": identifier, }, headers={"Content-Type": "application/json"}, ) response.raise_for_status() j = response.json() - if j is None: - return None - return MemoryBankSpec(**j) + return deserialize_memory_bank_def(j) + + async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/memory/register_memory_bank", + json={ + "memory_bank": json.loads(memory_bank.json()), + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() async def run_main(host: str, port: int, stream: bool): client = MemoryBanksClient(f"http://{host}:{port}") - response = await client.list_available_memory_banks() + await client.register_memory_bank( + VectorMemoryBankDef( + identifier="test_bank", + provider_id="", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) + + response = await client.list_memory_banks() cprint(f"list_memory_banks response={response}", "green") diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index f6821c8df..f07a0f873 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -241,13 +241,15 @@ class StackBuild(Subcommand): default="conda", ) - cprint(textwrap.dedent( - """ + cprint( + textwrap.dedent( + """ Llama Stack is composed of several APIs working together. Let's select the provider types (implementations) you want to use for these APIs. """, - ), - color="green") + ), + color="green", + ) print("Tip: use to see options for the providers.\n") @@ -257,9 +259,7 @@ class StackBuild(Subcommand): x for x in providers_for_api.keys() if x != "remote" ] api_provider = prompt( - "> Enter provider for API {}: ".format( - api.value - ), + "> Enter provider for API {}: ".format(api.value), completer=WordCompleter(available_providers), complete_while_typing=True, validator=Validator.from_callable( diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index f343c13bb..12f225af2 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -64,8 +64,8 @@ def configure_api_providers( ) -> StackRunConfig: is_nux = len(config.providers) == 0 - apis = set((config.apis or list(build_spec.providers.keys()))) - config.apis = [a for a in apis if a != "telemetry"] + # keep this default so all APIs are served + config.apis = [] if is_nux: print( @@ -79,7 +79,8 @@ def configure_api_providers( provider_registry = get_provider_registry() builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()] - for api_str in config.apis: + apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)] + for api_str in apis_to_serve: api = Api(api_str) if api in builtin_apis: continue diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 9963fffd8..f5716ef5e 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -8,52 +8,56 @@ from typing import Dict, List from llama_stack.apis.inspect import * # noqa: F403 from pydantic import BaseModel -from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.providers.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 class DistributionInspectConfig(BaseModel): - pass + run_config: StackRunConfig -async def get_provider_impl(*args, **kwargs): - impl = DistributionInspectImpl() +async def get_provider_impl(config, deps): + impl = DistributionInspectImpl(config, deps) await impl.initialize() return impl class DistributionInspectImpl(Inspect): - def __init__(self): - pass + def __init__(self, config, deps): + self.config = config + self.deps = deps async def initialize(self) -> None: pass async def list_providers(self) -> Dict[str, List[ProviderInfo]]: + run_config = self.config.run_config + ret = {} - all_providers = get_provider_registry() - for api, providers in all_providers.items(): - ret[api.value] = [ + for api, providers in run_config.providers.items(): + ret[api] = [ ProviderInfo( + provider_id=p.provider_id, provider_type=p.provider_type, - description="Passthrough" if is_passthrough(p) else "", ) - for p in providers.values() + for p in providers ] return ret async def list_routes(self) -> Dict[str, List[RouteInfo]]: + run_config = self.config.run_config + ret = {} all_endpoints = get_all_api_endpoints() - for api, endpoints in all_endpoints.items(): + providers = run_config.providers.get(api.value, []) ret[api.value] = [ RouteInfo( route=e.route, method=e.method, - providers=[], + provider_types=[p.provider_type for p in providers], ) for e in endpoints ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 0adb42915..0fc9bd72e 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -60,8 +60,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An providers_with_specs[key] = specs apis_to_serve = run_config.apis or set( - list(providers_with_specs.keys()) + list(routing_table_apis) + list(providers_with_specs.keys()) + + [x.value for x in routing_table_apis] + + [x.value for x in router_apis] ) + print(f"{apis_to_serve=}") for info in builtin_automatically_routed_apis(): if info.router_api.value not in apis_to_serve: @@ -112,18 +115,22 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An sorted_providers = topological_sort( {k: v.values() for k, v in providers_with_specs.items()} ) + apis = [x[1].spec.api for x in sorted_providers] sorted_providers.append( ( "inspect", ProviderWithSpec( provider_id="__builtin__", provider_type="__builtin__", - config={}, + config={ + "run_config": run_config.dict(), + }, spec=InlineProviderSpec( api=Api.inspect, provider_type="__builtin__", config_class="llama_stack.distribution.inspect.DistributionInspectConfig", module="llama_stack.distribution.inspect", + api_dependencies=apis, ), ), ) @@ -233,6 +240,7 @@ async def instantiate_provider( fn = getattr(module, method) impl = await fn(*args) + impl.__provider_id__ = provider.provider_id impl.__provider_spec__ = provider_spec impl.__provider_config__ = config return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index c56b33f21..361cee3f3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -28,14 +28,8 @@ class MemoryRouter(Memory): async def shutdown(self) -> None: pass - async def list_memory_banks(self) -> List[MemoryBankDef]: - return self.routing_table.list_memory_banks() - - async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: - return self.routing_table.get_memory_bank(identifier) - - async def register_memory_bank(self, bank: MemoryBankDef) -> None: - await self.routing_table.register_memory_bank(bank) + async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: + await self.routing_table.register_memory_bank(memory_bank) async def insert_documents( self, @@ -73,12 +67,6 @@ class InferenceRouter(Inference): async def shutdown(self) -> None: pass - async def list_models(self) -> List[ModelDef]: - return self.routing_table.list_models() - - async def get_model(self, identifier: str) -> Optional[ModelDef]: - return self.routing_table.get_model(identifier) - async def register_model(self, model: ModelDef) -> None: await self.routing_table.register_model(model) @@ -149,12 +137,6 @@ class SafetyRouter(Safety): async def shutdown(self) -> None: pass - async def list_shields(self) -> List[ShieldDef]: - return self.routing_table.list_shields() - - async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: - return self.routing_table.get_shield(shield_type) - async def register_shield(self, shield: ShieldDef) -> None: await self.routing_table.register_shield(shield) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ef38b6391..3d89aa19f 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -80,12 +80,21 @@ class CommonRoutingTableImpl(RoutingTable): return obj return None - async def register_object(self, obj: RoutableObject) -> Any: + async def register_object(self, obj: RoutableObject): if obj.identifier in self.routing_key_to_object: - raise ValueError(f"Object `{obj.identifier}` already registered") + print(f"Object `{obj.identifier}` is already registered") + return - if obj.provider_id not in self.impls_by_provider_id: - raise ValueError(f"Provider `{obj.provider_id}` not found") + if not obj.provider_id: + provider_ids = list(self.impls_by_provider_id.keys()) + if not provider_ids: + raise ValueError("No providers found") + + print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}") + obj.provider_id = provider_ids[0] + else: + if obj.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider `{obj.provider_id}` not found") p = self.impls_by_provider_id[obj.provider_id] await register_object_with_provider(obj, p) @@ -93,6 +102,8 @@ class CommonRoutingTableImpl(RoutingTable): self.routing_key_to_object[obj.identifier] = obj self.registry.append(obj) + # TODO: persist this to a store + class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelDef]: