From a4daf4d3ecc3d53ec14725634f2be16a8948ce56 Mon Sep 17 00:00:00 2001 From: Steve Grubb Date: Thu, 5 Dec 2024 17:13:49 -0500 Subject: [PATCH] Fix up safety client for versioned API (#573) When running: python -m llama_stack.apis.safety.client localhost 5000 The API server was logging: INFO: ::1:57176 - "POST /safety/run_shield HTTP/1.1" 404 Not Found This patch uses the versioned API, uses the updated safety endpoint, and updates the model name to what's being served. The above python command now demonstrates a passing and failing example. --- llama_stack/apis/safety/client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index d7d4bc981..a9396c70c 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from pydantic import BaseModel from termcolor import cprint +from llama_stack.apis.version import LLAMA_STACK_API_VERSION + from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.apis.safety import * # noqa: F403 @@ -45,7 +47,7 @@ class SafetyClient(Safety): ) -> RunShieldResponse: async with httpx.AsyncClient() as client: response = await client.post( - f"{self.base_url}/safety/run_shield", + f"{self.base_url}/{LLAMA_STACK_API_VERSION}/safety/run-shield", json=dict( shield_id=shield_id, messages=[encodable_dict(m) for m in messages], @@ -91,7 +93,7 @@ async def run_main(host: str, port: int, image_path: str = None): ]: cprint(f"User>{message.content}", "green") response = await client.run_shield( - shield_id="llama_guard", + shield_id="meta-llama/Llama-Guard-3-1B", messages=[message], ) print(response)