From e5bdd6615af32fc8488826b349736fa33f9e676a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 23 Sep 2024 18:17:15 -0700 Subject: [PATCH 1/6] bug fix for safety violation --- .../impls/meta_reference/agents/agent_instance.py | 2 +- .../providers/impls/meta_reference/agents/safety.py | 12 +----------- .../providers/impls/meta_reference/safety/safety.py | 13 +++++++++++++ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 0ac26a857..797a1bc7f 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin): # CompletionMessage itself in the ShieldResponse messages.append( CompletionMessage( - content=violation.user_message, + content=step.violation.user_message, stop_reason=StopReason.end_of_turn, ) ) diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index 44d47b16c..e7c982181 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -34,7 +34,7 @@ class ShieldRunnerMixin: async def run_multiple_shields( self, messages: List[Message], shields: List[str] ) -> None: - responses = await asyncio.gather( + await asyncio.gather( *[ self.safety_api.run_shield( shield_type=shield_type, @@ -43,13 +43,3 @@ class ShieldRunnerMixin: for shield_type in shields ] ) - - for shield, r in zip(shields, responses): - if r.violation: - if shield.on_violation_action == OnViolationAction.RAISE: - raise SafetyException(r) - elif shield.on_violation_action == OnViolationAction.WARN: - cprint( - f"[Warn]{shield.__class__.__name__} raised a warning", - color="red", - ) diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6eccf47a5..e5c42b45c 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -10,6 +10,11 @@ from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException +from llama_stack.providers.impls.meta_reference.safety.shields.base import ( + OnViolationAction, +) + from .config import MetaReferenceShieldType, SafetyConfig from .shields import ( @@ -78,6 +83,14 @@ class MetaReferenceSafetyImpl(Safety): }, ) + if shield.on_violation_action == OnViolationAction.RAISE: + raise SafetyException(violation) + elif shield.on_violation_action == OnViolationAction.WARN: + cprint( + f"[Warn]{shield.__class__.__name__} raised a warning", + color="red", + ) + return RunShieldResponse(violation=violation) def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: From c9005e95ed602bca74c348b5251c78ce5d3e362c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 23 Sep 2024 19:06:30 -0700 Subject: [PATCH 2/6] Another attempt at a proper bugfix for safety violations --- .../impls/meta_reference/agents/safety.py | 18 +++++++++++++++--- .../impls/meta_reference/safety/safety.py | 17 ++++++----------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index e7c982181..b3aa53728 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -32,14 +32,26 @@ class ShieldRunnerMixin: self.output_shields = output_shields async def run_multiple_shields( - self, messages: List[Message], shields: List[str] + self, messages: List[Message], shield_types: List[str] ) -> None: - await asyncio.gather( + responses = await asyncio.gather( *[ self.safety_api.run_shield( shield_type=shield_type, messages=messages, ) - for shield_type in shields + for shield_type in shield_types ] ) + for shield_type, response in zip(shields, responses): + if not response.violation: + continue + + violation = response.violation + if violation.violation_level == ViolationLevel.ERROR: + raise SafetyException(violation) + elif violation.violation_level == ViolationLevel.WARN: + cprint( + f"[Warn]{shield_type} raised a warning", + color="red", + ) diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index e5c42b45c..6cf8a79d2 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -10,7 +10,6 @@ from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException from llama_stack.providers.impls.meta_reference.safety.shields.base import ( OnViolationAction, ) @@ -74,23 +73,19 @@ class MetaReferenceSafetyImpl(Safety): # TODO: we can refactor ShieldBase, etc. to be inline with the API types res = await shield.run(messages) violation = None - if res.is_violation: + if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE: violation = SafetyViolation( - violation_level=ViolationLevel.ERROR, + violation_level=( + ViolationLevel.ERROR + if shield.on_violation_action == OnViolationAction.RAISE + else ViolationLevel.WARN + ), user_message=res.violation_return_message, metadata={ "violation_type": res.violation_type, }, ) - if shield.on_violation_action == OnViolationAction.RAISE: - raise SafetyException(violation) - elif shield.on_violation_action == OnViolationAction.WARN: - cprint( - f"[Warn]{shield.__class__.__name__} raised a warning", - color="red", - ) - return RunShieldResponse(violation=violation) def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: From f92ff86b967f18f83199a7db7f5a42987b0f765b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 23 Sep 2024 21:22:22 -0700 Subject: [PATCH 3/6] fix shields in agents safety --- llama_stack/providers/impls/meta_reference/agents/safety.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index b3aa53728..fb5821f6a 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -43,7 +43,7 @@ class ShieldRunnerMixin: for shield_type in shield_types ] ) - for shield_type, response in zip(shields, responses): + for shield_type, response in zip(shield_types, responses): if not response.violation: continue From f136f802b1e1596e907559c3539aa344cd6d06bc Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 23 Sep 2024 21:39:47 -0700 Subject: [PATCH 4/6] Somewhat better error handling --- llama_stack/distribution/server/server.py | 37 +++++++++++++++++++---- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f09e1c586..38218ab8b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -42,9 +45,6 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import ( @@ -90,10 +90,35 @@ async def global_exception_handler(request: Request, exc: Exception): def translate_exception(exc: Exception) -> HTTPException: if isinstance(exc, ValidationError): - return RequestValidationError(exc.raw_errors) + exc = RequestValidationError(exc.raw_errors) - # Add more custom exception translations here - return HTTPException(status_code=500, detail="Internal server error") + if isinstance(exc, RequestValidationError): + return HTTPException( + status_code=400, + detail={ + "errors": [ + { + "loc": list(error["loc"]), + "msg": error["msg"], + "type": error["type"], + } + for error in exc.errors() + ] + }, + ) + elif isinstance(exc, ValueError): + return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + elif isinstance(exc, PermissionError): + return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") + elif isinstance(exc, TimeoutError): + return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") + elif isinstance(exc, NotImplementedError): + return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + else: + return HTTPException( + status_code=500, + detail="Internal server error: An unexpected error occurred.", + ) async def passthrough( From e617273d8c023148565d8a3134e03545dadc4dab Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 23 Sep 2024 21:44:26 -0700 Subject: [PATCH 5/6] attribute changed (model_args -> arch_args) --- llama_stack/cli/model/describe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index b100f7544..c99cb06c1 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -55,7 +55,7 @@ class ModelDescribe(Subcommand): ("Description", model.description_markdown), ("Context Length", f"{model.max_seq_length // 1024}K tokens"), ("Weights format", model.quantization_format.value), - ("Model params.json", json.dumps(model.model_args, indent=4)), + ("Model params.json", json.dumps(model.arch_args, indent=4)), ] if model.recommended_sampling_params is not None: From d04cd97abaddd63c7d71ffbbe2756b4f142a71f3 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 24 Sep 2024 01:03:40 -0700 Subject: [PATCH 6/6] remove providers/impls/sqlite/* --- llama_stack/providers/impls/sqlite/__init__.py | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 llama_stack/providers/impls/sqlite/__init__.py diff --git a/llama_stack/providers/impls/sqlite/__init__.py b/llama_stack/providers/impls/sqlite/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/impls/sqlite/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree.