memory bank registration fixes

This commit is contained in:
Ashwin Bharambe 2024-10-06 22:00:54 -07:00 committed by Ashwin Bharambe
parent 099a95b614
commit 3725e74906
8 changed files with 108 additions and 62 deletions

View file

@ -12,15 +12,15 @@ from pydantic import BaseModel
@json_schema_type @json_schema_type
class ProviderInfo(BaseModel): class ProviderInfo(BaseModel):
provider_id: str
provider_type: str provider_type: str
description: str
@json_schema_type @json_schema_type
class RouteInfo(BaseModel): class RouteInfo(BaseModel):
route: str route: str
method: str method: str
providers: List[str] provider_types: List[str]
@json_schema_type @json_schema_type

View file

@ -5,8 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
from typing import List, Optional from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
@ -15,6 +16,25 @@ from termcolor import cprint
from .memory_banks import * # noqa: F403 from .memory_banks import * # noqa: F403
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
if j is None:
return None
if "type" not in j:
raise ValueError("Memory bank type not specified")
type = j["type"]
if type == MemoryBankType.vector.value:
return VectorMemoryBankDef(**j)
elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBankDef(**j)
elif type == MemoryBankType.keyword.value:
return KeywordMemoryBankDef(**j)
elif type == MemoryBankType.graph.value:
return GraphMemoryBankDef(**j)
else:
raise ValueError(f"Unknown memory bank type: {type}")
class MemoryBanksClient(MemoryBanks): class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str): def __init__(self, base_url: str):
self.base_url = base_url self.base_url = base_url
@ -25,37 +45,57 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: async def list_memory_banks(self) -> List[MemoryBankDef]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/memory_banks/list", f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()] return [deserialize_memory_bank_def(x) for x in response.json()]
async def get_serving_memory_bank( async def get_memory_bank(
self, bank_type: MemoryBankType self,
) -> Optional[MemoryBankSpec]: identifier: str,
) -> Optional[MemoryBankDef]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/memory_banks/get", f"{self.base_url}/memory_banks/get",
params={ params={
"bank_type": bank_type.value, "identifier": identifier,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
j = response.json() j = response.json()
if j is None: return deserialize_memory_bank_def(j)
return None
return MemoryBankSpec(**j) async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/memory/register_memory_bank",
json={
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}") client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_available_memory_banks() await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green") cprint(f"list_memory_banks response={response}", "green")

View file

@ -241,13 +241,15 @@ class StackBuild(Subcommand):
default="conda", default="conda",
) )
cprint(textwrap.dedent( cprint(
textwrap.dedent(
""" """
Llama Stack is composed of several APIs working together. Let's select Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs. the provider types (implementations) you want to use for these APIs.
""", """,
), ),
color="green") color="green",
)
print("Tip: use <TAB> to see options for the providers.\n") print("Tip: use <TAB> to see options for the providers.\n")
@ -257,9 +259,7 @@ class StackBuild(Subcommand):
x for x in providers_for_api.keys() if x != "remote" x for x in providers_for_api.keys() if x != "remote"
] ]
api_provider = prompt( api_provider = prompt(
"> Enter provider for API {}: ".format( "> Enter provider for API {}: ".format(api.value),
api.value
),
completer=WordCompleter(available_providers), completer=WordCompleter(available_providers),
complete_while_typing=True, complete_while_typing=True,
validator=Validator.from_callable( validator=Validator.from_callable(

View file

@ -64,8 +64,8 @@ def configure_api_providers(
) -> StackRunConfig: ) -> StackRunConfig:
is_nux = len(config.providers) == 0 is_nux = len(config.providers) == 0
apis = set((config.apis or list(build_spec.providers.keys()))) # keep this default so all APIs are served
config.apis = [a for a in apis if a != "telemetry"] config.apis = []
if is_nux: if is_nux:
print( print(
@ -79,7 +79,8 @@ def configure_api_providers(
provider_registry = get_provider_registry() provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()] builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
for api_str in config.apis: apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
if api in builtin_apis: if api in builtin_apis:
continue continue

View file

@ -8,52 +8,56 @@ from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403 from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
class DistributionInspectConfig(BaseModel): class DistributionInspectConfig(BaseModel):
pass run_config: StackRunConfig
async def get_provider_impl(*args, **kwargs): async def get_provider_impl(config, deps):
impl = DistributionInspectImpl() impl = DistributionInspectImpl(config, deps)
await impl.initialize() await impl.initialize()
return impl return impl
class DistributionInspectImpl(Inspect): class DistributionInspectImpl(Inspect):
def __init__(self): def __init__(self, config, deps):
pass self.config = config
self.deps = deps
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]: async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
run_config = self.config.run_config
ret = {} ret = {}
all_providers = get_provider_registry() for api, providers in run_config.providers.items():
for api, providers in all_providers.items(): ret[api] = [
ret[api.value] = [
ProviderInfo( ProviderInfo(
provider_id=p.provider_id,
provider_type=p.provider_type, provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
) )
for p in providers.values() for p in providers
] ]
return ret return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]: async def list_routes(self) -> Dict[str, List[RouteInfo]]:
run_config = self.config.run_config
ret = {} ret = {}
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items(): for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret[api.value] = [ ret[api.value] = [
RouteInfo( RouteInfo(
route=e.route, route=e.route,
method=e.method, method=e.method,
providers=[], provider_types=[p.provider_type for p in providers],
) )
for e in endpoints for e in endpoints
] ]

View file

@ -60,8 +60,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers_with_specs[key] = specs providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set( apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + list(routing_table_apis) list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
) )
print(f"{apis_to_serve=}")
for info in builtin_automatically_routed_apis(): for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve: if info.router_api.value not in apis_to_serve:
@ -112,18 +115,22 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
sorted_providers = topological_sort( sorted_providers = topological_sort(
{k: v.values() for k, v in providers_with_specs.items()} {k: v.values() for k, v in providers_with_specs.items()}
) )
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append( sorted_providers.append(
( (
"inspect", "inspect",
ProviderWithSpec( ProviderWithSpec(
provider_id="__builtin__", provider_id="__builtin__",
provider_type="__builtin__", provider_type="__builtin__",
config={}, config={
"run_config": run_config.dict(),
},
spec=InlineProviderSpec( spec=InlineProviderSpec(
api=Api.inspect, api=Api.inspect,
provider_type="__builtin__", provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig", config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect", module="llama_stack.distribution.inspect",
api_dependencies=apis,
), ),
), ),
) )
@ -233,6 +240,7 @@ async def instantiate_provider(
fn = getattr(module, method) fn = getattr(module, method)
impl = await fn(*args) impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config impl.__provider_config__ = config
return impl return impl

View file

@ -28,14 +28,8 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_memory_banks(self) -> List[MemoryBankDef]: async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
return self.routing_table.list_memory_banks() await self.routing_table.register_memory_bank(memory_bank)
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
return self.routing_table.get_memory_bank(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(bank)
async def insert_documents( async def insert_documents(
self, self,
@ -73,12 +67,6 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_models(self) -> List[ModelDef]:
return self.routing_table.list_models()
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return self.routing_table.get_model(identifier)
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model) await self.routing_table.register_model(model)
@ -149,12 +137,6 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_shields(self) -> List[ShieldDef]:
return self.routing_table.list_shields()
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
return self.routing_table.get_shield(shield_type)
async def register_shield(self, shield: ShieldDef) -> None: async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield) await self.routing_table.register_shield(shield)

View file

@ -80,10 +80,19 @@ class CommonRoutingTableImpl(RoutingTable):
return obj return obj
return None return None
async def register_object(self, obj: RoutableObject) -> Any: async def register_object(self, obj: RoutableObject):
if obj.identifier in self.routing_key_to_object: if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered") print(f"Object `{obj.identifier}` is already registered")
return
if not obj.provider_id:
provider_ids = list(self.impls_by_provider_id.keys())
if not provider_ids:
raise ValueError("No providers found")
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
obj.provider_id = provider_ids[0]
else:
if obj.provider_id not in self.impls_by_provider_id: if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found") raise ValueError(f"Provider `{obj.provider_id}` not found")
@ -93,6 +102,8 @@ class CommonRoutingTableImpl(RoutingTable):
self.routing_key_to_object[obj.identifier] = obj self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj) self.registry.append(obj)
# TODO: persist this to a store
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDef]: async def list_models(self) -> List[ModelDef]: