Make Safety test work, other cleanup

This commit is contained in:
Ashwin Bharambe 2024-10-09 21:09:50 -07:00
parent ba1f294cc6
commit fcd22b6baa
16 changed files with 229 additions and 123 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import json
from typing import List, Optional
@ -25,16 +26,27 @@ class ShieldsClient(Shields):
async def shutdown(self) -> None:
pass
async def list_shields(self) -> List[ShieldSpec]:
async def list_shields(self) -> List[ShieldDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [ShieldSpec(**x) for x in response.json()]
return [ShieldDefWithProvider(**x) for x in response.json()]
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/shields/register",
json={
"shield": json.loads(shield.json()),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/get",
@ -49,7 +61,7 @@ class ShieldsClient(Shields):
if j is None:
return None
return ShieldSpec(**j)
return ShieldDefWithProvider(**j)
async def run_main(host: str, port: int, stream: bool):