Kill the notion of shield_type

This commit is contained in:
Ashwin Bharambe 2024-11-12 11:41:23 -08:00
parent 09269e2a44
commit b1c3a95485
20 changed files with 87 additions and 161 deletions

View file

@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource):
return self.provider_resource_id
@json_schema_type
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None

View file

@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource):
return self.provider_resource_id
@json_schema_type
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str
provider_id: Optional[str] = None

View file

@ -122,7 +122,6 @@ MemoryBank = Annotated[
]
@json_schema_type
class MemoryBankInput(BaseModel):
memory_bank_id: str
params: BankParams

View file

@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource):
return self.provider_resource_id
@json_schema_type
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None

View file

@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource):
return self.provider_resource_id
@json_schema_type
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None

View file

@ -37,7 +37,6 @@ class ShieldsClient(Shields):
async def register_shield(
self,
shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str],
provider_id: Optional[str],
params: Optional[Dict[str, Any]],
@ -47,7 +46,6 @@ class ShieldsClient(Shields):
f"{self.base_url}/shields/register",
json={
"shield_id": shield_id,
"shield_type": shield_type,
"provider_shield_id": provider_shield_id,
"provider_id": provider_id,
"params": params,
@ -56,12 +54,12 @@ class ShieldsClient(Shields):
)
response.raise_for_status()
async def get_shield(self, shield_type: str) -> Optional[Shield]:
async def get_shield(self, shield_id: str) -> Optional[Shield]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/get",
params={
"shield_type": shield_type,
"shield_id": shield_id,
},
headers={"Content-Type": "application/json"},
)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@ -13,16 +12,7 @@ from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class ShieldType(Enum):
generic_content_shield = "generic_content_shield"
llama_guard = "llama_guard"
code_scanner = "code_scanner"
prompt_guard = "prompt_guard"
class CommonShieldFields(BaseModel):
shield_type: ShieldType
params: Optional[Dict[str, Any]] = None
@ -59,7 +49,6 @@ class Shields(Protocol):
async def register_shield(
self,
shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,