mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? This is not part of the official OpenAI API, but we'll use this for the logs UI. In order to support more filtering options, I'm adopting the newly introduced sql store in in place of the kv store. ## Test Plan Added integration/unit tests.
339 lines
12 KiB
Python
339 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import logging
|
|
import uuid
|
|
from collections.abc import AsyncGenerator
|
|
from datetime import datetime, timezone
|
|
|
|
from llama_stack.apis.agents import (
|
|
Agent,
|
|
AgentConfig,
|
|
AgentCreateResponse,
|
|
Agents,
|
|
AgentSessionCreateResponse,
|
|
AgentStepResponse,
|
|
AgentToolGroup,
|
|
AgentTurnCreateRequest,
|
|
AgentTurnResumeRequest,
|
|
Document,
|
|
ListOpenAIResponseObject,
|
|
OpenAIResponseInput,
|
|
OpenAIResponseInputTool,
|
|
OpenAIResponseObject,
|
|
Order,
|
|
Session,
|
|
Turn,
|
|
)
|
|
from llama_stack.apis.common.responses import PaginatedResponse
|
|
from llama_stack.apis.inference import (
|
|
Inference,
|
|
ToolConfig,
|
|
ToolResponse,
|
|
ToolResponseMessage,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.apis.safety import Safety
|
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
|
from llama_stack.apis.vector_io import VectorIO
|
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
|
from llama_stack.providers.utils.pagination import paginate_records
|
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
|
|
|
from .agent_instance import ChatAgent
|
|
from .config import MetaReferenceAgentsImplConfig
|
|
from .openai_responses import OpenAIResponsesImpl
|
|
from .persistence import AgentInfo
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class MetaReferenceAgentsImpl(Agents):
|
|
def __init__(
|
|
self,
|
|
config: MetaReferenceAgentsImplConfig,
|
|
inference_api: Inference,
|
|
vector_io_api: VectorIO,
|
|
safety_api: Safety,
|
|
tool_runtime_api: ToolRuntime,
|
|
tool_groups_api: ToolGroups,
|
|
):
|
|
self.config = config
|
|
self.inference_api = inference_api
|
|
self.vector_io_api = vector_io_api
|
|
self.safety_api = safety_api
|
|
self.tool_runtime_api = tool_runtime_api
|
|
self.tool_groups_api = tool_groups_api
|
|
|
|
self.in_memory_store = InmemoryKVStoreImpl()
|
|
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
|
|
|
async def initialize(self) -> None:
|
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
|
self.responses_store = ResponsesStore(self.config.responses_store)
|
|
await self.responses_store.initialize()
|
|
self.openai_responses_impl = OpenAIResponsesImpl(
|
|
inference_api=self.inference_api,
|
|
tool_groups_api=self.tool_groups_api,
|
|
tool_runtime_api=self.tool_runtime_api,
|
|
responses_store=self.responses_store,
|
|
)
|
|
|
|
async def create_agent(
|
|
self,
|
|
agent_config: AgentConfig,
|
|
) -> AgentCreateResponse:
|
|
agent_id = str(uuid.uuid4())
|
|
created_at = datetime.now(timezone.utc)
|
|
|
|
agent_info = AgentInfo(
|
|
**agent_config.model_dump(),
|
|
created_at=created_at,
|
|
)
|
|
|
|
# Store the agent info
|
|
await self.persistence_store.set(
|
|
key=f"agent:{agent_id}",
|
|
value=agent_info.model_dump_json(),
|
|
)
|
|
|
|
return AgentCreateResponse(
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
|
agent_info_json = await self.persistence_store.get(
|
|
key=f"agent:{agent_id}",
|
|
)
|
|
if not agent_info_json:
|
|
raise ValueError(f"Could not find agent info for {agent_id}")
|
|
|
|
try:
|
|
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
|
except Exception as e:
|
|
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
|
|
|
return ChatAgent(
|
|
agent_id=agent_id,
|
|
agent_config=agent_info,
|
|
inference_api=self.inference_api,
|
|
safety_api=self.safety_api,
|
|
vector_io_api=self.vector_io_api,
|
|
tool_runtime_api=self.tool_runtime_api,
|
|
tool_groups_api=self.tool_groups_api,
|
|
persistence_store=(
|
|
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
|
),
|
|
created_at=agent_info.created_at,
|
|
)
|
|
|
|
async def create_agent_session(
|
|
self,
|
|
agent_id: str,
|
|
session_name: str,
|
|
) -> AgentSessionCreateResponse:
|
|
agent = await self._get_agent_impl(agent_id)
|
|
|
|
session_id = await agent.create_session(session_name)
|
|
return AgentSessionCreateResponse(
|
|
session_id=session_id,
|
|
)
|
|
|
|
async def create_agent_turn(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
messages: list[UserMessage | ToolResponseMessage],
|
|
toolgroups: list[AgentToolGroup] | None = None,
|
|
documents: list[Document] | None = None,
|
|
stream: bool | None = False,
|
|
tool_config: ToolConfig | None = None,
|
|
) -> AsyncGenerator:
|
|
request = AgentTurnCreateRequest(
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
messages=messages,
|
|
stream=True,
|
|
toolgroups=toolgroups,
|
|
documents=documents,
|
|
tool_config=tool_config,
|
|
)
|
|
if stream:
|
|
return self._create_agent_turn_streaming(request)
|
|
else:
|
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
|
|
|
async def _create_agent_turn_streaming(
|
|
self,
|
|
request: AgentTurnCreateRequest,
|
|
) -> AsyncGenerator:
|
|
agent = await self._get_agent_impl(request.agent_id)
|
|
async for event in agent.create_and_execute_turn(request):
|
|
yield event
|
|
|
|
async def resume_agent_turn(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
turn_id: str,
|
|
tool_responses: list[ToolResponse],
|
|
stream: bool | None = False,
|
|
) -> AsyncGenerator:
|
|
request = AgentTurnResumeRequest(
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
turn_id=turn_id,
|
|
tool_responses=tool_responses,
|
|
stream=stream,
|
|
)
|
|
if stream:
|
|
return self._continue_agent_turn_streaming(request)
|
|
else:
|
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
|
|
|
async def _continue_agent_turn_streaming(
|
|
self,
|
|
request: AgentTurnResumeRequest,
|
|
) -> AsyncGenerator:
|
|
agent = await self._get_agent_impl(request.agent_id)
|
|
async for event in agent.resume_turn(request):
|
|
yield event
|
|
|
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
|
agent = await self._get_agent_impl(agent_id)
|
|
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
|
return turn
|
|
|
|
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
|
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
|
for step in turn.steps:
|
|
if step.step_id == step_id:
|
|
return AgentStepResponse(step=step)
|
|
raise ValueError(f"Provided step_id {step_id} could not be found")
|
|
|
|
async def get_agents_session(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
turn_ids: list[str] | None = None,
|
|
) -> Session:
|
|
agent = await self._get_agent_impl(agent_id)
|
|
|
|
session_info = await agent.storage.get_session_info(session_id)
|
|
if session_info is None:
|
|
raise ValueError(f"Session {session_id} not found")
|
|
turns = await agent.storage.get_session_turns(session_id)
|
|
if turn_ids:
|
|
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
|
return Session(
|
|
session_name=session_info.session_name,
|
|
session_id=session_id,
|
|
turns=turns,
|
|
started_at=session_info.started_at,
|
|
)
|
|
|
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
|
agent = await self._get_agent_impl(agent_id)
|
|
session_info = await agent.storage.get_session_info(session_id)
|
|
if session_info is None:
|
|
raise ValueError(f"Session {session_id} not found")
|
|
|
|
# Delete turns first, then the session
|
|
await agent.storage.delete_session_turns(session_id)
|
|
await agent.storage.delete_session(session_id)
|
|
|
|
async def delete_agent(self, agent_id: str) -> None:
|
|
# First get all sessions for this agent
|
|
agent = await self._get_agent_impl(agent_id)
|
|
sessions = await agent.storage.list_sessions()
|
|
|
|
# Delete all sessions
|
|
for session in sessions:
|
|
await self.delete_agents_session(agent_id, session.session_id)
|
|
|
|
# Finally delete the agent itself
|
|
await self.persistence_store.delete(f"agent:{agent_id}")
|
|
|
|
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
|
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
|
agent_list: list[Agent] = []
|
|
for agent_key in agent_keys:
|
|
agent_id = agent_key.split(":")[1]
|
|
|
|
# Get the agent info using the key
|
|
agent_info_json = await self.persistence_store.get(agent_key)
|
|
if not agent_info_json:
|
|
logger.error(f"Could not find agent info for key {agent_key}")
|
|
continue
|
|
|
|
try:
|
|
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
|
agent_list.append(
|
|
Agent(
|
|
agent_id=agent_id,
|
|
agent_config=agent_info,
|
|
created_at=agent_info.created_at,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
|
continue
|
|
|
|
# Convert Agent objects to dictionaries
|
|
agent_dicts = [agent.model_dump() for agent in agent_list]
|
|
return paginate_records(agent_dicts, start_index, limit)
|
|
|
|
async def get_agent(self, agent_id: str) -> Agent:
|
|
chat_agent = await self._get_agent_impl(agent_id)
|
|
agent = Agent(
|
|
agent_id=agent_id,
|
|
agent_config=chat_agent.agent_config,
|
|
created_at=chat_agent.created_at,
|
|
)
|
|
return agent
|
|
|
|
async def list_agent_sessions(
|
|
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
|
) -> PaginatedResponse:
|
|
agent = await self._get_agent_impl(agent_id)
|
|
sessions = await agent.storage.list_sessions()
|
|
# Convert Session objects to dictionaries
|
|
session_dicts = [session.model_dump() for session in sessions]
|
|
return paginate_records(session_dicts, start_index, limit)
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
# OpenAI responses
|
|
async def get_openai_response(
|
|
self,
|
|
response_id: str,
|
|
) -> OpenAIResponseObject:
|
|
return await self.openai_responses_impl.get_openai_response(response_id)
|
|
|
|
async def create_openai_response(
|
|
self,
|
|
input: str | list[OpenAIResponseInput],
|
|
model: str,
|
|
instructions: str | None = None,
|
|
previous_response_id: str | None = None,
|
|
store: bool | None = True,
|
|
stream: bool | None = False,
|
|
temperature: float | None = None,
|
|
tools: list[OpenAIResponseInputTool] | None = None,
|
|
) -> OpenAIResponseObject:
|
|
return await self.openai_responses_impl.create_openai_response(
|
|
input, model, instructions, previous_response_id, store, stream, temperature, tools
|
|
)
|
|
|
|
async def list_openai_responses(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int | None = 50,
|
|
model: str | None = None,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIResponseObject:
|
|
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|