forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -74,9 +74,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||
)
|
||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
|
@ -153,9 +151,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
|
@ -206,14 +202,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
):
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
yield chunk
|
||||
|
@ -388,9 +379,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
|
||||
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
|
@ -408,9 +397,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
vector_db_ids = args.get("vector_db_ids", [])
|
||||
query_config = args.get("query_config")
|
||||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(
|
||||
query_config
|
||||
)
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
# handle someone passing an empty dict
|
||||
query_config = RAGQueryConfig()
|
||||
|
@ -438,9 +425,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.rag_tool.query(
|
||||
content=concat_interleaved_content(
|
||||
[msg.content for msg in input_messages]
|
||||
),
|
||||
content=concat_interleaved_content([msg.content for msg in input_messages]),
|
||||
vector_db_ids=vector_db_ids,
|
||||
query_config=query_config,
|
||||
)
|
||||
|
@ -472,9 +457,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
||||
span.set_attribute("output", retrieved_context)
|
||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||
|
||||
|
@ -511,9 +494,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=[
|
||||
tool
|
||||
for tool in tool_defs.values()
|
||||
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
||||
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
||||
],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
response_format=self.agent_config.response_format,
|
||||
|
@ -560,12 +541,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
span.set_attribute("stop_reason", stop_reason)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute(
|
||||
"output", f"content: {content} tool_calls: {tool_calls}"
|
||||
)
|
||||
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
||||
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
|
||||
|
||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||
|
||||
|
@ -667,9 +644,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), "Currently not supporting multiple messages"
|
||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
|
@ -697,9 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
|
||||
if out_attachment := _interpret_content_as_attachment(
|
||||
result_message.content
|
||||
):
|
||||
if out_attachment := _interpret_content_as_attachment(result_message.content):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
|
@ -714,22 +687,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
agent_config_toolgroups = set(
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
for toolgroup in self.agent_config.toolgroups
|
||||
)
|
||||
toolgroups_for_turn_set = (
|
||||
agent_config_toolgroups
|
||||
if toolgroups_for_turn is None
|
||||
else {
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
for toolgroup in toolgroups_for_turn
|
||||
}
|
||||
)
|
||||
|
@ -759,10 +724,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
continue
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
for tool_def in tools.data:
|
||||
if (
|
||||
toolgroup_name.startswith("builtin")
|
||||
and toolgroup_name != RAG_TOOL_GROUP
|
||||
):
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
tool_name = tool_def.identifier
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
if tool_name == "web_search":
|
||||
|
@ -773,9 +735,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if tool_def_map.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_def_map[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type
|
||||
)
|
||||
tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
|
||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
|
@ -821,9 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
if code_interpreter_tool:
|
||||
for c in content_items:
|
||||
temp_file_path = os.path.join(
|
||||
self.tempdir, f"{make_random_string()}.txt"
|
||||
)
|
||||
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(c.content)
|
||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
@ -849,8 +807,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# we try to load the data from the URLs and content items as a message to inference
|
||||
# and add it to the last message's context
|
||||
input_messages[-1].context = "\n".join(
|
||||
[doc.content for doc in content_items]
|
||||
+ await load_data_from_urls(url_items)
|
||||
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
|
||||
)
|
||||
|
||||
async def _ensure_vector_db(self, session_id: str) -> str:
|
||||
|
@ -874,9 +831,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
return vector_db_id
|
||||
|
||||
async def add_to_session_vector_db(
|
||||
self, session_id: str, data: List[Document]
|
||||
) -> None:
|
||||
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
|
||||
vector_db_id = await self._ensure_vector_db(session_id)
|
||||
documents = [
|
||||
RAGDocument(
|
||||
|
@ -931,11 +886,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(
|
||||
TextContentItem(
|
||||
text=f'# There is a file accessible to you at "{filepath}"\n'
|
||||
)
|
||||
)
|
||||
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
|
||||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue