adding prod url

also cleans up print messages
This commit is contained in:
Kaushik 2025-02-10 15:30:48 -08:00
parent aea5b2745d
commit 49f7e04f83
2 changed files with 28 additions and 26 deletions

View file

@ -4,11 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel from typing import List
from pydantic import BaseModel, Field
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class FiddlecubeSafetyConfig(BaseModel): class FiddlecubeSafetyConfig(BaseModel):
pass api_url: str = "https://api.fiddlecube.ai/api"
excluded_categories: List[str] = Field(default_factory=list)

View file

@ -14,8 +14,6 @@ from llama_stack.apis.inference import Message
from llama_stack.apis.safety import ( from llama_stack.apis.safety import (
RunShieldResponse, RunShieldResponse,
Safety, Safety,
SafetyViolation,
ViolationLevel,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -77,28 +75,29 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
content_messages.append({"text": {"text": message.content}}) content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = self.bedrock_runtime_client.apply_guardrail( # Make a call to the FiddleCube API for guardrails
guardrailIdentifier=shield.provider_resource_id, async with httpx.AsyncClient(timeout=30.0) as client:
guardrailVersion=shield_params["guardrailVersion"], request_body = {
source="OUTPUT", # or 'INPUT' depending on your use case "messages": [message.model_dump(mode="json") for message in messages],
content=content_messages, }
) if params.get("excluded_categories"):
if response["action"] == "GUARDRAIL_INTERVENED": request_body["excluded_categories"] = params.get("excluded_categories")
user_message = "" headers = {"Content-Type": "application/json"}
metadata = {} response = await client.post(
for output in response["outputs"]: f"{self.config.api_url}/safety/guard/check",
# guardrails returns a list - however for this implementation we will leverage the last values json=request_body,
user_message = output["text"] headers=headers,
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
) )
logger.debug("Response:::", response.status_code)
# Check if the response is successful
if response.status_code != 200:
logger.error(f"FiddleCube API error: {response.status_code} - {response.text}")
raise RuntimeError("Failed to run shield with FiddleCube API")
# Convert the response into the format RunShieldResponse expects
response_data = response.json()
logger.debug("Response data", response_data)
return RunShieldResponse() return RunShieldResponse()