mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 20:22:27 +00:00
minor fix
This commit is contained in:
parent
6358d0a478
commit
093f017da2
7 changed files with 32 additions and 108 deletions
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
|
@ -82,20 +82,5 @@ class SafetyRouter(Safety):
|
|||
input=input,
|
||||
model=model,
|
||||
)
|
||||
self._validate_required_categories_exist(response)
|
||||
|
||||
return response
|
||||
|
||||
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
|
||||
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
|
||||
required_categories = list(map(str, OpenAICategories))
|
||||
|
||||
categories = response.results[0].categories
|
||||
category_applied_input_types = response.results[0].category_applied_input_types
|
||||
category_scores = response.results[0].category_scores
|
||||
|
||||
for i in [categories, category_applied_input_types, category_scores]:
|
||||
if not set(required_categories).issubset(set(i.keys())):
|
||||
raise ValueError(
|
||||
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue