mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? This PR has two fixes needed for correct trace context propagation across asycnio boundary Fix 1: Start using context vars to store the global trace context. This is needed since we cannot use the same trace context across coroutines since the state is shared. each coroutine should have its own trace context so that each of it can start storing its state correctly. Fix 2: Start a new span for each new coroutines started for running shields to keep the span tree clean ## Test Plan ### Integration tests with server LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/together/together-run.yaml LLAMA_STACK_CONFIG=http://localhost:8321 pytest -s --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct server logs: https://gist.github.com/dineshyv/51ac5d9864ed031d0d89ce77352821fe test logs: https://gist.github.com/dineshyv/e66acc1c4648a42f1854600609c467f3 ### Integration tests with library client LLAMA_STACK_CONFIG=fireworks pytest -s --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct logs: https://gist.github.com/dineshyv/ca160696a0b167223378673fb1dcefb8 ### Apps test with server: ``` LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/together/together-run.yaml python -m examples.agents.e2e_loop_with_client_tools localhost 8321 ``` server logs: https://gist.github.com/dineshyv/1717a572d8f7c14279c36123b79c5797 app logs: https://gist.github.com/dineshyv/44167e9f57806a0ba3b710c32aec02f8
52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import List
|
|
|
|
from llama_stack.apis.inference import Message
|
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
|
from llama_stack.providers.utils.telemetry import tracing
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class SafetyException(Exception): # noqa: N818
|
|
def __init__(self, violation: SafetyViolation):
|
|
self.violation = violation
|
|
super().__init__(violation.user_message)
|
|
|
|
|
|
class ShieldRunnerMixin:
|
|
def __init__(
|
|
self,
|
|
safety_api: Safety,
|
|
input_shields: List[str] = None,
|
|
output_shields: List[str] = None,
|
|
):
|
|
self.safety_api = safety_api
|
|
self.input_shields = input_shields
|
|
self.output_shields = output_shields
|
|
|
|
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
|
async def run_shield_with_span(identifier: str):
|
|
async with tracing.span(f"run_shield_{identifier}"):
|
|
return await self.safety_api.run_shield(
|
|
shield_id=identifier,
|
|
messages=messages,
|
|
)
|
|
|
|
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
|
for identifier, response in zip(identifiers, responses, strict=False):
|
|
if not response.violation:
|
|
continue
|
|
|
|
violation = response.violation
|
|
if violation.violation_level == ViolationLevel.ERROR:
|
|
raise SafetyException(violation)
|
|
elif violation.violation_level == ViolationLevel.WARN:
|
|
log.warning(f"[Warn]{identifier} raised a warning")
|