diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index b100f7544..c99cb06c1 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -55,7 +55,7 @@ class ModelDescribe(Subcommand): ("Description", model.description_markdown), ("Context Length", f"{model.max_seq_length // 1024}K tokens"), ("Weights format", model.quantization_format.value), - ("Model params.json", json.dumps(model.model_args, indent=4)), + ("Model params.json", json.dumps(model.arch_args, indent=4)), ] if model.recommended_sampling_params is not None: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f09e1c586..38218ab8b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -42,9 +45,6 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import ( @@ -90,10 +90,35 @@ async def global_exception_handler(request: Request, exc: Exception): def translate_exception(exc: Exception) -> HTTPException: if isinstance(exc, ValidationError): - return RequestValidationError(exc.raw_errors) + exc = RequestValidationError(exc.raw_errors) - # Add more custom exception translations here - return HTTPException(status_code=500, detail="Internal server error") + if isinstance(exc, RequestValidationError): + return HTTPException( + status_code=400, + detail={ + "errors": [ + { + "loc": list(error["loc"]), + "msg": error["msg"], + "type": error["type"], + } + for error in exc.errors() + ] + }, + ) + elif isinstance(exc, ValueError): + return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + elif isinstance(exc, PermissionError): + return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") + elif isinstance(exc, TimeoutError): + return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") + elif isinstance(exc, NotImplementedError): + return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + else: + return HTTPException( + status_code=500, + detail="Internal server error: An unexpected error occurred.", + ) async def passthrough( diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 0ac26a857..797a1bc7f 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin): # CompletionMessage itself in the ShieldResponse messages.append( CompletionMessage( - content=violation.user_message, + content=step.violation.user_message, stop_reason=StopReason.end_of_turn, ) ) diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index 44d47b16c..fb5821f6a 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -32,7 +32,7 @@ class ShieldRunnerMixin: self.output_shields = output_shields async def run_multiple_shields( - self, messages: List[Message], shields: List[str] + self, messages: List[Message], shield_types: List[str] ) -> None: responses = await asyncio.gather( *[ @@ -40,16 +40,18 @@ class ShieldRunnerMixin: shield_type=shield_type, messages=messages, ) - for shield_type in shields + for shield_type in shield_types ] ) + for shield_type, response in zip(shield_types, responses): + if not response.violation: + continue - for shield, r in zip(shields, responses): - if r.violation: - if shield.on_violation_action == OnViolationAction.RAISE: - raise SafetyException(r) - elif shield.on_violation_action == OnViolationAction.WARN: - cprint( - f"[Warn]{shield.__class__.__name__} raised a warning", - color="red", - ) + violation = response.violation + if violation.violation_level == ViolationLevel.ERROR: + raise SafetyException(violation) + elif violation.violation_level == ViolationLevel.WARN: + cprint( + f"[Warn]{shield_type} raised a warning", + color="red", + ) diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6eccf47a5..6cf8a79d2 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -10,6 +10,10 @@ from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.providers.impls.meta_reference.safety.shields.base import ( + OnViolationAction, +) + from .config import MetaReferenceShieldType, SafetyConfig from .shields import ( @@ -69,9 +73,13 @@ class MetaReferenceSafetyImpl(Safety): # TODO: we can refactor ShieldBase, etc. to be inline with the API types res = await shield.run(messages) violation = None - if res.is_violation: + if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE: violation = SafetyViolation( - violation_level=ViolationLevel.ERROR, + violation_level=( + ViolationLevel.ERROR + if shield.on_violation_action == OnViolationAction.RAISE + else ViolationLevel.WARN + ), user_message=res.violation_return_message, metadata={ "violation_type": res.violation_type, diff --git a/llama_stack/providers/impls/sqlite/__init__.py b/llama_stack/providers/impls/sqlite/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/impls/sqlite/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree.