mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
update memory endpoints
This commit is contained in:
parent
1ac188e1b3
commit
f467a0482c
5 changed files with 42 additions and 21 deletions
|
@ -38,7 +38,7 @@ class MemoryClient(Memory):
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(
|
r = await client.get(
|
||||||
f"{self.base_url}/memory_banks/get",
|
f"{self.base_url}/memory/get",
|
||||||
params={
|
params={
|
||||||
"bank_id": bank_id,
|
"bank_id": bank_id,
|
||||||
},
|
},
|
||||||
|
@ -59,7 +59,7 @@ class MemoryClient(Memory):
|
||||||
) -> MemoryBank:
|
) -> MemoryBank:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
f"{self.base_url}/memory_banks/create",
|
f"{self.base_url}/memory/create",
|
||||||
json={
|
json={
|
||||||
"name": name,
|
"name": name,
|
||||||
"config": config.dict(),
|
"config": config.dict(),
|
||||||
|
@ -81,7 +81,7 @@ class MemoryClient(Memory):
|
||||||
) -> None:
|
) -> None:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
f"{self.base_url}/memory_bank/insert",
|
f"{self.base_url}/memory/insert",
|
||||||
json={
|
json={
|
||||||
"bank_id": bank_id,
|
"bank_id": bank_id,
|
||||||
"documents": [d.dict() for d in documents],
|
"documents": [d.dict() for d in documents],
|
||||||
|
@ -99,7 +99,7 @@ class MemoryClient(Memory):
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
f"{self.base_url}/memory_bank/query",
|
f"{self.base_url}/memory/query",
|
||||||
json={
|
json={
|
||||||
"bank_id": bank_id,
|
"bank_id": bank_id,
|
||||||
"query": query,
|
"query": query,
|
||||||
|
|
|
@ -96,7 +96,7 @@ class MemoryBank(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Memory(Protocol):
|
class Memory(Protocol):
|
||||||
@webmethod(route="/memory_banks/create")
|
@webmethod(route="/memory/create")
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -104,13 +104,13 @@ class Memory(Protocol):
|
||||||
url: Optional[URL] = None,
|
url: Optional[URL] = None,
|
||||||
) -> MemoryBank: ...
|
) -> MemoryBank: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/list", method="GET")
|
@webmethod(route="/memory/list", method="GET")
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get", method="GET")
|
@webmethod(route="/memory/get", method="GET")
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/drop", method="DELETE")
|
@webmethod(route="/memory/drop", method="DELETE")
|
||||||
async def drop_memory_bank(
|
async def drop_memory_bank(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
@ -118,7 +118,7 @@ class Memory(Protocol):
|
||||||
|
|
||||||
# this will just block now until documents are inserted, but it should
|
# this will just block now until documents are inserted, but it should
|
||||||
# probably return a Job instance which can be polled for completion
|
# probably return a Job instance which can be polled for completion
|
||||||
@webmethod(route="/memory_bank/insert")
|
@webmethod(route="/memory/insert")
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
@ -126,14 +126,14 @@ class Memory(Protocol):
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/update")
|
@webmethod(route="/memory/update")
|
||||||
async def update_documents(
|
async def update_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/query")
|
@webmethod(route="/memory/query")
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
@ -141,14 +141,14 @@ class Memory(Protocol):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse: ...
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/documents/get", method="GET")
|
@webmethod(route="/memory/documents/get", method="GET")
|
||||||
async def get_documents(
|
async def get_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_ids: List[str],
|
document_ids: List[str],
|
||||||
) -> List[MemoryBankDocument]: ...
|
) -> List[MemoryBankDocument]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
|
@webmethod(route="/memory/documents/delete", method="DELETE")
|
||||||
async def delete_documents(
|
async def delete_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -23,10 +23,10 @@ class MemoryBankSpec(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanks(Protocol):
|
class MemoryBanks(Protocol):
|
||||||
@webmethod(route="/memory_banks_router/list", method="GET")
|
@webmethod(route="/memory_banks/list", method="GET")
|
||||||
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
|
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks_router/get", method="GET")
|
@webmethod(route="/memory_banks/get", method="GET")
|
||||||
async def get_memory_bank(
|
async def get_memory_bank(
|
||||||
self, bank_type: MemoryBankType
|
self, bank_type: MemoryBankType
|
||||||
) -> Optional[MemoryBankSpec]: ...
|
) -> Optional[MemoryBankSpec]: ...
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.safety.safety import BuiltinShield # noqa: F403
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
api_providers,
|
api_providers,
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
|
@ -54,7 +55,7 @@ def configure_api_providers(
|
||||||
for inf in builtin_automatically_routed_apis()
|
for inf in builtin_automatically_routed_apis()
|
||||||
}
|
}
|
||||||
|
|
||||||
config.apis_to_serve = [a for a in apis if a != "telemetry"]
|
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||||
|
|
||||||
apis = [v.value for v in stack_apis()]
|
apis = [v.value for v in stack_apis()]
|
||||||
all_providers = api_providers()
|
all_providers = api_providers()
|
||||||
|
@ -84,10 +85,31 @@ def configure_api_providers(
|
||||||
|
|
||||||
if api_str in router_api2builtin_api:
|
if api_str in router_api2builtin_api:
|
||||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
||||||
routing_key = prompt(
|
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
||||||
"> Enter routing key for the {} provider: ".format(api_str),
|
if api_str == "inference":
|
||||||
)
|
if hasattr(cfg, "model"):
|
||||||
config.routing_table[]
|
routing_key = cfg.model
|
||||||
|
else:
|
||||||
|
routing_key = prompt(
|
||||||
|
"> Please enter the supported model your provider has for inference: ",
|
||||||
|
default="Meta-Llama3.1-8B-Instruct",
|
||||||
|
)
|
||||||
|
|
||||||
|
if api_str == "safety":
|
||||||
|
# check all supported shields
|
||||||
|
for shield_type in BuiltinShield:
|
||||||
|
print(shield_type.value)
|
||||||
|
|
||||||
|
# if api_str == "memory":
|
||||||
|
# # check all supported memory_banks
|
||||||
|
|
||||||
|
config.routing_table[api_str] = [
|
||||||
|
RoutableProviderConfig(
|
||||||
|
routing_key=routing_key,
|
||||||
|
provider_id=p,
|
||||||
|
config=cfg.dict(),
|
||||||
|
)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
config.api_providers[api_str] = GenericProviderConfig(
|
config.api_providers[api_str] = GenericProviderConfig(
|
||||||
provider_id=p,
|
provider_id=p,
|
||||||
|
|
|
@ -23,7 +23,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.providers = {k: v for k, v in inner_impls}
|
self.providers = {k: v for k, v in inner_impls}
|
||||||
print("routing table providers", self.providers)
|
|
||||||
self.routing_keys = list(self.providers.keys())
|
self.routing_keys = list(self.providers.keys())
|
||||||
self.routing_table_config = routing_table_config
|
self.routing_table_config = routing_table_config
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue