pre-commit fixes

This commit is contained in:
Chantal D Gama Rose 2025-03-14 13:56:05 -07:00
parent 967dd0aa08
commit 7e211f8553
314 changed files with 5574 additions and 11369 deletions

View file

@ -12,6 +12,7 @@ import uuid
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
@ -21,12 +22,15 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
Session,
Turn,
)
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
@ -83,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id=agent_id,
)
async def get_agent(self, agent_id: str) -> ChatAgent:
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
@ -119,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
@ -140,7 +144,6 @@ class MetaReferenceAgentsImpl(Agents):
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -150,7 +153,6 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
allow_turn_resume=allow_turn_resume,
)
if stream:
return self._create_agent_turn_streaming(request)
@ -161,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
@ -170,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
@ -189,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn
@ -211,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -233,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None:
pass
async def list_agents(self) -> ListAgentsResponse:
pass
async def get_agent(self, agent_id: str) -> Agent:
pass
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
pass