mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fix
This commit is contained in:
parent
1f5adff5a7
commit
ec1bae78e6
2 changed files with 11 additions and 7 deletions
|
|
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
|
|||
self.prompt = prompt
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools if ctx.tool_context else {}
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
ctx.tool_context.previous_tools if ctx.tool_context else {}
|
||||
)
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
|
|
@ -1138,7 +1140,9 @@ class StreamingResponseOrchestrator:
|
|||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
||||
async for stream_event in self._process_new_tools(
|
||||
self.ctx.tool_context.tools_to_process, output_messages
|
||||
):
|
||||
yield stream_event
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -330,8 +330,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
# list_shields not in Safety interface but available at runtime via API routing
|
||||
shields_list = await safety_api.list_shields() # type: ignore[attr-defined]
|
||||
# TODO: list_shields not in Safety interface but available at runtime via API routing
|
||||
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
|
|
@ -348,9 +348,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
for result in response.results:
|
||||
if result.flagged:
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [
|
||||
cat for cat, flagged in result.categories.items() if flagged
|
||||
] if result.categories else []
|
||||
flagged_categories = (
|
||||
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
|
||||
)
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue