mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
Enable sane naming of registered objects with defaults (#429)
# What does this PR do? This is a follow-up to #425. That PR allows for specifying models in the registry, but each entry needs to look like: ```yaml - identifier: ... provider_id: ... provider_resource_identifier: ... ``` This is headache-inducing. The current PR makes this situation better by adopting the shape of our APIs. Namely, we need the user to only specify `model-id`. The rest should be optional and figured out by the Stack. You can always override it. Here's what example `ollama` "full stack" registry looks like (we still need to kill or simplify shield_type crap): ```yaml models: - model_id: Llama3.2-3B-Instruct - model_id: Llama-Guard-3-1B shields: - shield_id: llama_guard shield_type: llama_guard ``` ## Test Plan See test plan for #425. Re-ran it.
This commit is contained in:
parent
d9d271a684
commit
09269e2a44
17 changed files with 295 additions and 207 deletions
|
@ -21,7 +21,7 @@
|
||||||
"info": {
|
"info": {
|
||||||
"title": "[DRAFT] Llama Stack Specification",
|
"title": "[DRAFT] Llama Stack Specification",
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-11 18:44:30.967321"
|
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
@ -5778,8 +5778,7 @@
|
||||||
"provider_resource_id",
|
"provider_resource_id",
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"type",
|
"type",
|
||||||
"shield_type",
|
"shield_type"
|
||||||
"params"
|
|
||||||
],
|
],
|
||||||
"title": "A safety shield resource that can be used to check content"
|
"title": "A safety shield resource that can be used to check content"
|
||||||
},
|
},
|
||||||
|
@ -7027,7 +7026,7 @@
|
||||||
"provider_id": {
|
"provider_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"provider_memorybank_id": {
|
"provider_memory_bank_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -7854,59 +7853,59 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"tags": [
|
"tags": [
|
||||||
{
|
|
||||||
"name": "Datasets"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Telemetry"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "PostTraining"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "MemoryBanks"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Eval"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Memory"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "EvalTasks"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Models"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Scoring"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "Inference"
|
"name": "Inference"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "Shields"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DatasetIO"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Safety"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "Agents"
|
"name": "Agents"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "SyntheticDataGeneration"
|
"name": "Telemetry"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Eval"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Models"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Inspect"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "EvalTasks"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "ScoringFunctions"
|
"name": "ScoringFunctions"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BatchInference"
|
"name": "Memory"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Inspect"
|
"name": "Safety"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "DatasetIO"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MemoryBanks"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Shields"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "PostTraining"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Datasets"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Scoring"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "SyntheticDataGeneration"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "BatchInference"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BuiltinTool",
|
"name": "BuiltinTool",
|
||||||
|
|
|
@ -2068,7 +2068,7 @@ components:
|
||||||
- $ref: '#/components/schemas/GraphMemoryBankParams'
|
- $ref: '#/components/schemas/GraphMemoryBankParams'
|
||||||
provider_id:
|
provider_id:
|
||||||
type: string
|
type: string
|
||||||
provider_memorybank_id:
|
provider_memory_bank_id:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- memory_bank_id
|
- memory_bank_id
|
||||||
|
@ -2710,7 +2710,6 @@ components:
|
||||||
- provider_id
|
- provider_id
|
||||||
- type
|
- type
|
||||||
- shield_type
|
- shield_type
|
||||||
- params
|
|
||||||
title: A safety shield resource that can be used to check content
|
title: A safety shield resource that can be used to check content
|
||||||
type: object
|
type: object
|
||||||
ShieldCallStep:
|
ShieldCallStep:
|
||||||
|
@ -3398,7 +3397,7 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-11-11 18:44:30.967321"
|
\ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
@ -4762,24 +4761,24 @@ security:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
tags:
|
tags:
|
||||||
- name: Datasets
|
|
||||||
- name: Telemetry
|
|
||||||
- name: PostTraining
|
|
||||||
- name: MemoryBanks
|
|
||||||
- name: Eval
|
|
||||||
- name: Memory
|
|
||||||
- name: EvalTasks
|
|
||||||
- name: Models
|
|
||||||
- name: Scoring
|
|
||||||
- name: Inference
|
- name: Inference
|
||||||
- name: Shields
|
|
||||||
- name: DatasetIO
|
|
||||||
- name: Safety
|
|
||||||
- name: Agents
|
- name: Agents
|
||||||
- name: SyntheticDataGeneration
|
- name: Telemetry
|
||||||
- name: ScoringFunctions
|
- name: Eval
|
||||||
- name: BatchInference
|
- name: Models
|
||||||
- name: Inspect
|
- name: Inspect
|
||||||
|
- name: EvalTasks
|
||||||
|
- name: ScoringFunctions
|
||||||
|
- name: Memory
|
||||||
|
- name: Safety
|
||||||
|
- name: DatasetIO
|
||||||
|
- name: MemoryBanks
|
||||||
|
- name: Shields
|
||||||
|
- name: PostTraining
|
||||||
|
- name: Datasets
|
||||||
|
- name: Scoring
|
||||||
|
- name: SyntheticDataGeneration
|
||||||
|
- name: BatchInference
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||||
name: BuiltinTool
|
name: BuiltinTool
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||||
|
|
|
@ -10,15 +10,13 @@ from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class CommonDatasetFields(BaseModel):
|
||||||
class Dataset(Resource):
|
|
||||||
type: Literal["dataset"] = "dataset"
|
|
||||||
schema: Dict[str, ParamType]
|
schema: Dict[str, ParamType]
|
||||||
url: URL
|
url: URL
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
|
@ -27,6 +25,26 @@ class Dataset(Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Dataset(CommonDatasetFields, Resource):
|
||||||
|
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_dataset_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
|
dataset_id: str
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
provider_dataset_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
class Datasets(Protocol):
|
||||||
@webmethod(route="/datasets/register", method="POST")
|
@webmethod(route="/datasets/register", method="POST")
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
|
|
|
@ -7,14 +7,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class CommonEvalTaskFields(BaseModel):
|
||||||
class EvalTask(Resource):
|
|
||||||
type: Literal["eval_task"] = "eval_task"
|
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
scoring_functions: List[str]
|
scoring_functions: List[str]
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
|
@ -23,6 +21,26 @@ class EvalTask(Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvalTask(CommonEvalTaskFields, Resource):
|
||||||
|
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eval_task_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_eval_task_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
||||||
|
eval_task_id: str
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
provider_eval_task_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class EvalTasks(Protocol):
|
class EvalTasks(Protocol):
|
||||||
@webmethod(route="/eval_tasks/list", method="GET")
|
@webmethod(route="/eval_tasks/list", method="GET")
|
||||||
|
|
|
@ -30,37 +30,8 @@ class MemoryBankType(Enum):
|
||||||
graph = "graph"
|
graph = "graph"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
# define params for each type of memory bank, this leads to a tagged union
|
||||||
class VectorMemoryBank(Resource):
|
# accepted as input from the API or from the config.
|
||||||
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
|
||||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
||||||
embedding_model: str
|
|
||||||
chunk_size_in_tokens: int
|
|
||||||
overlap_size_in_tokens: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class KeyValueMemoryBank(Resource):
|
|
||||||
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
|
||||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
|
||||||
MemoryBankType.keyvalue.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class KeywordMemoryBank(Resource):
|
|
||||||
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
|
||||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
|
||||||
MemoryBankType.keyword.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class GraphMemoryBank(Resource):
|
|
||||||
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
|
||||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VectorMemoryBankParams(BaseModel):
|
class VectorMemoryBankParams(BaseModel):
|
||||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
@ -88,6 +59,58 @@ class GraphMemoryBankParams(BaseModel):
|
||||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
|
|
||||||
|
|
||||||
|
BankParams = Annotated[
|
||||||
|
Union[
|
||||||
|
VectorMemoryBankParams,
|
||||||
|
KeyValueMemoryBankParams,
|
||||||
|
KeywordMemoryBankParams,
|
||||||
|
GraphMemoryBankParams,
|
||||||
|
],
|
||||||
|
Field(discriminator="memory_bank_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Some common functionality for memory banks.
|
||||||
|
class MemoryBankResourceMixin(Resource):
|
||||||
|
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory_bank_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_memory_bank_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VectorMemoryBank(MemoryBankResourceMixin):
|
||||||
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
embedding_model: str
|
||||||
|
chunk_size_in_tokens: int
|
||||||
|
overlap_size_in_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class KeyValueMemoryBank(MemoryBankResourceMixin):
|
||||||
|
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
||||||
|
MemoryBankType.keyvalue.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
|
||||||
|
@json_schema_type
|
||||||
|
class KeywordMemoryBank(MemoryBankResourceMixin):
|
||||||
|
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
||||||
|
MemoryBankType.keyword.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GraphMemoryBank(MemoryBankResourceMixin):
|
||||||
|
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
|
|
||||||
|
|
||||||
MemoryBank = Annotated[
|
MemoryBank = Annotated[
|
||||||
Union[
|
Union[
|
||||||
VectorMemoryBank,
|
VectorMemoryBank,
|
||||||
|
@ -98,15 +121,12 @@ MemoryBank = Annotated[
|
||||||
Field(discriminator="memory_bank_type"),
|
Field(discriminator="memory_bank_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
BankParams = Annotated[
|
|
||||||
Union[
|
@json_schema_type
|
||||||
VectorMemoryBankParams,
|
class MemoryBankInput(BaseModel):
|
||||||
KeyValueMemoryBankParams,
|
memory_bank_id: str
|
||||||
KeywordMemoryBankParams,
|
params: BankParams
|
||||||
GraphMemoryBankParams,
|
provider_memory_bank_id: Optional[str] = None
|
||||||
],
|
|
||||||
Field(discriminator="memory_bank_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -123,5 +143,5 @@ class MemoryBanks(Protocol):
|
||||||
memory_bank_id: str,
|
memory_bank_id: str,
|
||||||
params: BankParams,
|
params: BankParams,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_memorybank_id: Optional[str] = None,
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
) -> MemoryBank: ...
|
) -> MemoryBank: ...
|
||||||
|
|
|
@ -7,20 +7,38 @@
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class CommonModelFields(BaseModel):
|
||||||
class Model(Resource):
|
|
||||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this model",
|
description="Any additional metadata for this model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Model(CommonModelFields, Resource):
|
||||||
|
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_model_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModelInput(CommonModelFields):
|
||||||
|
model_id: str
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
provider_model_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models/list", method="GET")
|
@webmethod(route="/models/list", method="GET")
|
||||||
|
|
|
@ -17,14 +17,12 @@ class ResourceType(Enum):
|
||||||
memory_bank = "memory_bank"
|
memory_bank = "memory_bank"
|
||||||
dataset = "dataset"
|
dataset = "dataset"
|
||||||
scoring_function = "scoring_function"
|
scoring_function = "scoring_function"
|
||||||
|
eval_task = "eval_task"
|
||||||
|
|
||||||
|
|
||||||
class Resource(BaseModel):
|
class Resource(BaseModel):
|
||||||
"""Base class for all Llama Stack resources"""
|
"""Base class for all Llama Stack resources"""
|
||||||
|
|
||||||
# TODO: I think we need to move these into the child classes
|
|
||||||
# and make them `model_id`, `shield_id`, etc. because otherwise
|
|
||||||
# the config file has these confusing generic names in there
|
|
||||||
identifier: str = Field(
|
identifier: str = Field(
|
||||||
description="Unique identifier for this resource in llama stack"
|
description="Unique identifier for this resource in llama stack"
|
||||||
)
|
)
|
||||||
|
|
|
@ -66,11 +66,7 @@ ScoringFnParams = Annotated[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class CommonScoringFnFields(BaseModel):
|
||||||
class ScoringFn(Resource):
|
|
||||||
type: Literal[ResourceType.scoring_function.value] = (
|
|
||||||
ResourceType.scoring_function.value
|
|
||||||
)
|
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
|
@ -85,6 +81,28 @@ class ScoringFn(Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
|
type: Literal[ResourceType.scoring_function.value] = (
|
||||||
|
ResourceType.scoring_function.value
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scoring_fn_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_scoring_fn_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
|
scoring_fn_id: str
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
provider_scoring_fn_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ScoringFunctions(Protocol):
|
class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring_functions/list", method="GET")
|
@webmethod(route="/scoring_functions/list", method="GET")
|
||||||
|
|
|
@ -8,6 +8,7 @@ from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
@ -20,13 +21,30 @@ class ShieldType(Enum):
|
||||||
prompt_guard = "prompt_guard"
|
prompt_guard = "prompt_guard"
|
||||||
|
|
||||||
|
|
||||||
|
class CommonShieldFields(BaseModel):
|
||||||
|
shield_type: ShieldType
|
||||||
|
params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Shield(Resource):
|
class Shield(CommonShieldFields, Resource):
|
||||||
"""A safety shield resource that can be used to check content"""
|
"""A safety shield resource that can be used to check content"""
|
||||||
|
|
||||||
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
|
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
|
||||||
shield_type: ShieldType
|
|
||||||
params: Dict[str, Any] = {}
|
@property
|
||||||
|
def shield_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_shield_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldInput(CommonShieldFields):
|
||||||
|
shield_id: str
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
provider_shield_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.eval_tasks import EvalTask
|
from llama_stack.apis.eval_tasks import EvalTaskInput
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
@ -152,12 +152,12 @@ a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: List[Model] = Field(default_factory=list)
|
models: List[ModelInput] = Field(default_factory=list)
|
||||||
shields: List[Shield] = Field(default_factory=list)
|
shields: List[ShieldInput] = Field(default_factory=list)
|
||||||
memory_banks: List[MemoryBank] = Field(default_factory=list)
|
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
|
||||||
datasets: List[Dataset] = Field(default_factory=list)
|
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFn] = Field(default_factory=list)
|
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||||
eval_tasks: List[EvalTask] = Field(default_factory=list)
|
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
|
|
|
@ -32,6 +32,10 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
|
|
||||||
if obj.provider_id == "remote":
|
if obj.provider_id == "remote":
|
||||||
|
# TODO: this is broken right now because we use the generic
|
||||||
|
# { identifier, provider_id, provider_resource_id } tuple here
|
||||||
|
# but the APIs expect things like ModelInput, ShieldInput, etc.
|
||||||
|
|
||||||
# if this is just a passthrough, we want to let the remote
|
# if this is just a passthrough, we want to let the remote
|
||||||
# end actually do the registration with the correct provider
|
# end actually do the registration with the correct provider
|
||||||
obj = obj.model_copy(deep=True)
|
obj = obj.model_copy(deep=True)
|
||||||
|
@ -277,10 +281,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
memory_bank_id: str,
|
memory_bank_id: str,
|
||||||
params: BankParams,
|
params: BankParams,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_memorybank_id: Optional[str] = None,
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
) -> MemoryBank:
|
) -> MemoryBank:
|
||||||
if provider_memorybank_id is None:
|
if provider_memory_bank_id is None:
|
||||||
provider_memorybank_id = memory_bank_id
|
provider_memory_bank_id = memory_bank_id
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
# If provider_id not specified, use the only provider if it supports this shield type
|
# If provider_id not specified, use the only provider if it supports this shield type
|
||||||
if len(self.impls_by_provider_id) == 1:
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
@ -295,7 +299,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
"identifier": memory_bank_id,
|
"identifier": memory_bank_id,
|
||||||
"type": ResourceType.memory_bank.value,
|
"type": ResourceType.memory_bank.value,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_resource_id": provider_memorybank_id,
|
"provider_resource_id": provider_memory_bank_id,
|
||||||
**params.model_dump(),
|
**params.model_dump(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
@ -67,30 +68,29 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
|
|
||||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||||
|
|
||||||
objects = [
|
|
||||||
*run_config.models,
|
|
||||||
*run_config.shields,
|
|
||||||
*run_config.memory_banks,
|
|
||||||
*run_config.datasets,
|
|
||||||
*run_config.scoring_fns,
|
|
||||||
*run_config.eval_tasks,
|
|
||||||
]
|
|
||||||
for obj in objects:
|
|
||||||
await dist_registry.register(obj)
|
|
||||||
|
|
||||||
resources = [
|
resources = [
|
||||||
("models", Api.models),
|
("models", Api.models, "register_model", "list_models"),
|
||||||
("shields", Api.shields),
|
("shields", Api.shields, "register_shield", "list_shields"),
|
||||||
("memory_banks", Api.memory_banks),
|
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||||
("datasets", Api.datasets),
|
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||||
("scoring_fns", Api.scoring_functions),
|
(
|
||||||
("eval_tasks", Api.eval_tasks),
|
"scoring_fns",
|
||||||
|
Api.scoring_functions,
|
||||||
|
"register_scoring_function",
|
||||||
|
"list_scoring_functions",
|
||||||
|
),
|
||||||
|
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||||
]
|
]
|
||||||
for rsrc, api in resources:
|
for rsrc, api, register_method, list_method in resources:
|
||||||
|
objects = getattr(run_config, rsrc)
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
method = getattr(impls[api], f"list_{api.value}")
|
method = getattr(impls[api], register_method)
|
||||||
|
for obj in objects:
|
||||||
|
await method(**obj.model_dump())
|
||||||
|
|
||||||
|
method = getattr(impls[api], list_method)
|
||||||
for obj in await method():
|
for obj in await method():
|
||||||
print(
|
print(
|
||||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||||
|
|
|
@ -128,7 +128,6 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
print(f"Registering shield {shield}")
|
|
||||||
if shield.shield_type != ShieldType.llama_guard:
|
if shield.shield_type != ShieldType.llama_guard:
|
||||||
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import ModelInput
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.agents.meta_reference import (
|
from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
|
@ -71,13 +71,9 @@ async def agents_stack(request, inference_model, safety_model):
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
inf_provider_id = providers["inference"][0].provider_id
|
shield_input = get_shield_to_register(
|
||||||
safety_provider_id = providers["safety"][0].provider_id
|
providers["safety"][0].provider_type, safety_model
|
||||||
|
|
||||||
shield = get_shield_to_register(
|
|
||||||
providers["safety"][0].provider_type, safety_provider_id, safety_model
|
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_models = (
|
inference_models = (
|
||||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||||
)
|
)
|
||||||
|
@ -86,13 +82,11 @@ async def agents_stack(request, inference_model, safety_model):
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
models=[
|
models=[
|
||||||
Model(
|
ModelInput(
|
||||||
identifier=model,
|
model_id=model,
|
||||||
provider_id=inf_provider_id,
|
|
||||||
provider_resource_id=model,
|
|
||||||
)
|
)
|
||||||
for model in inference_models
|
for model in inference_models
|
||||||
],
|
],
|
||||||
shields=[shield],
|
shields=[shield_input],
|
||||||
)
|
)
|
||||||
return impls[Api.agents], impls[Api.memory]
|
return impls[Api.agents], impls[Api.memory]
|
||||||
|
|
|
@ -9,7 +9,7 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
|
@ -162,10 +162,8 @@ async def inference_stack(request, inference_model):
|
||||||
{"inference": inference_fixture.providers},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[
|
models=[
|
||||||
Model(
|
ModelInput(
|
||||||
identifier=inference_model,
|
model_id=inference_model,
|
||||||
provider_resource_id=inference_model,
|
|
||||||
provider_id=inference_fixture.providers[0].provider_id,
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
from llama_stack.apis.shields import Shield, ShieldType
|
from llama_stack.apis.shields import ShieldInput, ShieldType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||||
|
@ -99,28 +99,21 @@ async def safety_stack(inference_model, safety_model, request):
|
||||||
provider_data.update(safety_fixture.provider_data)
|
provider_data.update(safety_fixture.provider_data)
|
||||||
|
|
||||||
shield_provider_type = safety_fixture.providers[0].provider_type
|
shield_provider_type = safety_fixture.providers[0].provider_type
|
||||||
shield = get_shield_to_register(
|
shield_input = get_shield_to_register(shield_provider_type, safety_model)
|
||||||
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
|
|
||||||
)
|
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
impls = await resolve_impls_for_test_v2(
|
||||||
[Api.safety, Api.shields, Api.inference],
|
[Api.safety, Api.shields, Api.inference],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
models=[
|
models=[ModelInput(model_id=inference_model)],
|
||||||
Model(
|
shields=[shield_input],
|
||||||
identifier=inference_model,
|
|
||||||
provider_id=inference_fixture.providers[0].provider_id,
|
|
||||||
provider_resource_id=inference_model,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
shields=[shield],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
shield = await impls[Api.shields].get_shield(shield_input.shield_id)
|
||||||
return impls[Api.safety], impls[Api.shields], shield
|
return impls[Api.safety], impls[Api.shields], shield
|
||||||
|
|
||||||
|
|
||||||
def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
|
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
|
||||||
shield_config = {}
|
shield_config = {}
|
||||||
shield_type = ShieldType.llama_guard
|
shield_type = ShieldType.llama_guard
|
||||||
identifier = "llama_guard"
|
identifier = "llama_guard"
|
||||||
|
@ -133,10 +126,8 @@ def get_shield_to_register(provider_type: str, provider_id: str, safety_model: s
|
||||||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||||
shield_type = ShieldType.generic_content_shield
|
shield_type = ShieldType.generic_content_shield
|
||||||
|
|
||||||
return Shield(
|
return ShieldInput(
|
||||||
identifier=identifier,
|
shield_id=identifier,
|
||||||
shield_type=shield_type,
|
shield_type=shield_type,
|
||||||
params=shield_config,
|
params=shield_config,
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=identifier,
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
|
@ -76,20 +78,14 @@ async def scoring_stack(request, inference_model):
|
||||||
[Api.scoring, Api.datasetio, Api.inference],
|
[Api.scoring, Api.datasetio, Api.inference],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
)
|
models=[
|
||||||
|
ModelInput(model_id=model)
|
||||||
provider_id = providers["inference"][0].provider_id
|
for model in [
|
||||||
await impls[Api.models].register_model(
|
inference_model,
|
||||||
model_id=inference_model,
|
"Llama3.1-405B-Instruct",
|
||||||
provider_id=provider_id,
|
"Llama3.1-8B-Instruct",
|
||||||
)
|
]
|
||||||
await impls[Api.models].register_model(
|
],
|
||||||
model_id="Llama3.1-405B-Instruct",
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
await impls[Api.models].register_model(
|
|
||||||
model_id="Llama3.1-8B-Instruct",
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue