mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
Merge branch 'main' of https://github.com/zainhas/weaviate-memory-adapter
This commit is contained in:
commit
af1710af75
6 changed files with 56 additions and 26 deletions
|
@ -55,7 +55,7 @@ class ModelDescribe(Subcommand):
|
|||
("Description", model.description_markdown),
|
||||
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
||||
("Weights format", model.quantization_format.value),
|
||||
("Model params.json", json.dumps(model.model_args, indent=4)),
|
||||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||
]
|
||||
|
||||
if model.recommended_sampling_params is not None:
|
||||
|
|
|
@ -35,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -42,9 +45,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import (
|
||||
|
@ -90,10 +90,35 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||
|
||||
def translate_exception(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, ValidationError):
|
||||
return RequestValidationError(exc.raw_errors)
|
||||
exc = RequestValidationError(exc.raw_errors)
|
||||
|
||||
# Add more custom exception translations here
|
||||
return HTTPException(status_code=500, detail="Internal server error")
|
||||
if isinstance(exc, RequestValidationError):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"errors": [
|
||||
{
|
||||
"loc": list(error["loc"]),
|
||||
"msg": error["msg"],
|
||||
"type": error["type"],
|
||||
}
|
||||
for error in exc.errors()
|
||||
]
|
||||
},
|
||||
)
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, PermissionError):
|
||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, TimeoutError):
|
||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||
else:
|
||||
return HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error: An unexpected error occurred.",
|
||||
)
|
||||
|
||||
|
||||
async def passthrough(
|
||||
|
|
|
@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=violation.user_message,
|
||||
content=step.violation.user_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -32,7 +32,7 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shields: List[str]
|
||||
self, messages: List[Message], shield_types: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
|
@ -40,16 +40,18 @@ class ShieldRunnerMixin:
|
|||
shield_type=shield_type,
|
||||
messages=messages,
|
||||
)
|
||||
for shield_type in shields
|
||||
for shield_type in shield_types
|
||||
]
|
||||
)
|
||||
for shield_type, response in zip(shield_types, responses):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
for shield, r in zip(shields, responses):
|
||||
if r.violation:
|
||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||
raise SafetyException(r)
|
||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
violation = response.violation
|
||||
if violation.violation_level == ViolationLevel.ERROR:
|
||||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield_type} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
|
|
@ -10,6 +10,10 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
|||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceShieldType, SafetyConfig
|
||||
|
||||
from .shields import (
|
||||
|
@ -69,9 +73,13 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||
res = await shield.run(messages)
|
||||
violation = None
|
||||
if res.is_violation:
|
||||
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
violation_level=(
|
||||
ViolationLevel.ERROR
|
||||
if shield.on_violation_action == OnViolationAction.RAISE
|
||||
else ViolationLevel.WARN
|
||||
),
|
||||
user_message=res.violation_return_message,
|
||||
metadata={
|
||||
"violation_type": res.violation_type,
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
Loading…
Add table
Add a link
Reference in a new issue