Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -74,9 +74,7 @@ log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
@ -153,9 +151,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
@ -206,14 +202,9 @@ class ChatAgent(ShieldRunnerMixin):
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
@ -388,9 +379,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
await self.handle_documents(session_id, documents, input_messages, tool_defs)
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
@ -408,9 +397,7 @@ class ChatAgent(ShieldRunnerMixin):
vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(
query_config
)
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
@ -438,9 +425,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content(
[msg.content for msg in input_messages]
),
content=concat_interleaved_content([msg.content for msg in input_messages]),
vector_db_ids=vector_db_ids,
query_config=query_config,
)
@ -472,9 +457,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
@ -511,9 +494,7 @@ class ChatAgent(ShieldRunnerMixin):
self.agent_config.model,
input_messages,
tools=[
tool
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tool_prompt_format=self.agent_config.tool_prompt_format,
response_format=self.agent_config.response_format,
@ -560,12 +541,8 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
stop_reason = stop_reason or StopReason.out_of_tokens
@ -667,9 +644,7 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_args,
tool_to_group,
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
@ -697,9 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := _interpret_content_as_attachment(
result_message.content
):
if out_attachment := _interpret_content_as_attachment(result_message.content):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
@ -714,22 +687,14 @@ class ChatAgent(ShieldRunnerMixin):
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
for toolgroup in self.agent_config.toolgroups
)
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None
else {
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
for toolgroup in toolgroups_for_turn
}
)
@ -759,10 +724,7 @@ class ChatAgent(ShieldRunnerMixin):
continue
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
for tool_def in tools.data:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != RAG_TOOL_GROUP
):
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search
if tool_name == "web_search":
@ -773,9 +735,7 @@ class ChatAgent(ShieldRunnerMixin):
if tool_def_map.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition(
tool_name=built_in_type
)
tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue
@ -821,9 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
@ -849,8 +807,7 @@ class ChatAgent(ShieldRunnerMixin):
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items]
+ await load_data_from_urls(url_items)
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
)
async def _ensure_vector_db(self, session_id: str) -> str:
@ -874,9 +831,7 @@ class ChatAgent(ShieldRunnerMixin):
return vector_db_id
async def add_to_session_vector_db(
self, session_id: str, data: List[Document]
) -> None:
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
vector_db_id = await self._ensure_vector_db(session_id)
documents = [
RAGDocument(
@ -931,11 +886,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else:
raise ValueError(f"Unsupported URL {url}")
content.append(
TextContentItem(
text=f'# There is a file accessible to you at "{filepath}"\n'
)
)
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
return ToolResponseMessage(
call_id="",