From f467a0482c290d0444c41fabdc47d512732a0ec5 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 22 Sep 2024 23:59:53 -0700 Subject: [PATCH] update memory endpoints --- llama_stack/apis/memory/client.py | 8 ++--- llama_stack/apis/memory/memory.py | 18 +++++------ llama_stack/apis/memory_banks/memory_banks.py | 4 +-- llama_stack/distribution/configure.py | 32 ++++++++++++++++--- .../distribution/routers/routing_tables.py | 1 - 5 files changed, 42 insertions(+), 21 deletions(-) diff --git a/llama_stack/apis/memory/client.py b/llama_stack/apis/memory/client.py index 0cddf0d0e..b4bfcb34d 100644 --- a/llama_stack/apis/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -38,7 +38,7 @@ class MemoryClient(Memory): async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async with httpx.AsyncClient() as client: r = await client.get( - f"{self.base_url}/memory_banks/get", + f"{self.base_url}/memory/get", params={ "bank_id": bank_id, }, @@ -59,7 +59,7 @@ class MemoryClient(Memory): ) -> MemoryBank: async with httpx.AsyncClient() as client: r = await client.post( - f"{self.base_url}/memory_banks/create", + f"{self.base_url}/memory/create", json={ "name": name, "config": config.dict(), @@ -81,7 +81,7 @@ class MemoryClient(Memory): ) -> None: async with httpx.AsyncClient() as client: r = await client.post( - f"{self.base_url}/memory_bank/insert", + f"{self.base_url}/memory/insert", json={ "bank_id": bank_id, "documents": [d.dict() for d in documents], @@ -99,7 +99,7 @@ class MemoryClient(Memory): ) -> QueryDocumentsResponse: async with httpx.AsyncClient() as client: r = await client.post( - f"{self.base_url}/memory_bank/query", + f"{self.base_url}/memory/query", json={ "bank_id": bank_id, "query": query, diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index a26ff67ea..261dd93ee 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -96,7 +96,7 @@ class MemoryBank(BaseModel): class Memory(Protocol): - @webmethod(route="/memory_banks/create") + @webmethod(route="/memory/create") async def create_memory_bank( self, name: str, @@ -104,13 +104,13 @@ class Memory(Protocol): url: Optional[URL] = None, ) -> MemoryBank: ... - @webmethod(route="/memory_banks/list", method="GET") + @webmethod(route="/memory/list", method="GET") async def list_memory_banks(self) -> List[MemoryBank]: ... - @webmethod(route="/memory_banks/get", method="GET") + @webmethod(route="/memory/get", method="GET") async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ... - @webmethod(route="/memory_banks/drop", method="DELETE") + @webmethod(route="/memory/drop", method="DELETE") async def drop_memory_bank( self, bank_id: str, @@ -118,7 +118,7 @@ class Memory(Protocol): # this will just block now until documents are inserted, but it should # probably return a Job instance which can be polled for completion - @webmethod(route="/memory_bank/insert") + @webmethod(route="/memory/insert") async def insert_documents( self, bank_id: str, @@ -126,14 +126,14 @@ class Memory(Protocol): ttl_seconds: Optional[int] = None, ) -> None: ... - @webmethod(route="/memory_bank/update") + @webmethod(route="/memory/update") async def update_documents( self, bank_id: str, documents: List[MemoryBankDocument], ) -> None: ... - @webmethod(route="/memory_bank/query") + @webmethod(route="/memory/query") async def query_documents( self, bank_id: str, @@ -141,14 +141,14 @@ class Memory(Protocol): params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: ... - @webmethod(route="/memory_bank/documents/get", method="GET") + @webmethod(route="/memory/documents/get", method="GET") async def get_documents( self, bank_id: str, document_ids: List[str], ) -> List[MemoryBankDocument]: ... - @webmethod(route="/memory_bank/documents/delete", method="DELETE") + @webmethod(route="/memory/documents/delete", method="DELETE") async def delete_documents( self, bank_id: str, diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 7c0e981a3..721983b19 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -23,10 +23,10 @@ class MemoryBankSpec(BaseModel): class MemoryBanks(Protocol): - @webmethod(route="/memory_banks_router/list", method="GET") + @webmethod(route="/memory_banks/list", method="GET") async def list_memory_banks(self) -> List[MemoryBankSpec]: ... - @webmethod(route="/memory_banks_router/get", method="GET") + @webmethod(route="/memory_banks/get", method="GET") async def get_memory_bank( self, bank_type: MemoryBankType ) -> Optional[MemoryBankSpec]: ... diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 361c24416..3e9a0fbeb 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -9,6 +9,7 @@ from typing import Any from pydantic import BaseModel from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.apis.safety.safety import BuiltinShield # noqa: F403 from llama_stack.distribution.distribution import ( api_providers, builtin_automatically_routed_apis, @@ -54,7 +55,7 @@ def configure_api_providers( for inf in builtin_automatically_routed_apis() } - config.apis_to_serve = [a for a in apis if a != "telemetry"] + config.apis_to_serve = list(set([a for a in apis if a != "telemetry"])) apis = [v.value for v in stack_apis()] all_providers = api_providers() @@ -84,10 +85,31 @@ def configure_api_providers( if api_str in router_api2builtin_api: # a routing api, we need to infer and assign it a routing_key and put it in the routing_table - routing_key = prompt( - "> Enter routing key for the {} provider: ".format(api_str), - ) - config.routing_table[] + routing_key = "" + if api_str == "inference": + if hasattr(cfg, "model"): + routing_key = cfg.model + else: + routing_key = prompt( + "> Please enter the supported model your provider has for inference: ", + default="Meta-Llama3.1-8B-Instruct", + ) + + if api_str == "safety": + # check all supported shields + for shield_type in BuiltinShield: + print(shield_type.value) + + # if api_str == "memory": + # # check all supported memory_banks + + config.routing_table[api_str] = [ + RoutableProviderConfig( + routing_key=routing_key, + provider_id=p, + config=cfg.dict(), + ) + ] else: config.api_providers[api_str] = GenericProviderConfig( provider_id=p, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index b2e4f01eb..fcd4d2b2b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -23,7 +23,6 @@ class CommonRoutingTableImpl(RoutingTable): routing_table_config: Dict[str, List[RoutableProviderConfig]], ) -> None: self.providers = {k: v for k, v in inner_impls} - print("routing table providers", self.providers) self.routing_keys = list(self.providers.keys()) self.routing_table_config = routing_table_config