mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
Address PR feedback: improve code clarity and fix AllowedToolsFilter bug
- streaming.py: Extract has_tool_calls boolean for readability - streaming.py: Replace nested function checks with assertions - streaming.py: Fix AllowedToolsFilter to use tool_names instead of allowed/disallowed - streaming.py: Add comment explaining tool_context can be None - streaming.py, utils.py: Clarify Pydantic/dict compatibility comments - utils.py: Document list invariance vs Sequence covariance in type signature - utils.py: Clarify list_shields runtime availability comment
This commit is contained in:
parent
84d78ff48a
commit
53c6f846d4
2 changed files with 24 additions and 13 deletions
|
|
@ -229,7 +229,8 @@ class StreamingResponseOrchestrator:
|
||||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||||
model=self.ctx.model,
|
model=self.ctx.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.ctx.chat_tools, # type: ignore[arg-type] # ChatCompletionFunctionToolParam compatible with expected dict type
|
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||||
|
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||||
stream=True,
|
stream=True,
|
||||||
temperature=self.ctx.temperature,
|
temperature=self.ctx.temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
|
@ -272,7 +273,12 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
# Handle choices with no tool calls
|
# Handle choices with no tool calls
|
||||||
for choice in current_response.choices:
|
for choice in current_response.choices:
|
||||||
if not (isinstance(choice.message, OpenAIAssistantMessageParam) and choice.message.tool_calls and self.ctx.response_tools):
|
has_tool_calls = (
|
||||||
|
isinstance(choice.message, OpenAIAssistantMessageParam)
|
||||||
|
and choice.message.tool_calls
|
||||||
|
and self.ctx.response_tools
|
||||||
|
)
|
||||||
|
if not has_tool_calls:
|
||||||
output_messages.append(
|
output_messages.append(
|
||||||
await convert_chat_choice_to_response_message(
|
await convert_chat_choice_to_response_message(
|
||||||
choice,
|
choice,
|
||||||
|
|
@ -723,11 +729,12 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
# Accumulate arguments for final response (only for subsequent chunks)
|
# Accumulate arguments for final response (only for subsequent chunks)
|
||||||
if not is_new_tool_call and response_tool_call is not None:
|
if not is_new_tool_call and response_tool_call is not None:
|
||||||
# Need to check function is not None
|
# Both should have functions since we're inside the tool_call.function check above
|
||||||
if response_tool_call.function and tool_call.function:
|
assert response_tool_call.function is not None
|
||||||
response_tool_call.function.arguments = (
|
assert tool_call.function is not None
|
||||||
response_tool_call.function.arguments or ""
|
response_tool_call.function.arguments = (
|
||||||
) + tool_call.function.arguments
|
response_tool_call.function.arguments or ""
|
||||||
|
) + tool_call.function.arguments
|
||||||
|
|
||||||
# Output Safety Validation for this chunk
|
# Output Safety Validation for this chunk
|
||||||
if self.guardrail_ids:
|
if self.guardrail_ids:
|
||||||
|
|
@ -1012,7 +1019,7 @@ class StreamingResponseOrchestrator:
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
input_schema=tool.input_schema,
|
input_schema=tool.input_schema,
|
||||||
)
|
)
|
||||||
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Dict compatible with ChatCompletionFunctionToolParam
|
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||||
|
|
||||||
# Initialize chat_tools if not already set
|
# Initialize chat_tools if not already set
|
||||||
if self.ctx.chat_tools is None:
|
if self.ctx.chat_tools is None:
|
||||||
|
|
@ -1059,8 +1066,8 @@ class StreamingResponseOrchestrator:
|
||||||
if isinstance(mcp_tool.allowed_tools, list):
|
if isinstance(mcp_tool.allowed_tools, list):
|
||||||
always_allowed = mcp_tool.allowed_tools
|
always_allowed = mcp_tool.allowed_tools
|
||||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||||
always_allowed = mcp_tool.allowed_tools.allowed # type: ignore[attr-defined]
|
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
|
||||||
never_allowed = mcp_tool.allowed_tools.disallowed # type: ignore[attr-defined]
|
always_allowed = mcp_tool.allowed_tools.tool_names
|
||||||
|
|
||||||
# Call list_mcp_tools
|
# Call list_mcp_tools
|
||||||
tool_defs = None
|
tool_defs = None
|
||||||
|
|
@ -1092,7 +1099,7 @@ class StreamingResponseOrchestrator:
|
||||||
openai_tool = convert_tooldef_to_chat_tool(t)
|
openai_tool = convert_tooldef_to_chat_tool(t)
|
||||||
if self.ctx.chat_tools is None:
|
if self.ctx.chat_tools is None:
|
||||||
self.ctx.chat_tools = []
|
self.ctx.chat_tools = []
|
||||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Dict compatible with ChatCompletionFunctionToolParam
|
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||||
|
|
||||||
# Add to MCP tool mapping
|
# Add to MCP tool mapping
|
||||||
if t.name in self.mcp_tool_to_server:
|
if t.name in self.mcp_tool_to_server:
|
||||||
|
|
@ -1124,6 +1131,7 @@ class StreamingResponseOrchestrator:
|
||||||
self, output_messages: list[OpenAIResponseOutput]
|
self, output_messages: list[OpenAIResponseOutput]
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Handle all mcp tool lists from previous response that are still valid:
|
# Handle all mcp tool lists from previous response that are still valid:
|
||||||
|
# tool_context can be None when no tools are provided in the response request
|
||||||
if self.ctx.tool_context:
|
if self.ctx.tool_context:
|
||||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||||
|
|
@ -1225,7 +1233,7 @@ class StreamingResponseOrchestrator:
|
||||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||||
if self.ctx.chat_tools is None:
|
if self.ctx.chat_tools is None:
|
||||||
self.ctx.chat_tools = []
|
self.ctx.chat_tools = []
|
||||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Dict compatible with ChatCompletionFunctionToolParam
|
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||||
|
|
||||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||||
id=f"mcp_list_{uuid.uuid4()}",
|
id=f"mcp_list_{uuid.uuid4()}",
|
||||||
|
|
|
||||||
|
|
@ -82,8 +82,10 @@ async def convert_chat_choice_to_response_message(
|
||||||
async def convert_response_content_to_chat_content(
|
async def convert_response_content_to_chat_content(
|
||||||
content: (
|
content: (
|
||||||
str
|
str
|
||||||
|
# List types for exact matches (invariant)
|
||||||
| list[OpenAIResponseInputMessageContent]
|
| list[OpenAIResponseInputMessageContent]
|
||||||
| list[OpenAIResponseOutputMessageContent]
|
| list[OpenAIResponseOutputMessageContent]
|
||||||
|
# Sequence for mixed content types (covariant - accepts list of subtypes)
|
||||||
| Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent]
|
| Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent]
|
||||||
),
|
),
|
||||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||||
|
|
@ -335,7 +337,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
||||||
|
|
||||||
# Look up shields to get their provider_resource_id (actual model ID)
|
# Look up shields to get their provider_resource_id (actual model ID)
|
||||||
model_ids = []
|
model_ids = []
|
||||||
shields_list = await safety_api.list_shields() # type: ignore[attr-defined] # Safety API routing_table access
|
# list_shields not in Safety interface but available at runtime via API routing
|
||||||
|
shields_list = await safety_api.list_shields() # type: ignore[attr-defined]
|
||||||
|
|
||||||
for guardrail_id in guardrail_ids:
|
for guardrail_id in guardrail_ids:
|
||||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue