From 49f7e04f833a87ecb2bbab8f094da52f2a0a1d08 Mon Sep 17 00:00:00 2001 From: Kaushik Date: Mon, 10 Feb 2025 15:30:48 -0800 Subject: [PATCH] adding prod url also cleans up print messages --- .../remote/safety/fiddlecube/config.py | 7 ++- .../remote/safety/fiddlecube/fiddlecube.py | 47 +++++++++---------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/llama_stack/providers/remote/safety/fiddlecube/config.py b/llama_stack/providers/remote/safety/fiddlecube/config.py index f26e79f20..0ba5d7c32 100644 --- a/llama_stack/providers/remote/safety/fiddlecube/config.py +++ b/llama_stack/providers/remote/safety/fiddlecube/config.py @@ -4,11 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from typing import List + +from pydantic import BaseModel, Field from llama_models.schema_utils import json_schema_type @json_schema_type class FiddlecubeSafetyConfig(BaseModel): - pass + api_url: str = "https://api.fiddlecube.ai/api" + excluded_categories: List[str] = Field(default_factory=list) diff --git a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py index 84228c5f8..dd5c49da5 100644 --- a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py +++ b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py @@ -14,8 +14,6 @@ from llama_stack.apis.inference import Message from llama_stack.apis.safety import ( RunShieldResponse, Safety, - SafetyViolation, - ViolationLevel, ) from llama_stack.apis.shields import Shield from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -77,28 +75,29 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate): content_messages.append({"text": {"text": message.content}}) logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") - response = self.bedrock_runtime_client.apply_guardrail( - guardrailIdentifier=shield.provider_resource_id, - guardrailVersion=shield_params["guardrailVersion"], - source="OUTPUT", # or 'INPUT' depending on your use case - content=content_messages, - ) - if response["action"] == "GUARDRAIL_INTERVENED": - user_message = "" - metadata = {} - for output in response["outputs"]: - # guardrails returns a list - however for this implementation we will leverage the last values - user_message = output["text"] - for assessment in response["assessments"]: - # guardrails returns a list - however for this implementation we will leverage the last values - metadata = dict(assessment) - - return RunShieldResponse( - violation=SafetyViolation( - user_message=user_message, - violation_level=ViolationLevel.ERROR, - metadata=metadata, - ) + # Make a call to the FiddleCube API for guardrails + async with httpx.AsyncClient(timeout=30.0) as client: + request_body = { + "messages": [message.model_dump(mode="json") for message in messages], + } + if params.get("excluded_categories"): + request_body["excluded_categories"] = params.get("excluded_categories") + headers = {"Content-Type": "application/json"} + response = await client.post( + f"{self.config.api_url}/safety/guard/check", + json=request_body, + headers=headers, ) + logger.debug("Response:::", response.status_code) + + # Check if the response is successful + if response.status_code != 200: + logger.error(f"FiddleCube API error: {response.status_code} - {response.text}") + raise RuntimeError("Failed to run shield with FiddleCube API") + + # Convert the response into the format RunShieldResponse expects + response_data = response.json() + logger.debug("Response data", response_data) + return RunShieldResponse()