This commit is contained in:
Zain Hasan 2024-09-24 16:21:54 -04:00
commit af1710af75
6 changed files with 56 additions and 26 deletions

View file

@ -55,7 +55,7 @@ class ModelDescribe(Subcommand):
("Description", model.description_markdown), ("Description", model.description_markdown),
("Context Length", f"{model.max_seq_length // 1024}K tokens"), ("Context Length", f"{model.max_seq_length // 1024}K tokens"),
("Weights format", model.quantization_format.value), ("Weights format", model.quantization_format.value),
("Model params.json", json.dumps(model.model_args, indent=4)), ("Model params.json", json.dumps(model.arch_args, indent=4)),
] ]
if model.recommended_sampling_params is not None: if model.recommended_sampling_params is not None:

View file

@ -35,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
@ -42,9 +45,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus, SpanStatus,
start_trace, start_trace,
) )
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
@ -90,10 +90,35 @@ async def global_exception_handler(request: Request, exc: Exception):
def translate_exception(exc: Exception) -> HTTPException: def translate_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ValidationError): if isinstance(exc, ValidationError):
return RequestValidationError(exc.raw_errors) exc = RequestValidationError(exc.raw_errors)
# Add more custom exception translations here if isinstance(exc, RequestValidationError):
return HTTPException(status_code=500, detail="Internal server error") return HTTPException(
status_code=400,
detail={
"errors": [
{
"loc": list(error["loc"]),
"msg": error["msg"],
"type": error["type"],
}
for error in exc.errors()
]
},
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, PermissionError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
else:
return HTTPException(
status_code=500,
detail="Internal server error: An unexpected error occurred.",
)
async def passthrough( async def passthrough(

View file

@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin):
# CompletionMessage itself in the ShieldResponse # CompletionMessage itself in the ShieldResponse
messages.append( messages.append(
CompletionMessage( CompletionMessage(
content=violation.user_message, content=step.violation.user_message,
stop_reason=StopReason.end_of_turn, stop_reason=StopReason.end_of_turn,
) )
) )

View file

@ -32,7 +32,7 @@ class ShieldRunnerMixin:
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields( async def run_multiple_shields(
self, messages: List[Message], shields: List[str] self, messages: List[Message], shield_types: List[str]
) -> None: ) -> None:
responses = await asyncio.gather( responses = await asyncio.gather(
*[ *[
@ -40,16 +40,18 @@ class ShieldRunnerMixin:
shield_type=shield_type, shield_type=shield_type,
messages=messages, messages=messages,
) )
for shield_type in shields for shield_type in shield_types
] ]
) )
for shield_type, response in zip(shield_types, responses):
if not response.violation:
continue
for shield, r in zip(shields, responses): violation = response.violation
if r.violation: if violation.violation_level == ViolationLevel.ERROR:
if shield.on_violation_action == OnViolationAction.RAISE: raise SafetyException(violation)
raise SafetyException(r) elif violation.violation_level == ViolationLevel.WARN:
elif shield.on_violation_action == OnViolationAction.WARN:
cprint( cprint(
f"[Warn]{shield.__class__.__name__} raised a warning", f"[Warn]{shield_type} raised a warning",
color="red", color="red",
) )

View file

@ -10,6 +10,10 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
)
from .config import MetaReferenceShieldType, SafetyConfig from .config import MetaReferenceShieldType, SafetyConfig
from .shields import ( from .shields import (
@ -69,9 +73,13 @@ class MetaReferenceSafetyImpl(Safety):
# TODO: we can refactor ShieldBase, etc. to be inline with the API types # TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages) res = await shield.run(messages)
violation = None violation = None
if res.is_violation: if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
violation = SafetyViolation( violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message, user_message=res.violation_return_message,
metadata={ metadata={
"violation_type": res.violation_type, "violation_type": res.violation_type,

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.