re-work tool definitions, fix FastAPI issues, fix tool regressions

This commit is contained in:
Ashwin Bharambe 2024-08-24 22:07:06 -07:00
parent 8d14d4228b
commit 8efe614719
11 changed files with 144 additions and 104 deletions

View file

@ -10,6 +10,7 @@ import fire
import httpx
from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel
from termcolor import cprint
from .api import (
@ -25,6 +26,10 @@ async def get_client_impl(base_url: str):
return SafetyClient(base_url)
def encodable_dict(d: BaseModel):
return json.loads(d.json())
class SafetyClient(Safety):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
@ -40,7 +45,9 @@ class SafetyClient(Safety):
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/safety/run_shields",
data=request.json(),
json={
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"},
timeout=20,
)