mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
Address PR feedback: improve code clarity and fix AllowedToolsFilter bug
- streaming.py: Extract has_tool_calls boolean for readability - streaming.py: Replace nested function checks with assertions - streaming.py: Fix AllowedToolsFilter to use tool_names instead of allowed/disallowed - streaming.py: Add comment explaining tool_context can be None - streaming.py, utils.py: Clarify Pydantic/dict compatibility comments - utils.py: Document list invariance vs Sequence covariance in type signature - utils.py: Clarify list_shields runtime availability comment
This commit is contained in:
parent
84d78ff48a
commit
53c6f846d4
2 changed files with 24 additions and 13 deletions
|
|
@ -229,7 +229,8 @@ class StreamingResponseOrchestrator:
|
|||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools, # type: ignore[arg-type] # ChatCompletionFunctionToolParam compatible with expected dict type
|
||||
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
|
|
@ -272,7 +273,12 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (isinstance(choice.message, OpenAIAssistantMessageParam) and choice.message.tool_calls and self.ctx.response_tools):
|
||||
has_tool_calls = (
|
||||
isinstance(choice.message, OpenAIAssistantMessageParam)
|
||||
and choice.message.tool_calls
|
||||
and self.ctx.response_tools
|
||||
)
|
||||
if not has_tool_calls:
|
||||
output_messages.append(
|
||||
await convert_chat_choice_to_response_message(
|
||||
choice,
|
||||
|
|
@ -723,8 +729,9 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Accumulate arguments for final response (only for subsequent chunks)
|
||||
if not is_new_tool_call and response_tool_call is not None:
|
||||
# Need to check function is not None
|
||||
if response_tool_call.function and tool_call.function:
|
||||
# Both should have functions since we're inside the tool_call.function check above
|
||||
assert response_tool_call.function is not None
|
||||
assert tool_call.function is not None
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
|
@ -1012,7 +1019,7 @@ class StreamingResponseOrchestrator:
|
|||
description=tool.description,
|
||||
input_schema=tool.input_schema,
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Dict compatible with ChatCompletionFunctionToolParam
|
||||
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
# Initialize chat_tools if not already set
|
||||
if self.ctx.chat_tools is None:
|
||||
|
|
@ -1059,8 +1066,8 @@ class StreamingResponseOrchestrator:
|
|||
if isinstance(mcp_tool.allowed_tools, list):
|
||||
always_allowed = mcp_tool.allowed_tools
|
||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = mcp_tool.allowed_tools.allowed # type: ignore[attr-defined]
|
||||
never_allowed = mcp_tool.allowed_tools.disallowed # type: ignore[attr-defined]
|
||||
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
|
||||
always_allowed = mcp_tool.allowed_tools.tool_names
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = None
|
||||
|
|
@ -1092,7 +1099,7 @@ class StreamingResponseOrchestrator:
|
|||
openai_tool = convert_tooldef_to_chat_tool(t)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Dict compatible with ChatCompletionFunctionToolParam
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
# Add to MCP tool mapping
|
||||
if t.name in self.mcp_tool_to_server:
|
||||
|
|
@ -1124,6 +1131,7 @@ class StreamingResponseOrchestrator:
|
|||
self, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Handle all mcp tool lists from previous response that are still valid:
|
||||
# tool_context can be None when no tools are provided in the response request
|
||||
if self.ctx.tool_context:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
|
|
@ -1225,7 +1233,7 @@ class StreamingResponseOrchestrator:
|
|||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Dict compatible with ChatCompletionFunctionToolParam
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
|
|
|
|||
|
|
@ -82,8 +82,10 @@ async def convert_chat_choice_to_response_message(
|
|||
async def convert_response_content_to_chat_content(
|
||||
content: (
|
||||
str
|
||||
# List types for exact matches (invariant)
|
||||
| list[OpenAIResponseInputMessageContent]
|
||||
| list[OpenAIResponseOutputMessageContent]
|
||||
# Sequence for mixed content types (covariant - accepts list of subtypes)
|
||||
| Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent]
|
||||
),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
|
|
@ -335,7 +337,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.list_shields() # type: ignore[attr-defined] # Safety API routing_table access
|
||||
# list_shields not in Safety interface but available at runtime via API routing
|
||||
shields_list = await safety_api.list_shields() # type: ignore[attr-defined]
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue