mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Introduce model_store, shield_store, memory_bank_store
This commit is contained in:
parent
e45a417543
commit
91e0063593
19 changed files with 172 additions and 297 deletions
|
|
@ -173,7 +173,13 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
def get_model(self, identifier: str) -> ModelDef: ...
|
||||
|
||||
|
||||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
@webmethod(route="/inference/completion")
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -207,9 +213,3 @@ class Inference(Protocol):
|
|||
|
||||
@webmethod(route="/inference/register_model")
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
|
||||
@webmethod(route="/inference/list_models")
|
||||
async def list_models(self) -> List[ModelDef]: ...
|
||||
|
||||
@webmethod(route="/inference/get_model")
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
||||
|
|
|
|||
|
|
@ -38,7 +38,13 @@ class QueryDocumentsResponse(BaseModel):
|
|||
scores: List[float]
|
||||
|
||||
|
||||
class MemoryBankStore(Protocol):
|
||||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
||||
|
||||
|
||||
class Memory(Protocol):
|
||||
memory_bank_store: MemoryBankStore
|
||||
|
||||
# this will just block now until documents are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
@webmethod(route="/memory/insert")
|
||||
|
|
@ -80,9 +86,3 @@ class Memory(Protocol):
|
|||
|
||||
@webmethod(route="/memory/register_memory_bank")
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/list_memory_banks")
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||
|
||||
@webmethod(route="/memory/get_memory_bank")
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
||||
|
|
|
|||
|
|
@ -38,7 +38,13 @@ class RunShieldResponse(BaseModel):
|
|||
violation: Optional[SafetyViolation] = None
|
||||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
|
||||
|
||||
class Safety(Protocol):
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run_shield")
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
|
|
@ -46,9 +52,3 @@ class Safety(Protocol):
|
|||
|
||||
@webmethod(route="/safety/register_shield")
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
|
||||
@webmethod(route="/safety/list_shields")
|
||||
async def list_shields(self) -> List[ShieldDef]: ...
|
||||
|
||||
@webmethod(route="/safety/get_shield")
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue