Enable sane naming of registered objects with defaults

This commit is contained in:
Ashwin Bharambe 2024-11-12 10:17:34 -08:00
parent 9e925f43e5
commit 48a6e27de9
13 changed files with 222 additions and 131 deletions

View file

@ -10,15 +10,13 @@ from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import Field
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource
from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class Dataset(Resource):
type: Literal["dataset"] = "dataset"
class CommonDatasetFields(BaseModel):
schema: Dict[str, ParamType]
url: URL
metadata: Dict[str, Any] = Field(
@ -27,6 +25,26 @@ class Dataset(Resource):
)
@json_schema_type
class Dataset(CommonDatasetFields, Resource):
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
@property
def dataset_id(self) -> str:
return self.identifier
@property
def provider_dataset_id(self) -> str:
return self.provider_resource_id
@json_schema_type
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class Datasets(Protocol):
@webmethod(route="/datasets/register", method="POST")
async def register_dataset(

View file

@ -7,14 +7,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import Field
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource
from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class EvalTask(Resource):
type: Literal["eval_task"] = "eval_task"
class CommonEvalTaskFields(BaseModel):
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
@ -23,6 +21,26 @@ class EvalTask(Resource):
)
@json_schema_type
class EvalTask(CommonEvalTaskFields, Resource):
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
@property
def eval_task_id(self) -> str:
return self.identifier
@property
def provider_eval_task_id(self) -> str:
return self.provider_resource_id
@json_schema_type
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str
provider_id: Optional[str] = None
provider_eval_task_id: Optional[str] = None
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval_tasks/list", method="GET")

View file

@ -30,37 +30,8 @@ class MemoryBankType(Enum):
graph = "graph"
@json_schema_type
class VectorMemoryBank(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBank(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value
)
@json_schema_type
class KeywordMemoryBank(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value
)
@json_schema_type
class GraphMemoryBank(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
# define params for each type of memory bank, this leads to a tagged union
# accepted as input from the API or from the config.
@json_schema_type
class VectorMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
@ -88,6 +59,58 @@ class GraphMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
BankParams = Annotated[
Union[
VectorMemoryBankParams,
KeyValueMemoryBankParams,
KeywordMemoryBankParams,
GraphMemoryBankParams,
],
Field(discriminator="memory_bank_type"),
]
# Some common functionality for memory banks.
class MemoryBankResourceMixin(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
@property
def memory_bank_id(self) -> str:
return self.identifier
@property
def provider_memory_bank_id(self) -> str:
return self.provider_resource_id
@json_schema_type
class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value
)
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
@json_schema_type
class KeywordMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value
)
@json_schema_type
class GraphMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBank = Annotated[
Union[
VectorMemoryBank,
@ -98,15 +121,13 @@ MemoryBank = Annotated[
Field(discriminator="memory_bank_type"),
]
BankParams = Annotated[
Union[
VectorMemoryBankParams,
KeyValueMemoryBankParams,
KeywordMemoryBankParams,
GraphMemoryBankParams,
],
Field(discriminator="memory_bank_type"),
]
@json_schema_type
class MemoryBankInput(BaseModel):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_id: str
params: BankParams
provider_memory_bank_id: Optional[str] = None
@runtime_checkable
@ -123,5 +144,5 @@ class MemoryBanks(Protocol):
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memorybank_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...

View file

@ -7,20 +7,38 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import Field
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class Model(Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
class CommonModelFields(BaseModel):
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)
@json_schema_type
class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
@property
def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
return self.provider_resource_id
@json_schema_type
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
provider_model_id: Optional[str] = None
@runtime_checkable
class Models(Protocol):
@webmethod(route="/models/list", method="GET")

View file

@ -17,14 +17,12 @@ class ResourceType(Enum):
memory_bank = "memory_bank"
dataset = "dataset"
scoring_function = "scoring_function"
eval_task = "eval_task"
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
# TODO: I think we need to move these into the child classes
# and make them `model_id`, `shield_id`, etc. because otherwise
# the config file has these confusing generic names in there
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)

View file

@ -66,11 +66,7 @@ ScoringFnParams = Annotated[
]
@json_schema_type
class ScoringFn(Resource):
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
class CommonScoringFnFields(BaseModel):
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
@ -85,6 +81,28 @@ class ScoringFn(Resource):
)
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str:
return self.provider_resource_id
@json_schema_type
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
provider_scoring_fn_id: Optional[str] = None
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET")

View file

@ -8,6 +8,7 @@ from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
@ -20,13 +21,30 @@ class ShieldType(Enum):
prompt_guard = "prompt_guard"
class CommonShieldFields(BaseModel):
shield_type: ShieldType
params: Optional[Dict[str, Any]] = None
@json_schema_type
class Shield(Resource):
class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content"""
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
shield_type: ShieldType
params: Dict[str, Any] = {}
@property
def shield_id(self) -> str:
return self.identifier
@property
def provider_shield_id(self) -> str:
return self.provider_resource_id
class ShieldInput(CommonShieldFields):
shield_id: str
provider_id: Optional[str] = None
provider_shield_id: Optional[str] = None
@runtime_checkable

View file

@ -18,7 +18,7 @@ from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
@ -152,12 +152,12 @@ a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: List[Model] = Field(default_factory=list)
shields: List[Shield] = Field(default_factory=list)
memory_banks: List[MemoryBank] = Field(default_factory=list)
datasets: List[Dataset] = Field(default_factory=list)
scoring_fns: List[ScoringFn] = Field(default_factory=list)
eval_tasks: List[EvalTask] = Field(default_factory=list)
models: List[ModelInput] = Field(default_factory=list)
shields: List[ShieldInput] = Field(default_factory=list)
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
class BuildConfig(BaseModel):

View file

@ -277,10 +277,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memorybank_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
if provider_memorybank_id is None:
provider_memorybank_id = memory_bank_id
if provider_memory_bank_id is None:
provider_memory_bank_id = memory_bank_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
@ -295,7 +295,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memorybank_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
},
)

View file

@ -68,30 +68,29 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
objects = [
*run_config.models,
*run_config.shields,
*run_config.memory_banks,
*run_config.datasets,
*run_config.scoring_fns,
*run_config.eval_tasks,
]
for obj in objects:
await dist_registry.register(obj)
resources = [
("models", Api.models),
("shields", Api.shields),
("memory_banks", Api.memory_banks),
("datasets", Api.datasets),
("scoring_fns", Api.scoring_functions),
("eval_tasks", Api.eval_tasks),
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
]
for rsrc, api in resources:
for rsrc, api, register_method, list_method in resources:
objects = getattr(run_config, rsrc)
if api not in impls:
continue
method = getattr(impls[api], f"list_{api.value}")
method = getattr(impls[api], register_method)
for obj in objects:
await method(**obj.model_dump())
method = getattr(impls[api], list_method)
for obj in await method():
print(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",

View file

@ -9,7 +9,7 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
@ -71,13 +71,9 @@ async def agents_stack(request, inference_model, safety_model):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inf_provider_id = providers["inference"][0].provider_id
safety_provider_id = providers["safety"][0].provider_id
shield = get_shield_to_register(
providers["safety"][0].provider_type, safety_provider_id, safety_model
shield_input = get_shield_to_register(
providers["safety"][0].provider_type, safety_model
)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
@ -86,13 +82,11 @@ async def agents_stack(request, inference_model, safety_model):
providers,
provider_data,
models=[
Model(
identifier=model,
provider_id=inf_provider_id,
provider_resource_id=model,
ModelInput(
model_id=model,
)
for model in inference_models
],
shields=[shield],
shields=[shield_input],
)
return impls[Api.agents], impls[Api.memory]

View file

@ -9,7 +9,7 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
@ -162,10 +162,8 @@ async def inference_stack(request, inference_model):
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
Model(
identifier=inference_model,
provider_resource_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
ModelInput(
model_id=inference_model,
)
],
)

View file

@ -7,9 +7,9 @@
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.apis.models import ModelInput
from llama_stack.apis.shields import Shield, ShieldType
from llama_stack.apis.shields import ShieldInput, ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@ -99,28 +99,21 @@ async def safety_stack(inference_model, safety_model, request):
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield = get_shield_to_register(
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
)
shield_input = get_shield_to_register(shield_provider_type, safety_model)
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
models=[
Model(
identifier=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
provider_resource_id=inference_model,
)
],
shields=[shield],
models=[ModelInput(model_id=inference_model)],
shields=[shield_input],
)
shield = await impls[Api.shields].get_shield(shield_input.shield_id)
return impls[Api.safety], impls[Api.shields], shield
def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@ -133,10 +126,8 @@ def get_shield_to_register(provider_type: str, provider_id: str, safety_model: s
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
return Shield(
identifier=identifier,
return ShieldInput(
shield_id=identifier,
shield_type=shield_type,
params=shield_config,
provider_id=provider_id,
provider_resource_id=identifier,
)