adding prod url

also cleans up print messages
This commit is contained in:
Kaushik 2025-02-10 15:30:48 -08:00
parent aea5b2745d
commit 49f7e04f83
2 changed files with 28 additions and 26 deletions

View file

@ -4,11 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from typing import List
from pydantic import BaseModel, Field
from llama_models.schema_utils import json_schema_type
@json_schema_type
class FiddlecubeSafetyConfig(BaseModel):
pass
api_url: str = "https://api.fiddlecube.ai/api"
excluded_categories: List[str] = Field(default_factory=list)

View file

@ -14,8 +14,6 @@ from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -77,28 +75,29 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = self.bedrock_runtime_client.apply_guardrail(
guardrailIdentifier=shield.provider_resource_id,
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
# Make a call to the FiddleCube API for guardrails
async with httpx.AsyncClient(timeout=30.0) as client:
request_body = {
"messages": [message.model_dump(mode="json") for message in messages],
}
if params.get("excluded_categories"):
request_body["excluded_categories"] = params.get("excluded_categories")
headers = {"Content-Type": "application/json"}
response = await client.post(
f"{self.config.api_url}/safety/guard/check",
json=request_body,
headers=headers,
)
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
logger.debug("Response:::", response.status_code)
# Check if the response is successful
if response.status_code != 200:
logger.error(f"FiddleCube API error: {response.status_code} - {response.text}")
raise RuntimeError("Failed to run shield with FiddleCube API")
# Convert the response into the format RunShieldResponse expects
response_data = response.json()
logger.debug("Response data", response_data)
return RunShieldResponse()