forked from phoenix-oss/llama-stack-mirror
Resource oriented design for shields (#399)
* init * working bedrock tests * bedrock test for inference fixes * use env vars for bedrock guardrail vars * add register in meta reference * use correct shield impl in meta ref * dont add together fixture * right naming * minor updates * improved registration flow * address feedback --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
7ee9f8d8ac
commit
d800a16acd
20 changed files with 262 additions and 124 deletions
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
@ -26,27 +25,38 @@ class ShieldsClient(Shields):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_shields(self) -> List[ShieldDefWithProvider]:
|
||||
async def list_shields(self) -> List[Shield]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ShieldDefWithProvider(**x) for x in response.json()]
|
||||
return [Shield(**x) for x in response.json()]
|
||||
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
shield_type: ShieldType,
|
||||
provider_shield_id: Optional[str],
|
||||
provider_id: Optional[str],
|
||||
params: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/shields/register",
|
||||
json={
|
||||
"shield": json.loads(shield.json()),
|
||||
"shield_id": shield_id,
|
||||
"shield_type": shield_type,
|
||||
"provider_shield_id": provider_shield_id,
|
||||
"provider_id": provider_id,
|
||||
"params": params,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
async def get_shield(self, shield_type: str) -> Optional[Shield]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/get",
|
||||
|
@ -61,7 +71,7 @@ class ShieldsClient(Shields):
|
|||
if j is None:
|
||||
return None
|
||||
|
||||
return ShieldDefWithProvider(**j)
|
||||
return Shield(**j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue