forked from phoenix-oss/llama-stack-mirror
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
227 lines
7.1 KiB
Python
227 lines
7.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import logging
|
|
import shutil
|
|
import tempfile
|
|
import uuid
|
|
from typing import AsyncGenerator, List, Optional, Union
|
|
|
|
from termcolor import colored
|
|
|
|
from llama_stack.apis.agents import (
|
|
AgentConfig,
|
|
AgentCreateResponse,
|
|
Agents,
|
|
AgentSessionCreateResponse,
|
|
AgentStepResponse,
|
|
AgentToolGroup,
|
|
AgentTurnCreateRequest,
|
|
Document,
|
|
Session,
|
|
Turn,
|
|
)
|
|
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
|
from llama_stack.apis.memory import Memory
|
|
from llama_stack.apis.memory_banks import MemoryBanks
|
|
from llama_stack.apis.safety import Safety
|
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
|
|
|
from .agent_instance import ChatAgent
|
|
from .config import MetaReferenceAgentsImplConfig
|
|
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
class MetaReferenceAgentsImpl(Agents):
|
|
def __init__(
|
|
self,
|
|
config: MetaReferenceAgentsImplConfig,
|
|
inference_api: Inference,
|
|
memory_api: Memory,
|
|
safety_api: Safety,
|
|
memory_banks_api: MemoryBanks,
|
|
tool_runtime_api: ToolRuntime,
|
|
tool_groups_api: ToolGroups,
|
|
):
|
|
self.config = config
|
|
self.inference_api = inference_api
|
|
self.memory_api = memory_api
|
|
self.safety_api = safety_api
|
|
self.memory_banks_api = memory_banks_api
|
|
self.tool_runtime_api = tool_runtime_api
|
|
self.tool_groups_api = tool_groups_api
|
|
|
|
self.in_memory_store = InmemoryKVStoreImpl()
|
|
self.tempdir = tempfile.mkdtemp()
|
|
|
|
async def initialize(self) -> None:
|
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
|
|
|
# check if "bwrap" is available
|
|
if not shutil.which("bwrap"):
|
|
print(
|
|
colored(
|
|
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
|
|
"yellow",
|
|
)
|
|
)
|
|
|
|
async def create_agent(
|
|
self,
|
|
agent_config: AgentConfig,
|
|
) -> AgentCreateResponse:
|
|
agent_id = str(uuid.uuid4())
|
|
|
|
await self.persistence_store.set(
|
|
key=f"agent:{agent_id}",
|
|
value=agent_config.model_dump_json(),
|
|
)
|
|
return AgentCreateResponse(
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
async def get_agent(self, agent_id: str) -> ChatAgent:
|
|
agent_config = await self.persistence_store.get(
|
|
key=f"agent:{agent_id}",
|
|
)
|
|
if not agent_config:
|
|
raise ValueError(f"Could not find agent config for {agent_id}")
|
|
|
|
try:
|
|
agent_config = json.loads(agent_config)
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError(
|
|
f"Could not JSON decode agent config for {agent_id}"
|
|
) from e
|
|
|
|
try:
|
|
agent_config = AgentConfig(**agent_config)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
f"Could not validate(?) agent config for {agent_id}"
|
|
) from e
|
|
|
|
return ChatAgent(
|
|
agent_id=agent_id,
|
|
agent_config=agent_config,
|
|
tempdir=self.tempdir,
|
|
inference_api=self.inference_api,
|
|
safety_api=self.safety_api,
|
|
memory_api=self.memory_api,
|
|
memory_banks_api=self.memory_banks_api,
|
|
tool_runtime_api=self.tool_runtime_api,
|
|
tool_groups_api=self.tool_groups_api,
|
|
persistence_store=(
|
|
self.persistence_store
|
|
if agent_config.enable_session_persistence
|
|
else self.in_memory_store
|
|
),
|
|
)
|
|
|
|
async def create_agent_session(
|
|
self,
|
|
agent_id: str,
|
|
session_name: str,
|
|
) -> AgentSessionCreateResponse:
|
|
agent = await self.get_agent(agent_id)
|
|
|
|
session_id = await agent.create_session(session_name)
|
|
return AgentSessionCreateResponse(
|
|
session_id=session_id,
|
|
)
|
|
|
|
async def create_agent_turn(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
messages: List[
|
|
Union[
|
|
UserMessage,
|
|
ToolResponseMessage,
|
|
]
|
|
],
|
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
|
documents: Optional[List[Document]] = None,
|
|
stream: Optional[bool] = False,
|
|
) -> AsyncGenerator:
|
|
request = AgentTurnCreateRequest(
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
messages=messages,
|
|
stream=True,
|
|
toolgroups=toolgroups,
|
|
documents=documents,
|
|
)
|
|
if stream:
|
|
return self._create_agent_turn_streaming(request)
|
|
else:
|
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
|
|
|
async def _create_agent_turn_streaming(
|
|
self,
|
|
request: AgentTurnCreateRequest,
|
|
) -> AsyncGenerator:
|
|
agent = await self.get_agent(request.agent_id)
|
|
async for event in agent.create_and_execute_turn(request):
|
|
yield event
|
|
|
|
async def get_agents_turn(
|
|
self, agent_id: str, session_id: str, turn_id: str
|
|
) -> Turn:
|
|
turn = await self.persistence_store.get(
|
|
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
)
|
|
turn = json.loads(turn)
|
|
turn = Turn(**turn)
|
|
return turn
|
|
|
|
async def get_agents_step(
|
|
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
|
) -> AgentStepResponse:
|
|
turn = await self.persistence_store.get(
|
|
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
)
|
|
turn = json.loads(turn)
|
|
turn = Turn(**turn)
|
|
steps = turn.steps
|
|
for step in steps:
|
|
if step.step_id == step_id:
|
|
return AgentStepResponse(step=step)
|
|
raise ValueError(f"Provided step_id {step_id} could not be found")
|
|
|
|
async def get_agents_session(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
turn_ids: Optional[List[str]] = None,
|
|
) -> Session:
|
|
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
|
|
session = Session(**json.loads(session), turns=[])
|
|
turns = []
|
|
if turn_ids:
|
|
for turn_id in turn_ids:
|
|
turn = await self.persistence_store.get(
|
|
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
)
|
|
turn = json.loads(turn)
|
|
turn = Turn(**turn)
|
|
turns.append(turn)
|
|
return Session(
|
|
session_name=session.session_name,
|
|
session_id=session_id,
|
|
turns=turns if turns else [],
|
|
started_at=session.started_at,
|
|
)
|
|
|
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
|
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
|
|
|
async def delete_agents(self, agent_id: str) -> None:
|
|
await self.persistence_store.delete(f"agent:{agent_id}")
|