Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -6,7 +6,16 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
@ -404,6 +413,7 @@ class AgentStepResponse(BaseModel):
step: Step
@runtime_checkable
class Agents(Protocol):
@webmethod(route="/agents/create")
async def create_agent(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@ -47,6 +47,7 @@ class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@runtime_checkable
class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion")
async def batch_completion(

View file

@ -6,7 +6,7 @@
from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
from llama_models.schema_utils import json_schema_type, webmethod
@ -177,6 +177,7 @@ class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
@runtime_checkable
class Inference(Protocol):
model_store: ModelStore
@ -214,6 +215,3 @@ class Inference(Protocol):
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
@webmethod(route="/inference/register_model")
async def register_model(self, model: ModelDef) -> None: ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Protocol
from typing import Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@ -29,6 +29,7 @@ class HealthInfo(BaseModel):
# TODO: add a provider level status
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
async def list_providers(self) -> Dict[str, ProviderInfo]: ...

View file

@ -8,7 +8,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@ -42,6 +42,7 @@ class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
@runtime_checkable
class Memory(Protocol):
memory_bank_store: MemoryBankStore
@ -55,13 +56,6 @@ class Memory(Protocol):
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory/update")
async def update_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory/query")
async def query_documents(
self,
@ -69,20 +63,3 @@ class Memory(Protocol):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory/documents/get", method="GET")
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory/documents/delete", method="DELETE")
async def delete_documents(
self,
bank_id: str,
document_ids: List[str],
) -> None: ...
@webmethod(route="/memory/register_memory_bank")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -22,7 +22,8 @@ class MemoryBankType(Enum):
class CommonDef(BaseModel):
identifier: str
provider_id: Optional[str] = None
# Hack: move this out later
provider_id: str = ""
@json_schema_type
@ -58,13 +59,20 @@ MemoryBankDef = Annotated[
Field(discriminator="type"),
]
MemoryBankDefWithProvider = MemoryBankDef
@runtime_checkable
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]: ...
@webmethod(route="/memory_banks/register", method="POST")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None: ...

View file

@ -4,34 +4,39 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@json_schema_type
class ModelDef(BaseModel):
identifier: str = Field(
description="A unique identifier for the model type",
description="A unique name for the model type",
)
llama_model: str = Field(
description="Pointer to the core Llama family model",
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this model"
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)
# For now, we are only supporting core llama models but as soon as finetuned
# and other custom models (for example various quantizations) are allowed, there
# will be more metadata fields here
@json_schema_type
class ModelDefWithProvider(ModelDef):
provider_id: str = Field(
description="The provider ID for this model",
)
@runtime_checkable
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[ModelDef]: ...
async def list_models(self) -> List[ModelDefWithProvider]: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: ...
@webmethod(route="/models/register", method="POST")
async def register_model(self, model: ModelDef) -> None: ...
async def register_model(self, model: ModelDefWithProvider) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Protocol
from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@ -42,6 +42,7 @@ class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
@runtime_checkable
class Safety(Protocol):
shield_store: ShieldStore
@ -49,6 +50,3 @@ class Safety(Protocol):
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...
@webmethod(route="/safety/register_shield")
async def register_shield(self, shield: ShieldDef) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -26,21 +26,26 @@ class ShieldDef(BaseModel):
type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum"
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this shield"
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional parameters needed for this shield",
)
@json_schema_type
class ShieldDefWithProvider(ShieldDef):
provider_id: str = Field(
description="The provider ID for this shield type",
)
@runtime_checkable
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldDef]: ...
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
@webmethod(route="/shields/register", method="POST")
async def register_shield(self, shield: ShieldDef) -> None: ...
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Literal, Optional, Protocol, Union
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -123,6 +123,7 @@ Event = Annotated[
]
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event")
async def log_event(self, event: Event) -> None: ...