mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fix
This commit is contained in:
parent
1f5adff5a7
commit
ec1bae78e6
2 changed files with 11 additions and 7 deletions
|
|
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.sequence_number = 0
|
self.sequence_number = 0
|
||||||
# Store MCP tool mapping that gets built during tool processing
|
# Store MCP tool mapping that gets built during tool processing
|
||||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools if ctx.tool_context else {}
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||||
|
ctx.tool_context.previous_tools if ctx.tool_context else {}
|
||||||
|
)
|
||||||
# Track final messages after all tool executions
|
# Track final messages after all tool executions
|
||||||
self.final_messages: list[OpenAIMessageParam] = []
|
self.final_messages: list[OpenAIMessageParam] = []
|
||||||
# mapping for annotations
|
# mapping for annotations
|
||||||
|
|
@ -1138,7 +1140,9 @@ class StreamingResponseOrchestrator:
|
||||||
yield evt
|
yield evt
|
||||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||||
if self.ctx.tool_context.tools_to_process:
|
if self.ctx.tool_context.tools_to_process:
|
||||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
async for stream_event in self._process_new_tools(
|
||||||
|
self.ctx.tool_context.tools_to_process, output_messages
|
||||||
|
):
|
||||||
yield stream_event
|
yield stream_event
|
||||||
|
|
||||||
def _approval_required(self, tool_name: str) -> bool:
|
def _approval_required(self, tool_name: str) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -330,8 +330,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
||||||
|
|
||||||
# Look up shields to get their provider_resource_id (actual model ID)
|
# Look up shields to get their provider_resource_id (actual model ID)
|
||||||
model_ids = []
|
model_ids = []
|
||||||
# list_shields not in Safety interface but available at runtime via API routing
|
# TODO: list_shields not in Safety interface but available at runtime via API routing
|
||||||
shields_list = await safety_api.list_shields() # type: ignore[attr-defined]
|
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
|
||||||
|
|
||||||
for guardrail_id in guardrail_ids:
|
for guardrail_id in guardrail_ids:
|
||||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||||
|
|
@ -348,9 +348,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
||||||
for result in response.results:
|
for result in response.results:
|
||||||
if result.flagged:
|
if result.flagged:
|
||||||
message = result.user_message or "Content blocked by safety guardrails"
|
message = result.user_message or "Content blocked by safety guardrails"
|
||||||
flagged_categories = [
|
flagged_categories = (
|
||||||
cat for cat, flagged in result.categories.items() if flagged
|
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
|
||||||
] if result.categories else []
|
)
|
||||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||||
|
|
||||||
if flagged_categories:
|
if flagged_categories:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue