diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 896fd818e..f0f02b3c5 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -10,15 +10,13 @@ from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import Field +from pydantic import BaseModel, Field from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.resource import Resource +from llama_stack.apis.resource import Resource, ResourceType -@json_schema_type -class Dataset(Resource): - type: Literal["dataset"] = "dataset" +class CommonDatasetFields(BaseModel): schema: Dict[str, ParamType] url: URL metadata: Dict[str, Any] = Field( @@ -27,6 +25,26 @@ class Dataset(Resource): ) +@json_schema_type +class Dataset(CommonDatasetFields, Resource): + type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value + + @property + def dataset_id(self) -> str: + return self.identifier + + @property + def provider_dataset_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class DatasetInput(CommonDatasetFields, BaseModel): + dataset_id: str + provider_id: Optional[str] = None + provider_dataset_id: Optional[str] = None + + class Datasets(Protocol): @webmethod(route="/datasets/register", method="POST") async def register_dataset( diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 870673e58..10c35c3ee 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -7,14 +7,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import Field +from pydantic import BaseModel, Field -from llama_stack.apis.resource import Resource +from llama_stack.apis.resource import Resource, ResourceType -@json_schema_type -class EvalTask(Resource): - type: Literal["eval_task"] = "eval_task" +class CommonEvalTaskFields(BaseModel): dataset_id: str scoring_functions: List[str] metadata: Dict[str, Any] = Field( @@ -23,6 +21,26 @@ class EvalTask(Resource): ) +@json_schema_type +class EvalTask(CommonEvalTaskFields, Resource): + type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value + + @property + def eval_task_id(self) -> str: + return self.identifier + + @property + def provider_eval_task_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class EvalTaskInput(CommonEvalTaskFields, BaseModel): + eval_task_id: str + provider_id: Optional[str] = None + provider_eval_task_id: Optional[str] = None + + @runtime_checkable class EvalTasks(Protocol): @webmethod(route="/eval_tasks/list", method="GET") diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 303104f25..c432f094a 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -30,37 +30,8 @@ class MemoryBankType(Enum): graph = "graph" -@json_schema_type -class VectorMemoryBank(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value - embedding_model: str - chunk_size_in_tokens: int - overlap_size_in_tokens: Optional[int] = None - - -@json_schema_type -class KeyValueMemoryBank(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( - MemoryBankType.keyvalue.value - ) - - -@json_schema_type -class KeywordMemoryBank(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: Literal[MemoryBankType.keyword.value] = ( - MemoryBankType.keyword.value - ) - - -@json_schema_type -class GraphMemoryBank(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value - - +# define params for each type of memory bank, this leads to a tagged union +# accepted as input from the API or from the config. @json_schema_type class VectorMemoryBankParams(BaseModel): memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value @@ -88,6 +59,58 @@ class GraphMemoryBankParams(BaseModel): memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value +BankParams = Annotated[ + Union[ + VectorMemoryBankParams, + KeyValueMemoryBankParams, + KeywordMemoryBankParams, + GraphMemoryBankParams, + ], + Field(discriminator="memory_bank_type"), +] + + +# Some common functionality for memory banks. +class MemoryBankResourceMixin(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value + + @property + def memory_bank_id(self) -> str: + return self.identifier + + @property + def provider_memory_bank_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class VectorMemoryBank(MemoryBankResourceMixin): + memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value + embedding_model: str + chunk_size_in_tokens: int + overlap_size_in_tokens: Optional[int] = None + + +@json_schema_type +class KeyValueMemoryBank(MemoryBankResourceMixin): + memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( + MemoryBankType.keyvalue.value + ) + + +# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention. +@json_schema_type +class KeywordMemoryBank(MemoryBankResourceMixin): + memory_bank_type: Literal[MemoryBankType.keyword.value] = ( + MemoryBankType.keyword.value + ) + + +@json_schema_type +class GraphMemoryBank(MemoryBankResourceMixin): + memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value + + MemoryBank = Annotated[ Union[ VectorMemoryBank, @@ -98,15 +121,13 @@ MemoryBank = Annotated[ Field(discriminator="memory_bank_type"), ] -BankParams = Annotated[ - Union[ - VectorMemoryBankParams, - KeyValueMemoryBankParams, - KeywordMemoryBankParams, - GraphMemoryBankParams, - ], - Field(discriminator="memory_bank_type"), -] + +@json_schema_type +class MemoryBankInput(BaseModel): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value + memory_bank_id: str + params: BankParams + provider_memory_bank_id: Optional[str] = None @runtime_checkable @@ -123,5 +144,5 @@ class MemoryBanks(Protocol): memory_bank_id: str, params: BankParams, provider_id: Optional[str] = None, - provider_memorybank_id: Optional[str] = None, + provider_memory_bank_id: Optional[str] = None, ) -> MemoryBank: ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index bb8d2c4ea..a5d226886 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -7,20 +7,38 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import Field +from pydantic import BaseModel, Field from llama_stack.apis.resource import Resource, ResourceType -@json_schema_type -class Model(Resource): - type: Literal[ResourceType.model.value] = ResourceType.model.value +class CommonModelFields(BaseModel): metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this model", ) +@json_schema_type +class Model(CommonModelFields, Resource): + type: Literal[ResourceType.model.value] = ResourceType.model.value + + @property + def model_id(self) -> str: + return self.identifier + + @property + def provider_model_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class ModelInput(CommonModelFields): + model_id: str + provider_id: Optional[str] = None + provider_model_id: Optional[str] = None + + @runtime_checkable class Models(Protocol): @webmethod(route="/models/list", method="GET") diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 0e488190b..93a3718a0 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -17,14 +17,12 @@ class ResourceType(Enum): memory_bank = "memory_bank" dataset = "dataset" scoring_function = "scoring_function" + eval_task = "eval_task" class Resource(BaseModel): """Base class for all Llama Stack resources""" - # TODO: I think we need to move these into the child classes - # and make them `model_id`, `shield_id`, etc. because otherwise - # the config file has these confusing generic names in there identifier: str = Field( description="Unique identifier for this resource in llama stack" ) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 6b2408e0d..7a2a83c72 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -66,11 +66,7 @@ ScoringFnParams = Annotated[ ] -@json_schema_type -class ScoringFn(Resource): - type: Literal[ResourceType.scoring_function.value] = ( - ResourceType.scoring_function.value - ) +class CommonScoringFnFields(BaseModel): description: Optional[str] = None metadata: Dict[str, Any] = Field( default_factory=dict, @@ -85,6 +81,28 @@ class ScoringFn(Resource): ) +@json_schema_type +class ScoringFn(CommonScoringFnFields, Resource): + type: Literal[ResourceType.scoring_function.value] = ( + ResourceType.scoring_function.value + ) + + @property + def scoring_fn_id(self) -> str: + return self.identifier + + @property + def provider_scoring_fn_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class ScoringFnInput(CommonScoringFnFields, BaseModel): + scoring_fn_id: str + provider_id: Optional[str] = None + provider_scoring_fn_id: Optional[str] = None + + @runtime_checkable class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/list", method="GET") diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 42fe717fa..1dcfd4f4c 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -8,6 +8,7 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType @@ -20,13 +21,30 @@ class ShieldType(Enum): prompt_guard = "prompt_guard" +class CommonShieldFields(BaseModel): + shield_type: ShieldType + params: Optional[Dict[str, Any]] = None + + @json_schema_type -class Shield(Resource): +class Shield(CommonShieldFields, Resource): """A safety shield resource that can be used to check content""" type: Literal[ResourceType.shield.value] = ResourceType.shield.value - shield_type: ShieldType - params: Dict[str, Any] = {} + + @property + def shield_id(self) -> str: + return self.identifier + + @property + def provider_shield_id(self) -> str: + return self.provider_resource_id + + +class ShieldInput(CommonShieldFields): + shield_id: str + provider_id: Optional[str] = None + provider_shield_id: Optional[str] = None @runtime_checkable diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 2cba5b052..4aaf9c38a 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -18,7 +18,7 @@ from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.eval import Eval -from llama_stack.apis.eval_tasks import EvalTask +from llama_stack.apis.eval_tasks import EvalTaskInput from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety @@ -152,12 +152,12 @@ a default SQLite store will be used.""", ) # registry of "resources" in the distribution - models: List[Model] = Field(default_factory=list) - shields: List[Shield] = Field(default_factory=list) - memory_banks: List[MemoryBank] = Field(default_factory=list) - datasets: List[Dataset] = Field(default_factory=list) - scoring_fns: List[ScoringFn] = Field(default_factory=list) - eval_tasks: List[EvalTask] = Field(default_factory=list) + models: List[ModelInput] = Field(default_factory=list) + shields: List[ShieldInput] = Field(default_factory=list) + memory_banks: List[MemoryBankInput] = Field(default_factory=list) + datasets: List[DatasetInput] = Field(default_factory=list) + scoring_fns: List[ScoringFnInput] = Field(default_factory=list) + eval_tasks: List[EvalTaskInput] = Field(default_factory=list) class BuildConfig(BaseModel): diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index efed54ab8..1aba4884b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -277,10 +277,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): memory_bank_id: str, params: BankParams, provider_id: Optional[str] = None, - provider_memorybank_id: Optional[str] = None, + provider_memory_bank_id: Optional[str] = None, ) -> MemoryBank: - if provider_memorybank_id is None: - provider_memorybank_id = memory_bank_id + if provider_memory_bank_id is None: + provider_memory_bank_id = memory_bank_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this shield type if len(self.impls_by_provider_id) == 1: @@ -295,7 +295,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): "identifier": memory_bank_id, "type": ResourceType.memory_bank.value, "provider_id": provider_id, - "provider_resource_id": provider_memorybank_id, + "provider_resource_id": provider_memory_bank_id, **params.model_dump(), }, ) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index b496d9f36..3afd51304 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -68,30 +68,29 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) - objects = [ - *run_config.models, - *run_config.shields, - *run_config.memory_banks, - *run_config.datasets, - *run_config.scoring_fns, - *run_config.eval_tasks, - ] - for obj in objects: - await dist_registry.register(obj) - resources = [ - ("models", Api.models), - ("shields", Api.shields), - ("memory_banks", Api.memory_banks), - ("datasets", Api.datasets), - ("scoring_fns", Api.scoring_functions), - ("eval_tasks", Api.eval_tasks), + ("models", Api.models, "register_model", "list_models"), + ("shields", Api.shields, "register_shield", "list_shields"), + ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"), + ("datasets", Api.datasets, "register_dataset", "list_datasets"), + ( + "scoring_fns", + Api.scoring_functions, + "register_scoring_function", + "list_scoring_functions", + ), + ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), ] - for rsrc, api in resources: + for rsrc, api, register_method, list_method in resources: + objects = getattr(run_config, rsrc) if api not in impls: continue - method = getattr(impls[api], f"list_{api.value}") + method = getattr(impls[api], register_method) + for obj in objects: + await method(**obj.model_dump()) + + method = getattr(impls[api], list_method) for obj in await method(): print( f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 6ee17ff1f..64f493b88 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -9,7 +9,7 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.apis.models import Model +from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.agents.meta_reference import ( @@ -71,13 +71,9 @@ async def agents_stack(request, inference_model, safety_model): if fixture.provider_data: provider_data.update(fixture.provider_data) - inf_provider_id = providers["inference"][0].provider_id - safety_provider_id = providers["safety"][0].provider_id - - shield = get_shield_to_register( - providers["safety"][0].provider_type, safety_provider_id, safety_model + shield_input = get_shield_to_register( + providers["safety"][0].provider_type, safety_model ) - inference_models = ( inference_model if isinstance(inference_model, list) else [inference_model] ) @@ -86,13 +82,11 @@ async def agents_stack(request, inference_model, safety_model): providers, provider_data, models=[ - Model( - identifier=model, - provider_id=inf_provider_id, - provider_resource_id=model, + ModelInput( + model_id=model, ) for model in inference_models ], - shields=[shield], + shields=[shield_input], ) return impls[Api.agents], impls[Api.memory] diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index fe91c6e03..d35ebab28 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -9,7 +9,7 @@ import os import pytest import pytest_asyncio -from llama_stack.apis.models import Model +from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.inference.meta_reference import ( @@ -162,10 +162,8 @@ async def inference_stack(request, inference_model): {"inference": inference_fixture.providers}, inference_fixture.provider_data, models=[ - Model( - identifier=inference_model, - provider_resource_id=inference_model, - provider_id=inference_fixture.providers[0].provider_id, + ModelInput( + model_id=inference_model, ) ], ) diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 5e553830c..66576e9d7 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -7,9 +7,9 @@ import pytest import pytest_asyncio -from llama_stack.apis.models import Model +from llama_stack.apis.models import ModelInput -from llama_stack.apis.shields import Shield, ShieldType +from llama_stack.apis.shields import ShieldInput, ShieldType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig @@ -99,28 +99,21 @@ async def safety_stack(inference_model, safety_model, request): provider_data.update(safety_fixture.provider_data) shield_provider_type = safety_fixture.providers[0].provider_type - shield = get_shield_to_register( - shield_provider_type, safety_fixture.providers[0].provider_id, safety_model - ) + shield_input = get_shield_to_register(shield_provider_type, safety_model) impls = await resolve_impls_for_test_v2( [Api.safety, Api.shields, Api.inference], providers, provider_data, - models=[ - Model( - identifier=inference_model, - provider_id=inference_fixture.providers[0].provider_id, - provider_resource_id=inference_model, - ) - ], - shields=[shield], + models=[ModelInput(model_id=inference_model)], + shields=[shield_input], ) + shield = await impls[Api.shields].get_shield(shield_input.shield_id) return impls[Api.safety], impls[Api.shields], shield -def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str): +def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: shield_config = {} shield_type = ShieldType.llama_guard identifier = "llama_guard" @@ -133,10 +126,8 @@ def get_shield_to_register(provider_type: str, provider_id: str, safety_model: s shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") shield_type = ShieldType.generic_content_shield - return Shield( - identifier=identifier, + return ShieldInput( + shield_id=identifier, shield_type=shield_type, params=shield_config, - provider_id=provider_id, - provider_resource_id=identifier, )