diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 196a400f8..231633464 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-11 18:44:30.967321" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" }, "servers": [ { @@ -5778,8 +5778,7 @@ "provider_resource_id", "provider_id", "type", - "shield_type", - "params" + "shield_type" ], "title": "A safety shield resource that can be used to check content" }, @@ -7027,7 +7026,7 @@ "provider_id": { "type": "string" }, - "provider_memorybank_id": { + "provider_memory_bank_id": { "type": "string" } }, @@ -7854,59 +7853,59 @@ } ], "tags": [ - { - "name": "Datasets" - }, - { - "name": "Telemetry" - }, - { - "name": "PostTraining" - }, - { - "name": "MemoryBanks" - }, - { - "name": "Eval" - }, - { - "name": "Memory" - }, - { - "name": "EvalTasks" - }, - { - "name": "Models" - }, - { - "name": "Scoring" - }, { "name": "Inference" }, - { - "name": "Shields" - }, - { - "name": "DatasetIO" - }, - { - "name": "Safety" - }, { "name": "Agents" }, { - "name": "SyntheticDataGeneration" + "name": "Telemetry" + }, + { + "name": "Eval" + }, + { + "name": "Models" + }, + { + "name": "Inspect" + }, + { + "name": "EvalTasks" }, { "name": "ScoringFunctions" }, { - "name": "BatchInference" + "name": "Memory" }, { - "name": "Inspect" + "name": "Safety" + }, + { + "name": "DatasetIO" + }, + { + "name": "MemoryBanks" + }, + { + "name": "Shields" + }, + { + "name": "PostTraining" + }, + { + "name": "Datasets" + }, + { + "name": "Scoring" + }, + { + "name": "SyntheticDataGeneration" + }, + { + "name": "BatchInference" }, { "name": "BuiltinTool", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 164d3168c..4e02e8075 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -2068,7 +2068,7 @@ components: - $ref: '#/components/schemas/GraphMemoryBankParams' provider_id: type: string - provider_memorybank_id: + provider_memory_bank_id: type: string required: - memory_bank_id @@ -2710,7 +2710,6 @@ components: - provider_id - type - shield_type - - params title: A safety shield resource that can be used to check content type: object ShieldCallStep: @@ -3398,7 +3397,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-11 18:44:30.967321" + \ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4762,24 +4761,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Datasets -- name: Telemetry -- name: PostTraining -- name: MemoryBanks -- name: Eval -- name: Memory -- name: EvalTasks -- name: Models -- name: Scoring - name: Inference -- name: Shields -- name: DatasetIO -- name: Safety - name: Agents -- name: SyntheticDataGeneration -- name: ScoringFunctions -- name: BatchInference +- name: Telemetry +- name: Eval +- name: Models - name: Inspect +- name: EvalTasks +- name: ScoringFunctions +- name: Memory +- name: Safety +- name: DatasetIO +- name: MemoryBanks +- name: Shields +- name: PostTraining +- name: Datasets +- name: Scoring +- name: SyntheticDataGeneration +- name: BatchInference - description: name: BuiltinTool - description: str: + return self.identifier + + @property + def provider_dataset_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class DatasetInput(CommonDatasetFields, BaseModel): + dataset_id: str + provider_id: Optional[str] = None + provider_dataset_id: Optional[str] = None + + class Datasets(Protocol): @webmethod(route="/datasets/register", method="POST") async def register_dataset( diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 870673e58..10c35c3ee 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -7,14 +7,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import Field +from pydantic import BaseModel, Field -from llama_stack.apis.resource import Resource +from llama_stack.apis.resource import Resource, ResourceType -@json_schema_type -class EvalTask(Resource): - type: Literal["eval_task"] = "eval_task" +class CommonEvalTaskFields(BaseModel): dataset_id: str scoring_functions: List[str] metadata: Dict[str, Any] = Field( @@ -23,6 +21,26 @@ class EvalTask(Resource): ) +@json_schema_type +class EvalTask(CommonEvalTaskFields, Resource): + type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value + + @property + def eval_task_id(self) -> str: + return self.identifier + + @property + def provider_eval_task_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class EvalTaskInput(CommonEvalTaskFields, BaseModel): + eval_task_id: str + provider_id: Optional[str] = None + provider_eval_task_id: Optional[str] = None + + @runtime_checkable class EvalTasks(Protocol): @webmethod(route="/eval_tasks/list", method="GET") diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 303104f25..83b292612 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -30,37 +30,8 @@ class MemoryBankType(Enum): graph = "graph" -@json_schema_type -class VectorMemoryBank(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value - embedding_model: str - chunk_size_in_tokens: int - overlap_size_in_tokens: Optional[int] = None - - -@json_schema_type -class KeyValueMemoryBank(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( - MemoryBankType.keyvalue.value - ) - - -@json_schema_type -class KeywordMemoryBank(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: Literal[MemoryBankType.keyword.value] = ( - MemoryBankType.keyword.value - ) - - -@json_schema_type -class GraphMemoryBank(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value - - +# define params for each type of memory bank, this leads to a tagged union +# accepted as input from the API or from the config. @json_schema_type class VectorMemoryBankParams(BaseModel): memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value @@ -88,6 +59,58 @@ class GraphMemoryBankParams(BaseModel): memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value +BankParams = Annotated[ + Union[ + VectorMemoryBankParams, + KeyValueMemoryBankParams, + KeywordMemoryBankParams, + GraphMemoryBankParams, + ], + Field(discriminator="memory_bank_type"), +] + + +# Some common functionality for memory banks. +class MemoryBankResourceMixin(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value + + @property + def memory_bank_id(self) -> str: + return self.identifier + + @property + def provider_memory_bank_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class VectorMemoryBank(MemoryBankResourceMixin): + memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value + embedding_model: str + chunk_size_in_tokens: int + overlap_size_in_tokens: Optional[int] = None + + +@json_schema_type +class KeyValueMemoryBank(MemoryBankResourceMixin): + memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( + MemoryBankType.keyvalue.value + ) + + +# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention. +@json_schema_type +class KeywordMemoryBank(MemoryBankResourceMixin): + memory_bank_type: Literal[MemoryBankType.keyword.value] = ( + MemoryBankType.keyword.value + ) + + +@json_schema_type +class GraphMemoryBank(MemoryBankResourceMixin): + memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value + + MemoryBank = Annotated[ Union[ VectorMemoryBank, @@ -98,15 +121,12 @@ MemoryBank = Annotated[ Field(discriminator="memory_bank_type"), ] -BankParams = Annotated[ - Union[ - VectorMemoryBankParams, - KeyValueMemoryBankParams, - KeywordMemoryBankParams, - GraphMemoryBankParams, - ], - Field(discriminator="memory_bank_type"), -] + +@json_schema_type +class MemoryBankInput(BaseModel): + memory_bank_id: str + params: BankParams + provider_memory_bank_id: Optional[str] = None @runtime_checkable @@ -123,5 +143,5 @@ class MemoryBanks(Protocol): memory_bank_id: str, params: BankParams, provider_id: Optional[str] = None, - provider_memorybank_id: Optional[str] = None, + provider_memory_bank_id: Optional[str] = None, ) -> MemoryBank: ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index bb8d2c4ea..a5d226886 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -7,20 +7,38 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import Field +from pydantic import BaseModel, Field from llama_stack.apis.resource import Resource, ResourceType -@json_schema_type -class Model(Resource): - type: Literal[ResourceType.model.value] = ResourceType.model.value +class CommonModelFields(BaseModel): metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this model", ) +@json_schema_type +class Model(CommonModelFields, Resource): + type: Literal[ResourceType.model.value] = ResourceType.model.value + + @property + def model_id(self) -> str: + return self.identifier + + @property + def provider_model_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class ModelInput(CommonModelFields): + model_id: str + provider_id: Optional[str] = None + provider_model_id: Optional[str] = None + + @runtime_checkable class Models(Protocol): @webmethod(route="/models/list", method="GET") diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 0e488190b..93a3718a0 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -17,14 +17,12 @@ class ResourceType(Enum): memory_bank = "memory_bank" dataset = "dataset" scoring_function = "scoring_function" + eval_task = "eval_task" class Resource(BaseModel): """Base class for all Llama Stack resources""" - # TODO: I think we need to move these into the child classes - # and make them `model_id`, `shield_id`, etc. because otherwise - # the config file has these confusing generic names in there identifier: str = Field( description="Unique identifier for this resource in llama stack" ) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 6b2408e0d..7a2a83c72 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -66,11 +66,7 @@ ScoringFnParams = Annotated[ ] -@json_schema_type -class ScoringFn(Resource): - type: Literal[ResourceType.scoring_function.value] = ( - ResourceType.scoring_function.value - ) +class CommonScoringFnFields(BaseModel): description: Optional[str] = None metadata: Dict[str, Any] = Field( default_factory=dict, @@ -85,6 +81,28 @@ class ScoringFn(Resource): ) +@json_schema_type +class ScoringFn(CommonScoringFnFields, Resource): + type: Literal[ResourceType.scoring_function.value] = ( + ResourceType.scoring_function.value + ) + + @property + def scoring_fn_id(self) -> str: + return self.identifier + + @property + def provider_scoring_fn_id(self) -> str: + return self.provider_resource_id + + +@json_schema_type +class ScoringFnInput(CommonScoringFnFields, BaseModel): + scoring_fn_id: str + provider_id: Optional[str] = None + provider_scoring_fn_id: Optional[str] = None + + @runtime_checkable class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/list", method="GET") diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 42fe717fa..1dcfd4f4c 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -8,6 +8,7 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType @@ -20,13 +21,30 @@ class ShieldType(Enum): prompt_guard = "prompt_guard" +class CommonShieldFields(BaseModel): + shield_type: ShieldType + params: Optional[Dict[str, Any]] = None + + @json_schema_type -class Shield(Resource): +class Shield(CommonShieldFields, Resource): """A safety shield resource that can be used to check content""" type: Literal[ResourceType.shield.value] = ResourceType.shield.value - shield_type: ShieldType - params: Dict[str, Any] = {} + + @property + def shield_id(self) -> str: + return self.identifier + + @property + def provider_shield_id(self) -> str: + return self.provider_resource_id + + +class ShieldInput(CommonShieldFields): + shield_id: str + provider_id: Optional[str] = None + provider_shield_id: Optional[str] = None @runtime_checkable diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 2cba5b052..4aaf9c38a 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -18,7 +18,7 @@ from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.eval import Eval -from llama_stack.apis.eval_tasks import EvalTask +from llama_stack.apis.eval_tasks import EvalTaskInput from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety @@ -152,12 +152,12 @@ a default SQLite store will be used.""", ) # registry of "resources" in the distribution - models: List[Model] = Field(default_factory=list) - shields: List[Shield] = Field(default_factory=list) - memory_banks: List[MemoryBank] = Field(default_factory=list) - datasets: List[Dataset] = Field(default_factory=list) - scoring_fns: List[ScoringFn] = Field(default_factory=list) - eval_tasks: List[EvalTask] = Field(default_factory=list) + models: List[ModelInput] = Field(default_factory=list) + shields: List[ShieldInput] = Field(default_factory=list) + memory_banks: List[MemoryBankInput] = Field(default_factory=list) + datasets: List[DatasetInput] = Field(default_factory=list) + scoring_fns: List[ScoringFnInput] = Field(default_factory=list) + eval_tasks: List[EvalTaskInput] = Field(default_factory=list) class BuildConfig(BaseModel): diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index efed54ab8..7b369df2c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -32,6 +32,10 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) if obj.provider_id == "remote": + # TODO: this is broken right now because we use the generic + # { identifier, provider_id, provider_resource_id } tuple here + # but the APIs expect things like ModelInput, ShieldInput, etc. + # if this is just a passthrough, we want to let the remote # end actually do the registration with the correct provider obj = obj.model_copy(deep=True) @@ -277,10 +281,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): memory_bank_id: str, params: BankParams, provider_id: Optional[str] = None, - provider_memorybank_id: Optional[str] = None, + provider_memory_bank_id: Optional[str] = None, ) -> MemoryBank: - if provider_memorybank_id is None: - provider_memorybank_id = memory_bank_id + if provider_memory_bank_id is None: + provider_memory_bank_id = memory_bank_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this shield type if len(self.impls_by_provider_id) == 1: @@ -295,7 +299,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): "identifier": memory_bank_id, "type": ResourceType.memory_bank.value, "provider_id": provider_id, - "provider_resource_id": provider_memorybank_id, + "provider_resource_id": provider_memory_bank_id, **params.model_dump(), }, ) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 7fe7d3ca7..3afd51304 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from typing import Any, Dict +from termcolor import colored from termcolor import colored @@ -67,30 +68,29 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) - objects = [ - *run_config.models, - *run_config.shields, - *run_config.memory_banks, - *run_config.datasets, - *run_config.scoring_fns, - *run_config.eval_tasks, - ] - for obj in objects: - await dist_registry.register(obj) - resources = [ - ("models", Api.models), - ("shields", Api.shields), - ("memory_banks", Api.memory_banks), - ("datasets", Api.datasets), - ("scoring_fns", Api.scoring_functions), - ("eval_tasks", Api.eval_tasks), + ("models", Api.models, "register_model", "list_models"), + ("shields", Api.shields, "register_shield", "list_shields"), + ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"), + ("datasets", Api.datasets, "register_dataset", "list_datasets"), + ( + "scoring_fns", + Api.scoring_functions, + "register_scoring_function", + "list_scoring_functions", + ), + ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), ] - for rsrc, api in resources: + for rsrc, api, register_method, list_method in resources: + objects = getattr(run_config, rsrc) if api not in impls: continue - method = getattr(impls[api], f"list_{api.value}") + method = getattr(impls[api], register_method) + for obj in objects: + await method(**obj.model_dump()) + + method = getattr(impls[api], list_method) for obj in await method(): print( f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 9c3ec7750..12d012b16 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -128,7 +128,6 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): pass async def register_shield(self, shield: Shield) -> None: - print(f"Registering shield {shield}") if shield.shield_type != ShieldType.llama_guard: raise ValueError(f"Unsupported shield type: {shield.shield_type}") diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 6ee17ff1f..64f493b88 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -9,7 +9,7 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.apis.models import Model +from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.agents.meta_reference import ( @@ -71,13 +71,9 @@ async def agents_stack(request, inference_model, safety_model): if fixture.provider_data: provider_data.update(fixture.provider_data) - inf_provider_id = providers["inference"][0].provider_id - safety_provider_id = providers["safety"][0].provider_id - - shield = get_shield_to_register( - providers["safety"][0].provider_type, safety_provider_id, safety_model + shield_input = get_shield_to_register( + providers["safety"][0].provider_type, safety_model ) - inference_models = ( inference_model if isinstance(inference_model, list) else [inference_model] ) @@ -86,13 +82,11 @@ async def agents_stack(request, inference_model, safety_model): providers, provider_data, models=[ - Model( - identifier=model, - provider_id=inf_provider_id, - provider_resource_id=model, + ModelInput( + model_id=model, ) for model in inference_models ], - shields=[shield], + shields=[shield_input], ) return impls[Api.agents], impls[Api.memory] diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index fe91c6e03..d35ebab28 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -9,7 +9,7 @@ import os import pytest import pytest_asyncio -from llama_stack.apis.models import Model +from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.inference.meta_reference import ( @@ -162,10 +162,8 @@ async def inference_stack(request, inference_model): {"inference": inference_fixture.providers}, inference_fixture.provider_data, models=[ - Model( - identifier=inference_model, - provider_resource_id=inference_model, - provider_id=inference_fixture.providers[0].provider_id, + ModelInput( + model_id=inference_model, ) ], ) diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 5e553830c..66576e9d7 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -7,9 +7,9 @@ import pytest import pytest_asyncio -from llama_stack.apis.models import Model +from llama_stack.apis.models import ModelInput -from llama_stack.apis.shields import Shield, ShieldType +from llama_stack.apis.shields import ShieldInput, ShieldType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig @@ -99,28 +99,21 @@ async def safety_stack(inference_model, safety_model, request): provider_data.update(safety_fixture.provider_data) shield_provider_type = safety_fixture.providers[0].provider_type - shield = get_shield_to_register( - shield_provider_type, safety_fixture.providers[0].provider_id, safety_model - ) + shield_input = get_shield_to_register(shield_provider_type, safety_model) impls = await resolve_impls_for_test_v2( [Api.safety, Api.shields, Api.inference], providers, provider_data, - models=[ - Model( - identifier=inference_model, - provider_id=inference_fixture.providers[0].provider_id, - provider_resource_id=inference_model, - ) - ], - shields=[shield], + models=[ModelInput(model_id=inference_model)], + shields=[shield_input], ) + shield = await impls[Api.shields].get_shield(shield_input.shield_id) return impls[Api.safety], impls[Api.shields], shield -def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str): +def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: shield_config = {} shield_type = ShieldType.llama_guard identifier = "llama_guard" @@ -133,10 +126,8 @@ def get_shield_to_register(provider_type: str, provider_id: str, safety_model: s shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") shield_type = ShieldType.generic_content_shield - return Shield( - identifier=identifier, + return ShieldInput( + shield_id=identifier, shield_type=shield_type, params=shield_config, - provider_id=provider_id, - provider_resource_id=identifier, ) diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 14095b526..ee6999043 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -7,6 +7,8 @@ import pytest import pytest_asyncio +from llama_stack.apis.models import ModelInput + from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 @@ -76,20 +78,14 @@ async def scoring_stack(request, inference_model): [Api.scoring, Api.datasetio, Api.inference], providers, provider_data, - ) - - provider_id = providers["inference"][0].provider_id - await impls[Api.models].register_model( - model_id=inference_model, - provider_id=provider_id, - ) - await impls[Api.models].register_model( - model_id="Llama3.1-405B-Instruct", - provider_id=provider_id, - ) - await impls[Api.models].register_model( - model_id="Llama3.1-8B-Instruct", - provider_id=provider_id, + models=[ + ModelInput(model_id=model) + for model in [ + inference_model, + "Llama3.1-405B-Instruct", + "Llama3.1-8B-Instruct", + ] + ], ) return impls