update memory endpoints

This commit is contained in:
Xi Yan 2024-09-22 23:59:53 -07:00
parent 1ac188e1b3
commit f467a0482c
5 changed files with 42 additions and 21 deletions

View file

@ -38,7 +38,7 @@ class MemoryClient(Memory):
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{self.base_url}/memory_banks/get",
f"{self.base_url}/memory/get",
params={
"bank_id": bank_id,
},
@ -59,7 +59,7 @@ class MemoryClient(Memory):
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_banks/create",
f"{self.base_url}/memory/create",
json={
"name": name,
"config": config.dict(),
@ -81,7 +81,7 @@ class MemoryClient(Memory):
) -> None:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_bank/insert",
f"{self.base_url}/memory/insert",
json={
"bank_id": bank_id,
"documents": [d.dict() for d in documents],
@ -99,7 +99,7 @@ class MemoryClient(Memory):
) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_bank/query",
f"{self.base_url}/memory/query",
json={
"bank_id": bank_id,
"query": query,

View file

@ -96,7 +96,7 @@ class MemoryBank(BaseModel):
class Memory(Protocol):
@webmethod(route="/memory_banks/create")
@webmethod(route="/memory/create")
async def create_memory_bank(
self,
name: str,
@ -104,13 +104,13 @@ class Memory(Protocol):
url: Optional[URL] = None,
) -> MemoryBank: ...
@webmethod(route="/memory_banks/list", method="GET")
@webmethod(route="/memory/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get", method="GET")
@webmethod(route="/memory/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory_banks/drop", method="DELETE")
@webmethod(route="/memory/drop", method="DELETE")
async def drop_memory_bank(
self,
bank_id: str,
@ -118,7 +118,7 @@ class Memory(Protocol):
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory_bank/insert")
@webmethod(route="/memory/insert")
async def insert_documents(
self,
bank_id: str,
@ -126,14 +126,14 @@ class Memory(Protocol):
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory_bank/update")
@webmethod(route="/memory/update")
async def update_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/query")
@webmethod(route="/memory/query")
async def query_documents(
self,
bank_id: str,
@ -141,14 +141,14 @@ class Memory(Protocol):
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/documents/get", method="GET")
@webmethod(route="/memory/documents/get", method="GET")
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
@webmethod(route="/memory/documents/delete", method="DELETE")
async def delete_documents(
self,
bank_id: str,

View file

@ -23,10 +23,10 @@ class MemoryBankSpec(BaseModel):
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks_router/list", method="GET")
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
@webmethod(route="/memory_banks_router/get", method="GET")
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ...

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.safety.safety import BuiltinShield # noqa: F403
from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis,
@ -54,7 +55,7 @@ def configure_api_providers(
for inf in builtin_automatically_routed_apis()
}
config.apis_to_serve = [a for a in apis if a != "telemetry"]
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
apis = [v.value for v in stack_apis()]
all_providers = api_providers()
@ -84,10 +85,31 @@ def configure_api_providers(
if api_str in router_api2builtin_api:
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
if api_str == "inference":
if hasattr(cfg, "model"):
routing_key = cfg.model
else:
routing_key = prompt(
"> Enter routing key for the {} provider: ".format(api_str),
"> Please enter the supported model your provider has for inference: ",
default="Meta-Llama3.1-8B-Instruct",
)
config.routing_table[]
if api_str == "safety":
# check all supported shields
for shield_type in BuiltinShield:
print(shield_type.value)
# if api_str == "memory":
# # check all supported memory_banks
config.routing_table[api_str] = [
RoutableProviderConfig(
routing_key=routing_key,
provider_id=p,
config=cfg.dict(),
)
]
else:
config.api_providers[api_str] = GenericProviderConfig(
provider_id=p,

View file

@ -23,7 +23,6 @@ class CommonRoutingTableImpl(RoutingTable):
routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None:
self.providers = {k: v for k, v in inner_impls}
print("routing table providers", self.providers)
self.routing_keys = list(self.providers.keys())
self.routing_table_config = routing_table_config