diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 196a400f8..231633464 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-11 18:44:30.967321"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
},
"servers": [
{
@@ -5778,8 +5778,7 @@
"provider_resource_id",
"provider_id",
"type",
- "shield_type",
- "params"
+ "shield_type"
],
"title": "A safety shield resource that can be used to check content"
},
@@ -7027,7 +7026,7 @@
"provider_id": {
"type": "string"
},
- "provider_memorybank_id": {
+ "provider_memory_bank_id": {
"type": "string"
}
},
@@ -7854,59 +7853,59 @@
}
],
"tags": [
- {
- "name": "Datasets"
- },
- {
- "name": "Telemetry"
- },
- {
- "name": "PostTraining"
- },
- {
- "name": "MemoryBanks"
- },
- {
- "name": "Eval"
- },
- {
- "name": "Memory"
- },
- {
- "name": "EvalTasks"
- },
- {
- "name": "Models"
- },
- {
- "name": "Scoring"
- },
{
"name": "Inference"
},
- {
- "name": "Shields"
- },
- {
- "name": "DatasetIO"
- },
- {
- "name": "Safety"
- },
{
"name": "Agents"
},
{
- "name": "SyntheticDataGeneration"
+ "name": "Telemetry"
+ },
+ {
+ "name": "Eval"
+ },
+ {
+ "name": "Models"
+ },
+ {
+ "name": "Inspect"
+ },
+ {
+ "name": "EvalTasks"
},
{
"name": "ScoringFunctions"
},
{
- "name": "BatchInference"
+ "name": "Memory"
},
{
- "name": "Inspect"
+ "name": "Safety"
+ },
+ {
+ "name": "DatasetIO"
+ },
+ {
+ "name": "MemoryBanks"
+ },
+ {
+ "name": "Shields"
+ },
+ {
+ "name": "PostTraining"
+ },
+ {
+ "name": "Datasets"
+ },
+ {
+ "name": "Scoring"
+ },
+ {
+ "name": "SyntheticDataGeneration"
+ },
+ {
+ "name": "BatchInference"
},
{
"name": "BuiltinTool",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 164d3168c..4e02e8075 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -2068,7 +2068,7 @@ components:
- $ref: '#/components/schemas/GraphMemoryBankParams'
provider_id:
type: string
- provider_memorybank_id:
+ provider_memory_bank_id:
type: string
required:
- memory_bank_id
@@ -2710,7 +2710,6 @@ components:
- provider_id
- type
- shield_type
- - params
title: A safety shield resource that can be used to check content
type: object
ShieldCallStep:
@@ -3398,7 +3397,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-11-11 18:44:30.967321"
+ \ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -4762,24 +4761,24 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
-- name: Datasets
-- name: Telemetry
-- name: PostTraining
-- name: MemoryBanks
-- name: Eval
-- name: Memory
-- name: EvalTasks
-- name: Models
-- name: Scoring
- name: Inference
-- name: Shields
-- name: DatasetIO
-- name: Safety
- name: Agents
-- name: SyntheticDataGeneration
-- name: ScoringFunctions
-- name: BatchInference
+- name: Telemetry
+- name: Eval
+- name: Models
- name: Inspect
+- name: EvalTasks
+- name: ScoringFunctions
+- name: Memory
+- name: Safety
+- name: DatasetIO
+- name: MemoryBanks
+- name: Shields
+- name: PostTraining
+- name: Datasets
+- name: Scoring
+- name: SyntheticDataGeneration
+- name: BatchInference
- description:
name: BuiltinTool
- description: str:
+ return self.identifier
+
+ @property
+ def provider_dataset_id(self) -> str:
+ return self.provider_resource_id
+
+
+@json_schema_type
+class DatasetInput(CommonDatasetFields, BaseModel):
+ dataset_id: str
+ provider_id: Optional[str] = None
+ provider_dataset_id: Optional[str] = None
+
+
class Datasets(Protocol):
@webmethod(route="/datasets/register", method="POST")
async def register_dataset(
diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py
index 870673e58..10c35c3ee 100644
--- a/llama_stack/apis/eval_tasks/eval_tasks.py
+++ b/llama_stack/apis/eval_tasks/eval_tasks.py
@@ -7,14 +7,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab
from llama_models.schema_utils import json_schema_type, webmethod
-from pydantic import Field
+from pydantic import BaseModel, Field
-from llama_stack.apis.resource import Resource
+from llama_stack.apis.resource import Resource, ResourceType
-@json_schema_type
-class EvalTask(Resource):
- type: Literal["eval_task"] = "eval_task"
+class CommonEvalTaskFields(BaseModel):
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
@@ -23,6 +21,26 @@ class EvalTask(Resource):
)
+@json_schema_type
+class EvalTask(CommonEvalTaskFields, Resource):
+ type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
+
+ @property
+ def eval_task_id(self) -> str:
+ return self.identifier
+
+ @property
+ def provider_eval_task_id(self) -> str:
+ return self.provider_resource_id
+
+
+@json_schema_type
+class EvalTaskInput(CommonEvalTaskFields, BaseModel):
+ eval_task_id: str
+ provider_id: Optional[str] = None
+ provider_eval_task_id: Optional[str] = None
+
+
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval_tasks/list", method="GET")
diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py
index 303104f25..83b292612 100644
--- a/llama_stack/apis/memory_banks/memory_banks.py
+++ b/llama_stack/apis/memory_banks/memory_banks.py
@@ -30,37 +30,8 @@ class MemoryBankType(Enum):
graph = "graph"
-@json_schema_type
-class VectorMemoryBank(Resource):
- type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
- memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
- embedding_model: str
- chunk_size_in_tokens: int
- overlap_size_in_tokens: Optional[int] = None
-
-
-@json_schema_type
-class KeyValueMemoryBank(Resource):
- type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
- memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
- MemoryBankType.keyvalue.value
- )
-
-
-@json_schema_type
-class KeywordMemoryBank(Resource):
- type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
- memory_bank_type: Literal[MemoryBankType.keyword.value] = (
- MemoryBankType.keyword.value
- )
-
-
-@json_schema_type
-class GraphMemoryBank(Resource):
- type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
- memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
-
-
+# define params for each type of memory bank, this leads to a tagged union
+# accepted as input from the API or from the config.
@json_schema_type
class VectorMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
@@ -88,6 +59,58 @@ class GraphMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
+BankParams = Annotated[
+ Union[
+ VectorMemoryBankParams,
+ KeyValueMemoryBankParams,
+ KeywordMemoryBankParams,
+ GraphMemoryBankParams,
+ ],
+ Field(discriminator="memory_bank_type"),
+]
+
+
+# Some common functionality for memory banks.
+class MemoryBankResourceMixin(Resource):
+ type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
+
+ @property
+ def memory_bank_id(self) -> str:
+ return self.identifier
+
+ @property
+ def provider_memory_bank_id(self) -> str:
+ return self.provider_resource_id
+
+
+@json_schema_type
+class VectorMemoryBank(MemoryBankResourceMixin):
+ memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
+ embedding_model: str
+ chunk_size_in_tokens: int
+ overlap_size_in_tokens: Optional[int] = None
+
+
+@json_schema_type
+class KeyValueMemoryBank(MemoryBankResourceMixin):
+ memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
+ MemoryBankType.keyvalue.value
+ )
+
+
+# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
+@json_schema_type
+class KeywordMemoryBank(MemoryBankResourceMixin):
+ memory_bank_type: Literal[MemoryBankType.keyword.value] = (
+ MemoryBankType.keyword.value
+ )
+
+
+@json_schema_type
+class GraphMemoryBank(MemoryBankResourceMixin):
+ memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
+
+
MemoryBank = Annotated[
Union[
VectorMemoryBank,
@@ -98,15 +121,12 @@ MemoryBank = Annotated[
Field(discriminator="memory_bank_type"),
]
-BankParams = Annotated[
- Union[
- VectorMemoryBankParams,
- KeyValueMemoryBankParams,
- KeywordMemoryBankParams,
- GraphMemoryBankParams,
- ],
- Field(discriminator="memory_bank_type"),
-]
+
+@json_schema_type
+class MemoryBankInput(BaseModel):
+ memory_bank_id: str
+ params: BankParams
+ provider_memory_bank_id: Optional[str] = None
@runtime_checkable
@@ -123,5 +143,5 @@ class MemoryBanks(Protocol):
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
- provider_memorybank_id: Optional[str] = None,
+ provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py
index bb8d2c4ea..a5d226886 100644
--- a/llama_stack/apis/models/models.py
+++ b/llama_stack/apis/models/models.py
@@ -7,20 +7,38 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
-from pydantic import Field
+from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
-@json_schema_type
-class Model(Resource):
- type: Literal[ResourceType.model.value] = ResourceType.model.value
+class CommonModelFields(BaseModel):
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)
+@json_schema_type
+class Model(CommonModelFields, Resource):
+ type: Literal[ResourceType.model.value] = ResourceType.model.value
+
+ @property
+ def model_id(self) -> str:
+ return self.identifier
+
+ @property
+ def provider_model_id(self) -> str:
+ return self.provider_resource_id
+
+
+@json_schema_type
+class ModelInput(CommonModelFields):
+ model_id: str
+ provider_id: Optional[str] = None
+ provider_model_id: Optional[str] = None
+
+
@runtime_checkable
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py
index 0e488190b..93a3718a0 100644
--- a/llama_stack/apis/resource.py
+++ b/llama_stack/apis/resource.py
@@ -17,14 +17,12 @@ class ResourceType(Enum):
memory_bank = "memory_bank"
dataset = "dataset"
scoring_function = "scoring_function"
+ eval_task = "eval_task"
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
- # TODO: I think we need to move these into the child classes
- # and make them `model_id`, `shield_id`, etc. because otherwise
- # the config file has these confusing generic names in there
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)
diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py
index 6b2408e0d..7a2a83c72 100644
--- a/llama_stack/apis/scoring_functions/scoring_functions.py
+++ b/llama_stack/apis/scoring_functions/scoring_functions.py
@@ -66,11 +66,7 @@ ScoringFnParams = Annotated[
]
-@json_schema_type
-class ScoringFn(Resource):
- type: Literal[ResourceType.scoring_function.value] = (
- ResourceType.scoring_function.value
- )
+class CommonScoringFnFields(BaseModel):
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
@@ -85,6 +81,28 @@ class ScoringFn(Resource):
)
+@json_schema_type
+class ScoringFn(CommonScoringFnFields, Resource):
+ type: Literal[ResourceType.scoring_function.value] = (
+ ResourceType.scoring_function.value
+ )
+
+ @property
+ def scoring_fn_id(self) -> str:
+ return self.identifier
+
+ @property
+ def provider_scoring_fn_id(self) -> str:
+ return self.provider_resource_id
+
+
+@json_schema_type
+class ScoringFnInput(CommonScoringFnFields, BaseModel):
+ scoring_fn_id: str
+ provider_id: Optional[str] = None
+ provider_scoring_fn_id: Optional[str] = None
+
+
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET")
diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py
index 42fe717fa..1dcfd4f4c 100644
--- a/llama_stack/apis/shields/shields.py
+++ b/llama_stack/apis/shields/shields.py
@@ -8,6 +8,7 @@ from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
+from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
@@ -20,13 +21,30 @@ class ShieldType(Enum):
prompt_guard = "prompt_guard"
+class CommonShieldFields(BaseModel):
+ shield_type: ShieldType
+ params: Optional[Dict[str, Any]] = None
+
+
@json_schema_type
-class Shield(Resource):
+class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content"""
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
- shield_type: ShieldType
- params: Dict[str, Any] = {}
+
+ @property
+ def shield_id(self) -> str:
+ return self.identifier
+
+ @property
+ def provider_shield_id(self) -> str:
+ return self.provider_resource_id
+
+
+class ShieldInput(CommonShieldFields):
+ shield_id: str
+ provider_id: Optional[str] = None
+ provider_shield_id: Optional[str] = None
@runtime_checkable
diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py
index 2cba5b052..4aaf9c38a 100644
--- a/llama_stack/distribution/datatypes.py
+++ b/llama_stack/distribution/datatypes.py
@@ -18,7 +18,7 @@ from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.eval import Eval
-from llama_stack.apis.eval_tasks import EvalTask
+from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
@@ -152,12 +152,12 @@ a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
- models: List[Model] = Field(default_factory=list)
- shields: List[Shield] = Field(default_factory=list)
- memory_banks: List[MemoryBank] = Field(default_factory=list)
- datasets: List[Dataset] = Field(default_factory=list)
- scoring_fns: List[ScoringFn] = Field(default_factory=list)
- eval_tasks: List[EvalTask] = Field(default_factory=list)
+ models: List[ModelInput] = Field(default_factory=list)
+ shields: List[ShieldInput] = Field(default_factory=list)
+ memory_banks: List[MemoryBankInput] = Field(default_factory=list)
+ datasets: List[DatasetInput] = Field(default_factory=list)
+ scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
+ eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
class BuildConfig(BaseModel):
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index efed54ab8..7b369df2c 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -32,6 +32,10 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if obj.provider_id == "remote":
+ # TODO: this is broken right now because we use the generic
+ # { identifier, provider_id, provider_resource_id } tuple here
+ # but the APIs expect things like ModelInput, ShieldInput, etc.
+
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)
@@ -277,10 +281,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
- provider_memorybank_id: Optional[str] = None,
+ provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
- if provider_memorybank_id is None:
- provider_memorybank_id = memory_bank_id
+ if provider_memory_bank_id is None:
+ provider_memory_bank_id = memory_bank_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
@@ -295,7 +299,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
- "provider_resource_id": provider_memorybank_id,
+ "provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
},
)
diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py
index 7fe7d3ca7..3afd51304 100644
--- a/llama_stack/distribution/stack.py
+++ b/llama_stack/distribution/stack.py
@@ -5,6 +5,7 @@
# the root directory of this source tree.
from typing import Any, Dict
+from termcolor import colored
from termcolor import colored
@@ -67,30 +68,29 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
- objects = [
- *run_config.models,
- *run_config.shields,
- *run_config.memory_banks,
- *run_config.datasets,
- *run_config.scoring_fns,
- *run_config.eval_tasks,
- ]
- for obj in objects:
- await dist_registry.register(obj)
-
resources = [
- ("models", Api.models),
- ("shields", Api.shields),
- ("memory_banks", Api.memory_banks),
- ("datasets", Api.datasets),
- ("scoring_fns", Api.scoring_functions),
- ("eval_tasks", Api.eval_tasks),
+ ("models", Api.models, "register_model", "list_models"),
+ ("shields", Api.shields, "register_shield", "list_shields"),
+ ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
+ ("datasets", Api.datasets, "register_dataset", "list_datasets"),
+ (
+ "scoring_fns",
+ Api.scoring_functions,
+ "register_scoring_function",
+ "list_scoring_functions",
+ ),
+ ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
]
- for rsrc, api in resources:
+ for rsrc, api, register_method, list_method in resources:
+ objects = getattr(run_config, rsrc)
if api not in impls:
continue
- method = getattr(impls[api], f"list_{api.value}")
+ method = getattr(impls[api], register_method)
+ for obj in objects:
+ await method(**obj.model_dump())
+
+ method = getattr(impls[api], list_method)
for obj in await method():
print(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
index 9c3ec7750..12d012b16 100644
--- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
+++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
@@ -128,7 +128,6 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
- print(f"Registering shield {shield}")
if shield.shield_type != ShieldType.llama_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py
index 6ee17ff1f..64f493b88 100644
--- a/llama_stack/providers/tests/agents/fixtures.py
+++ b/llama_stack/providers/tests/agents/fixtures.py
@@ -9,7 +9,7 @@ import tempfile
import pytest
import pytest_asyncio
-from llama_stack.apis.models import Model
+from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
@@ -71,13 +71,9 @@ async def agents_stack(request, inference_model, safety_model):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
- inf_provider_id = providers["inference"][0].provider_id
- safety_provider_id = providers["safety"][0].provider_id
-
- shield = get_shield_to_register(
- providers["safety"][0].provider_type, safety_provider_id, safety_model
+ shield_input = get_shield_to_register(
+ providers["safety"][0].provider_type, safety_model
)
-
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
@@ -86,13 +82,11 @@ async def agents_stack(request, inference_model, safety_model):
providers,
provider_data,
models=[
- Model(
- identifier=model,
- provider_id=inf_provider_id,
- provider_resource_id=model,
+ ModelInput(
+ model_id=model,
)
for model in inference_models
],
- shields=[shield],
+ shields=[shield_input],
)
return impls[Api.agents], impls[Api.memory]
diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py
index fe91c6e03..d35ebab28 100644
--- a/llama_stack/providers/tests/inference/fixtures.py
+++ b/llama_stack/providers/tests/inference/fixtures.py
@@ -9,7 +9,7 @@ import os
import pytest
import pytest_asyncio
-from llama_stack.apis.models import Model
+from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
@@ -162,10 +162,8 @@ async def inference_stack(request, inference_model):
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
- Model(
- identifier=inference_model,
- provider_resource_id=inference_model,
- provider_id=inference_fixture.providers[0].provider_id,
+ ModelInput(
+ model_id=inference_model,
)
],
)
diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py
index 5e553830c..66576e9d7 100644
--- a/llama_stack/providers/tests/safety/fixtures.py
+++ b/llama_stack/providers/tests/safety/fixtures.py
@@ -7,9 +7,9 @@
import pytest
import pytest_asyncio
-from llama_stack.apis.models import Model
+from llama_stack.apis.models import ModelInput
-from llama_stack.apis.shields import Shield, ShieldType
+from llama_stack.apis.shields import ShieldInput, ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@@ -99,28 +99,21 @@ async def safety_stack(inference_model, safety_model, request):
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
- shield = get_shield_to_register(
- shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
- )
+ shield_input = get_shield_to_register(shield_provider_type, safety_model)
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
- models=[
- Model(
- identifier=inference_model,
- provider_id=inference_fixture.providers[0].provider_id,
- provider_resource_id=inference_model,
- )
- ],
- shields=[shield],
+ models=[ModelInput(model_id=inference_model)],
+ shields=[shield_input],
)
+ shield = await impls[Api.shields].get_shield(shield_input.shield_id)
return impls[Api.safety], impls[Api.shields], shield
-def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
+def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@@ -133,10 +126,8 @@ def get_shield_to_register(provider_type: str, provider_id: str, safety_model: s
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
- return Shield(
- identifier=identifier,
+ return ShieldInput(
+ shield_id=identifier,
shield_type=shield_type,
params=shield_config,
- provider_id=provider_id,
- provider_resource_id=identifier,
)
diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py
index 14095b526..ee6999043 100644
--- a/llama_stack/providers/tests/scoring/fixtures.py
+++ b/llama_stack/providers/tests/scoring/fixtures.py
@@ -7,6 +7,8 @@
import pytest
import pytest_asyncio
+from llama_stack.apis.models import ModelInput
+
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
@@ -76,20 +78,14 @@ async def scoring_stack(request, inference_model):
[Api.scoring, Api.datasetio, Api.inference],
providers,
provider_data,
- )
-
- provider_id = providers["inference"][0].provider_id
- await impls[Api.models].register_model(
- model_id=inference_model,
- provider_id=provider_id,
- )
- await impls[Api.models].register_model(
- model_id="Llama3.1-405B-Instruct",
- provider_id=provider_id,
- )
- await impls[Api.models].register_model(
- model_id="Llama3.1-8B-Instruct",
- provider_id=provider_id,
+ models=[
+ ModelInput(model_id=model)
+ for model in [
+ inference_model,
+ "Llama3.1-405B-Instruct",
+ "Llama3.1-8B-Instruct",
+ ]
+ ],
)
return impls