This commit is contained in:
Ashwin Bharambe 2025-10-28 16:22:13 -07:00
parent 1f5adff5a7
commit ec1bae78e6
2 changed files with 11 additions and 7 deletions

View file

@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
self.prompt = prompt self.prompt = prompt
self.sequence_number = 0 self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing # Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools if ctx.tool_context else {} self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
ctx.tool_context.previous_tools if ctx.tool_context else {}
)
# Track final messages after all tool executions # Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = [] self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations # mapping for annotations
@ -1138,7 +1140,9 @@ class StreamingResponseOrchestrator:
yield evt yield evt
# Process all remaining tools (including MCP tools) and emit streaming events # Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process: if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages): async for stream_event in self._process_new_tools(
self.ctx.tool_context.tools_to_process, output_messages
):
yield stream_event yield stream_event
def _approval_required(self, tool_name: str) -> bool: def _approval_required(self, tool_name: str) -> bool:

View file

@ -330,8 +330,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
# Look up shields to get their provider_resource_id (actual model ID) # Look up shields to get their provider_resource_id (actual model ID)
model_ids = [] model_ids = []
# list_shields not in Safety interface but available at runtime via API routing # TODO: list_shields not in Safety interface but available at runtime via API routing
shields_list = await safety_api.list_shields() # type: ignore[attr-defined] shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
for guardrail_id in guardrail_ids: for guardrail_id in guardrail_ids:
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id] matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
@ -348,9 +348,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
for result in response.results: for result in response.results:
if result.flagged: if result.flagged:
message = result.user_message or "Content blocked by safety guardrails" message = result.user_message or "Content blocked by safety guardrails"
flagged_categories = [ flagged_categories = (
cat for cat, flagged in result.categories.items() if flagged [cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
] if result.categories else [] )
violation_type = result.metadata.get("violation_type", []) if result.metadata else [] violation_type = result.metadata.get("violation_type", []) if result.metadata else []
if flagged_categories: if flagged_categories: