mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
adding prod url
also cleans up print messages
This commit is contained in:
parent
aea5b2745d
commit
49f7e04f83
2 changed files with 28 additions and 26 deletions
|
@ -4,11 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FiddlecubeSafetyConfig(BaseModel):
|
||||
pass
|
||||
api_url: str = "https://api.fiddlecube.ai/api"
|
||||
excluded_categories: List[str] = Field(default_factory=list)
|
||||
|
|
|
@ -14,8 +14,6 @@ from llama_stack.apis.inference import Message
|
|||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
@ -77,28 +75,29 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
response = self.bedrock_runtime_client.apply_guardrail(
|
||||
guardrailIdentifier=shield.provider_resource_id,
|
||||
guardrailVersion=shield_params["guardrailVersion"],
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
# Make a call to the FiddleCube API for guardrails
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
request_body = {
|
||||
"messages": [message.model_dump(mode="json") for message in messages],
|
||||
}
|
||||
if params.get("excluded_categories"):
|
||||
request_body["excluded_categories"] = params.get("excluded_categories")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = await client.post(
|
||||
f"{self.config.api_url}/safety/guard/check",
|
||||
json=request_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
logger.debug("Response:::", response.status_code)
|
||||
|
||||
# Check if the response is successful
|
||||
if response.status_code != 200:
|
||||
logger.error(f"FiddleCube API error: {response.status_code} - {response.text}")
|
||||
raise RuntimeError("Failed to run shield with FiddleCube API")
|
||||
|
||||
# Convert the response into the format RunShieldResponse expects
|
||||
response_data = response.json()
|
||||
logger.debug("Response data", response_data)
|
||||
|
||||
return RunShieldResponse()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue