Introduce model_store, shield_store, memory_bank_store

This commit is contained in:
Ashwin Bharambe 2024-10-06 16:29:33 -07:00 committed by Ashwin Bharambe
parent e45a417543
commit 91e0063593
19 changed files with 172 additions and 297 deletions

View file

@ -173,7 +173,13 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion")
async def completion(
self,
@ -207,9 +213,3 @@ class Inference(Protocol):
@webmethod(route="/inference/register_model")
async def register_model(self, model: ModelDef) -> None: ...
@webmethod(route="/inference/list_models")
async def list_models(self) -> List[ModelDef]: ...
@webmethod(route="/inference/get_model")
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...

View file

@ -38,7 +38,13 @@ class QueryDocumentsResponse(BaseModel):
scores: List[float]
class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
class Memory(Protocol):
memory_bank_store: MemoryBankStore
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory/insert")
@ -80,9 +86,3 @@ class Memory(Protocol):
@webmethod(route="/memory/register_memory_bank")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
@webmethod(route="/memory/list_memory_banks")
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
@webmethod(route="/memory/get_memory_bank")
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...

View file

@ -38,7 +38,13 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
@ -46,9 +52,3 @@ class Safety(Protocol):
@webmethod(route="/safety/register_shield")
async def register_shield(self, shield: ShieldDef) -> None: ...
@webmethod(route="/safety/list_shields")
async def list_shields(self) -> List[ShieldDef]: ...
@webmethod(route="/safety/get_shield")
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...