mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Another round of simplification and clarity for models/shields/memory_banks stuff
This commit is contained in:
parent
73a0a34e39
commit
b55034c0de
27 changed files with 454 additions and 444 deletions
|
|
@ -6,7 +6,16 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -404,6 +413,7 @@ class AgentStepResponse(BaseModel):
|
|||
step: Step
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
@webmethod(route="/agents/create")
|
||||
async def create_agent(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -47,6 +47,7 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
completion_message_batch: List[CompletionMessage]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BatchInference(Protocol):
|
||||
@webmethod(route="/batch_inference/completion")
|
||||
async def batch_completion(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, Union
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -177,6 +177,7 @@ class ModelStore(Protocol):
|
|||
def get_model(self, identifier: str) -> ModelDef: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
|
|
@ -214,6 +215,3 @@ class Inference(Protocol):
|
|||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse: ...
|
||||
|
||||
@webmethod(route="/inference/register_model")
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -29,6 +29,7 @@ class HealthInfo(BaseModel):
|
|||
# TODO: add a provider level status
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/providers/list", method="GET")
|
||||
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -42,6 +42,7 @@ class MemoryBankStore(Protocol):
|
|||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Memory(Protocol):
|
||||
memory_bank_store: MemoryBankStore
|
||||
|
||||
|
|
@ -55,13 +56,6 @@ class Memory(Protocol):
|
|||
ttl_seconds: Optional[int] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/update")
|
||||
async def update_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/query")
|
||||
async def query_documents(
|
||||
self,
|
||||
|
|
@ -69,20 +63,3 @@ class Memory(Protocol):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
|
||||
@webmethod(route="/memory/documents/get", method="GET")
|
||||
async def get_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> List[MemoryBankDocument]: ...
|
||||
|
||||
@webmethod(route="/memory/documents/delete", method="DELETE")
|
||||
async def delete_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/register_memory_bank")
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Protocol, Union
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -22,7 +22,8 @@ class MemoryBankType(Enum):
|
|||
|
||||
class CommonDef(BaseModel):
|
||||
identifier: str
|
||||
provider_id: Optional[str] = None
|
||||
# Hack: move this out later
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -58,13 +59,20 @@ MemoryBankDef = Annotated[
|
|||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
MemoryBankDefWithProvider = MemoryBankDef
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory_banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/get", method="GET")
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
) -> Optional[MemoryBankDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/register", method="POST")
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -4,34 +4,39 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelDef(BaseModel):
|
||||
identifier: str = Field(
|
||||
description="A unique identifier for the model type",
|
||||
description="A unique name for the model type",
|
||||
)
|
||||
llama_model: str = Field(
|
||||
description="Pointer to the core Llama family model",
|
||||
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
||||
)
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this model"
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this model",
|
||||
)
|
||||
# For now, we are only supporting core llama models but as soon as finetuned
|
||||
# and other custom models (for example various quantizations) are allowed, there
|
||||
# will be more metadata fields here
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelDefWithProvider(ModelDef):
|
||||
provider_id: str = Field(
|
||||
description="The provider ID for this model",
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models/list", method="GET")
|
||||
async def list_models(self) -> List[ModelDef]: ...
|
||||
async def list_models(self) -> List[ModelDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/models/get", method="GET")
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/models/register", method="POST")
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
async def register_model(self, model: ModelDefWithProvider) -> None: ...
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Protocol
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -42,6 +42,7 @@ class ShieldStore(Protocol):
|
|||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Safety(Protocol):
|
||||
shield_store: ShieldStore
|
||||
|
||||
|
|
@ -49,6 +50,3 @@ class Safety(Protocol):
|
|||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse: ...
|
||||
|
||||
@webmethod(route="/safety/register_shield")
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -26,21 +26,26 @@ class ShieldDef(BaseModel):
|
|||
type: str = Field(
|
||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||
)
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this shield"
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional parameters needed for this shield",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldDefWithProvider(ShieldDef):
|
||||
provider_id: str = Field(
|
||||
description="The provider ID for this shield type",
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields/list", method="GET")
|
||||
async def list_shields(self) -> List[ShieldDef]: ...
|
||||
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/register", method="POST")
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -123,6 +123,7 @@ Event = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/log_event")
|
||||
async def log_event(self, event: Event) -> None: ...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue