mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 15:27:16 +00:00
Kill the notion of shield_type
This commit is contained in:
parent
09269e2a44
commit
b1c3a95485
20 changed files with 87 additions and 161 deletions
|
@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource):
|
|||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
dataset_id: str
|
||||
provider_id: Optional[str] = None
|
||||
|
|
|
@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource):
|
|||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
||||
eval_task_id: str
|
||||
provider_id: Optional[str] = None
|
||||
|
|
|
@ -122,7 +122,6 @@ MemoryBank = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankInput(BaseModel):
|
||||
memory_bank_id: str
|
||||
params: BankParams
|
||||
|
|
|
@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource):
|
|||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelInput(CommonModelFields):
|
||||
model_id: str
|
||||
provider_id: Optional[str] = None
|
||||
|
|
|
@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource):
|
|||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||
scoring_fn_id: str
|
||||
provider_id: Optional[str] = None
|
||||
|
|
|
@ -37,7 +37,6 @@ class ShieldsClient(Shields):
|
|||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
shield_type: ShieldType,
|
||||
provider_shield_id: Optional[str],
|
||||
provider_id: Optional[str],
|
||||
params: Optional[Dict[str, Any]],
|
||||
|
@ -47,7 +46,6 @@ class ShieldsClient(Shields):
|
|||
f"{self.base_url}/shields/register",
|
||||
json={
|
||||
"shield_id": shield_id,
|
||||
"shield_type": shield_type,
|
||||
"provider_shield_id": provider_shield_id,
|
||||
"provider_id": provider_id,
|
||||
"params": params,
|
||||
|
@ -56,12 +54,12 @@ class ShieldsClient(Shields):
|
|||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[Shield]:
|
||||
async def get_shield(self, shield_id: str) -> Optional[Shield]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/get",
|
||||
params={
|
||||
"shield_type": shield_type,
|
||||
"shield_id": shield_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
@ -13,16 +12,7 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldType(Enum):
|
||||
generic_content_shield = "generic_content_shield"
|
||||
llama_guard = "llama_guard"
|
||||
code_scanner = "code_scanner"
|
||||
prompt_guard = "prompt_guard"
|
||||
|
||||
|
||||
class CommonShieldFields(BaseModel):
|
||||
shield_type: ShieldType
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
|
@ -59,7 +49,6 @@ class Shields(Protocol):
|
|||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
shield_type: ShieldType,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue