mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
adding prod url
also cleans up print messages
This commit is contained in:
parent
aea5b2745d
commit
49f7e04f83
2 changed files with 28 additions and 26 deletions
|
@ -4,11 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FiddlecubeSafetyConfig(BaseModel):
|
class FiddlecubeSafetyConfig(BaseModel):
|
||||||
pass
|
api_url: str = "https://api.fiddlecube.ai/api"
|
||||||
|
excluded_categories: List[str] = Field(default_factory=list)
|
||||||
|
|
|
@ -14,8 +14,6 @@ from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import (
|
||||||
RunShieldResponse,
|
RunShieldResponse,
|
||||||
Safety,
|
Safety,
|
||||||
SafetyViolation,
|
|
||||||
ViolationLevel,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
@ -77,28 +75,29 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
content_messages.append({"text": {"text": message.content}})
|
content_messages.append({"text": {"text": message.content}})
|
||||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||||
|
|
||||||
response = self.bedrock_runtime_client.apply_guardrail(
|
# Make a call to the FiddleCube API for guardrails
|
||||||
guardrailIdentifier=shield.provider_resource_id,
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
guardrailVersion=shield_params["guardrailVersion"],
|
request_body = {
|
||||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
"messages": [message.model_dump(mode="json") for message in messages],
|
||||||
content=content_messages,
|
}
|
||||||
)
|
if params.get("excluded_categories"):
|
||||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
request_body["excluded_categories"] = params.get("excluded_categories")
|
||||||
user_message = ""
|
headers = {"Content-Type": "application/json"}
|
||||||
metadata = {}
|
response = await client.post(
|
||||||
for output in response["outputs"]:
|
f"{self.config.api_url}/safety/guard/check",
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
json=request_body,
|
||||||
user_message = output["text"]
|
headers=headers,
|
||||||
for assessment in response["assessments"]:
|
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
|
||||||
metadata = dict(assessment)
|
|
||||||
|
|
||||||
return RunShieldResponse(
|
|
||||||
violation=SafetyViolation(
|
|
||||||
user_message=user_message,
|
|
||||||
violation_level=ViolationLevel.ERROR,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug("Response:::", response.status_code)
|
||||||
|
|
||||||
|
# Check if the response is successful
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"FiddleCube API error: {response.status_code} - {response.text}")
|
||||||
|
raise RuntimeError("Failed to run shield with FiddleCube API")
|
||||||
|
|
||||||
|
# Convert the response into the format RunShieldResponse expects
|
||||||
|
response_data = response.json()
|
||||||
|
logger.debug("Response data", response_data)
|
||||||
|
|
||||||
return RunShieldResponse()
|
return RunShieldResponse()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue