forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This is a follow-up to #425. That PR allows for specifying models in the registry, but each entry needs to look like: ```yaml - identifier: ... provider_id: ... provider_resource_identifier: ... ``` This is headache-inducing. The current PR makes this situation better by adopting the shape of our APIs. Namely, we need the user to only specify `model-id`. The rest should be optional and figured out by the Stack. You can always override it. Here's what example `ollama` "full stack" registry looks like (we still need to kill or simplify shield_type crap): ```yaml models: - model_id: Llama3.2-3B-Instruct - model_id: Llama-Guard-3-1B shields: - shield_id: llama_guard shield_type: llama_guard ``` ## Test Plan See test plan for #425. Re-ran it.
147 lines
3.8 KiB
Python
147 lines
3.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from enum import Enum
|
|
from typing import (
|
|
Annotated,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Protocol,
|
|
runtime_checkable,
|
|
Union,
|
|
)
|
|
|
|
from llama_models.schema_utils import json_schema_type, webmethod
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from llama_stack.apis.resource import Resource, ResourceType
|
|
|
|
|
|
@json_schema_type
|
|
class MemoryBankType(Enum):
|
|
vector = "vector"
|
|
keyvalue = "keyvalue"
|
|
keyword = "keyword"
|
|
graph = "graph"
|
|
|
|
|
|
# define params for each type of memory bank, this leads to a tagged union
|
|
# accepted as input from the API or from the config.
|
|
@json_schema_type
|
|
class VectorMemoryBankParams(BaseModel):
|
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
embedding_model: str
|
|
chunk_size_in_tokens: int
|
|
overlap_size_in_tokens: Optional[int] = None
|
|
|
|
|
|
@json_schema_type
|
|
class KeyValueMemoryBankParams(BaseModel):
|
|
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
|
MemoryBankType.keyvalue.value
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class KeywordMemoryBankParams(BaseModel):
|
|
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
|
MemoryBankType.keyword.value
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class GraphMemoryBankParams(BaseModel):
|
|
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
|
|
|
|
BankParams = Annotated[
|
|
Union[
|
|
VectorMemoryBankParams,
|
|
KeyValueMemoryBankParams,
|
|
KeywordMemoryBankParams,
|
|
GraphMemoryBankParams,
|
|
],
|
|
Field(discriminator="memory_bank_type"),
|
|
]
|
|
|
|
|
|
# Some common functionality for memory banks.
|
|
class MemoryBankResourceMixin(Resource):
|
|
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
|
|
|
@property
|
|
def memory_bank_id(self) -> str:
|
|
return self.identifier
|
|
|
|
@property
|
|
def provider_memory_bank_id(self) -> str:
|
|
return self.provider_resource_id
|
|
|
|
|
|
@json_schema_type
|
|
class VectorMemoryBank(MemoryBankResourceMixin):
|
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
embedding_model: str
|
|
chunk_size_in_tokens: int
|
|
overlap_size_in_tokens: Optional[int] = None
|
|
|
|
|
|
@json_schema_type
|
|
class KeyValueMemoryBank(MemoryBankResourceMixin):
|
|
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
|
MemoryBankType.keyvalue.value
|
|
)
|
|
|
|
|
|
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
|
|
@json_schema_type
|
|
class KeywordMemoryBank(MemoryBankResourceMixin):
|
|
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
|
MemoryBankType.keyword.value
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class GraphMemoryBank(MemoryBankResourceMixin):
|
|
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
|
|
|
|
MemoryBank = Annotated[
|
|
Union[
|
|
VectorMemoryBank,
|
|
KeyValueMemoryBank,
|
|
KeywordMemoryBank,
|
|
GraphMemoryBank,
|
|
],
|
|
Field(discriminator="memory_bank_type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class MemoryBankInput(BaseModel):
|
|
memory_bank_id: str
|
|
params: BankParams
|
|
provider_memory_bank_id: Optional[str] = None
|
|
|
|
|
|
@runtime_checkable
|
|
class MemoryBanks(Protocol):
|
|
@webmethod(route="/memory_banks/list", method="GET")
|
|
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
|
|
|
@webmethod(route="/memory_banks/get", method="GET")
|
|
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
|
|
|
|
@webmethod(route="/memory_banks/register", method="POST")
|
|
async def register_memory_bank(
|
|
self,
|
|
memory_bank_id: str,
|
|
params: BankParams,
|
|
provider_id: Optional[str] = None,
|
|
provider_memory_bank_id: Optional[str] = None,
|
|
) -> MemoryBank: ...
|