Fix up safety client for versioned API (#573)

When running:
python -m llama_stack.apis.safety.client localhost 5000

The API server was logging:
INFO:    ::1:57176 - "POST /safety/run_shield HTTP/1.1" 404 Not Found

This patch uses the versioned API, uses the updated safety endpoint, and
updates the model name to what's being served. The above python command
now demonstrates a passing and failing example.
This commit is contained in:
Steve Grubb 2024-12-05 17:13:49 -05:00 committed by GitHub
parent 6eb5f2a865
commit a4daf4d3ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
@ -45,7 +47,7 @@ class SafetyClient(Safety):
) -> RunShieldResponse: ) -> RunShieldResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/safety/run_shield", f"{self.base_url}/{LLAMA_STACK_API_VERSION}/safety/run-shield",
json=dict( json=dict(
shield_id=shield_id, shield_id=shield_id,
messages=[encodable_dict(m) for m in messages], messages=[encodable_dict(m) for m in messages],
@ -91,7 +93,7 @@ async def run_main(host: str, port: int, image_path: str = None):
]: ]:
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
response = await client.run_shield( response = await client.run_shield(
shield_id="llama_guard", shield_id="meta-llama/Llama-Guard-3-1B",
messages=[message], messages=[message],
) )
print(response) print(response)