mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Fix up safety client for versioned API (#573)
When running: python -m llama_stack.apis.safety.client localhost 5000 The API server was logging: INFO: ::1:57176 - "POST /safety/run_shield HTTP/1.1" 404 Not Found This patch uses the versioned API, uses the updated safety endpoint, and updates the model name to what's being served. The above python command now demonstrates a passing and failing example.
This commit is contained in:
parent
6eb5f2a865
commit
a4daf4d3ec
1 changed files with 4 additions and 2 deletions
|
@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
@ -45,7 +47,7 @@ class SafetyClient(Safety):
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/safety/run_shield",
|
f"{self.base_url}/{LLAMA_STACK_API_VERSION}/safety/run-shield",
|
||||||
json=dict(
|
json=dict(
|
||||||
shield_id=shield_id,
|
shield_id=shield_id,
|
||||||
messages=[encodable_dict(m) for m in messages],
|
messages=[encodable_dict(m) for m in messages],
|
||||||
|
@ -91,7 +93,7 @@ async def run_main(host: str, port: int, image_path: str = None):
|
||||||
]:
|
]:
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
response = await client.run_shield(
|
response = await client.run_shield(
|
||||||
shield_id="llama_guard",
|
shield_id="meta-llama/Llama-Guard-3-1B",
|
||||||
messages=[message],
|
messages=[message],
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue