mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
619b03b2e9
6 changed files with 56 additions and 26 deletions
|
@ -55,7 +55,7 @@ class ModelDescribe(Subcommand):
|
||||||
("Description", model.description_markdown),
|
("Description", model.description_markdown),
|
||||||
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
||||||
("Weights format", model.quantization_format.value),
|
("Weights format", model.quantization_format.value),
|
||||||
("Model params.json", json.dumps(model.model_args, indent=4)),
|
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||||
]
|
]
|
||||||
|
|
||||||
if model.recommended_sampling_params is not None:
|
if model.recommended_sampling_params is not None:
|
||||||
|
|
|
@ -35,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from termcolor import cprint
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -42,9 +45,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
from termcolor import cprint
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
|
@ -90,10 +90,35 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
|
||||||
def translate_exception(exc: Exception) -> HTTPException:
|
def translate_exception(exc: Exception) -> HTTPException:
|
||||||
if isinstance(exc, ValidationError):
|
if isinstance(exc, ValidationError):
|
||||||
return RequestValidationError(exc.raw_errors)
|
exc = RequestValidationError(exc.raw_errors)
|
||||||
|
|
||||||
# Add more custom exception translations here
|
if isinstance(exc, RequestValidationError):
|
||||||
return HTTPException(status_code=500, detail="Internal server error")
|
return HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"errors": [
|
||||||
|
{
|
||||||
|
"loc": list(error["loc"]),
|
||||||
|
"msg": error["msg"],
|
||||||
|
"type": error["type"],
|
||||||
|
}
|
||||||
|
for error in exc.errors()
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif isinstance(exc, ValueError):
|
||||||
|
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
||||||
|
elif isinstance(exc, PermissionError):
|
||||||
|
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||||
|
elif isinstance(exc, TimeoutError):
|
||||||
|
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||||
|
elif isinstance(exc, NotImplementedError):
|
||||||
|
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||||
|
else:
|
||||||
|
return HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Internal server error: An unexpected error occurred.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def passthrough(
|
async def passthrough(
|
||||||
|
|
|
@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# CompletionMessage itself in the ShieldResponse
|
# CompletionMessage itself in the ShieldResponse
|
||||||
messages.append(
|
messages.append(
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
content=violation.user_message,
|
content=step.violation.user_message,
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -32,7 +32,7 @@ class ShieldRunnerMixin:
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(
|
async def run_multiple_shields(
|
||||||
self, messages: List[Message], shields: List[str]
|
self, messages: List[Message], shield_types: List[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
responses = await asyncio.gather(
|
responses = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
|
@ -40,16 +40,18 @@ class ShieldRunnerMixin:
|
||||||
shield_type=shield_type,
|
shield_type=shield_type,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
for shield_type in shields
|
for shield_type in shield_types
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
for shield_type, response in zip(shield_types, responses):
|
||||||
|
if not response.violation:
|
||||||
|
continue
|
||||||
|
|
||||||
for shield, r in zip(shields, responses):
|
violation = response.violation
|
||||||
if r.violation:
|
if violation.violation_level == ViolationLevel.ERROR:
|
||||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
raise SafetyException(violation)
|
||||||
raise SafetyException(r)
|
elif violation.violation_level == ViolationLevel.WARN:
|
||||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
cprint(
|
||||||
cprint(
|
f"[Warn]{shield_type} raised a warning",
|
||||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
color="red",
|
||||||
color="red",
|
)
|
||||||
)
|
|
||||||
|
|
|
@ -10,6 +10,10 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||||
|
OnViolationAction,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceShieldType, SafetyConfig
|
from .config import MetaReferenceShieldType, SafetyConfig
|
||||||
|
|
||||||
from .shields import (
|
from .shields import (
|
||||||
|
@ -69,9 +73,13 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||||
res = await shield.run(messages)
|
res = await shield.run(messages)
|
||||||
violation = None
|
violation = None
|
||||||
if res.is_violation:
|
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
|
||||||
violation = SafetyViolation(
|
violation = SafetyViolation(
|
||||||
violation_level=ViolationLevel.ERROR,
|
violation_level=(
|
||||||
|
ViolationLevel.ERROR
|
||||||
|
if shield.on_violation_action == OnViolationAction.RAISE
|
||||||
|
else ViolationLevel.WARN
|
||||||
|
),
|
||||||
user_message=res.violation_return_message,
|
user_message=res.violation_return_message,
|
||||||
metadata={
|
metadata={
|
||||||
"violation_type": res.violation_type,
|
"violation_type": res.violation_type,
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
Loading…
Add table
Add a link
Reference in a new issue