Merge branch 'meta-llama:main' into main

This commit is contained in:
Zain Hasan 2024-09-24 11:31:34 -04:00 committed by GitHub
commit 619b03b2e9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 56 additions and 26 deletions

View file

@ -55,7 +55,7 @@ class ModelDescribe(Subcommand):
("Description", model.description_markdown),
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
("Weights format", model.quantization_format.value),
("Model params.json", json.dumps(model.model_args, indent=4)),
("Model params.json", json.dumps(model.arch_args, indent=4)),
]
if model.recommended_sampling_params is not None:

View file

@ -35,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -42,9 +45,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import (
@ -90,10 +90,35 @@ async def global_exception_handler(request: Request, exc: Exception):
def translate_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ValidationError):
return RequestValidationError(exc.raw_errors)
exc = RequestValidationError(exc.raw_errors)
# Add more custom exception translations here
return HTTPException(status_code=500, detail="Internal server error")
if isinstance(exc, RequestValidationError):
return HTTPException(
status_code=400,
detail={
"errors": [
{
"loc": list(error["loc"]),
"msg": error["msg"],
"type": error["type"],
}
for error in exc.errors()
]
},
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, PermissionError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
else:
return HTTPException(
status_code=500,
detail="Internal server error: An unexpected error occurred.",
)
async def passthrough(

View file

@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin):
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=violation.user_message,
content=step.violation.user_message,
stop_reason=StopReason.end_of_turn,
)
)

View file

@ -32,7 +32,7 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], shields: List[str]
self, messages: List[Message], shield_types: List[str]
) -> None:
responses = await asyncio.gather(
*[
@ -40,16 +40,18 @@ class ShieldRunnerMixin:
shield_type=shield_type,
messages=messages,
)
for shield_type in shields
for shield_type in shield_types
]
)
for shield_type, response in zip(shield_types, responses):
if not response.violation:
continue
for shield, r in zip(shields, responses):
if r.violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
violation = response.violation
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{shield_type} raised a warning",
color="red",
)

View file

@ -10,6 +10,10 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
)
from .config import MetaReferenceShieldType, SafetyConfig
from .shields import (
@ -69,9 +73,13 @@ class MetaReferenceSafetyImpl(Safety):
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation:
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.