llama-stack/llama_stack/providers/inline/agents/meta_reference/agents.py
ehhuang 3922999118
sys_prompt support in Agent (#938)
# What does this PR do?

The current default system prompt for llama3.2 tends to overindex on
tool calling and doesn't work well when the prompt does not require tool
calling.

This PR adds an option to override the default system prompt, and
organizes tool-related configs into a new config object.

- [ ] Addresses issue (#issue)


## Test Plan


LLAMA_STACK_CONFIG=together pytest
\-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v
tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2025-02-05 21:11:32 -08:00

220 lines
7.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import shutil
import tempfile
import uuid
from typing import AsyncGenerator, List, Optional, Union
from termcolor import colored
from llama_stack.apis.agents import (
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
Document,
Session,
Turn,
)
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class MetaReferenceAgentsImpl(Agents):
def __init__(
self,
config: MetaReferenceAgentsImplConfig,
inference_api: Inference,
vector_io_api: VectorIO,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
):
self.config = config
self.inference_api = inference_api
self.vector_io_api = vector_io_api
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
# check if "bwrap" is available
if not shutil.which("bwrap"):
print(
colored(
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
"yellow",
)
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
if agent_config.tool_config is None:
agent_config.tool_config = ToolConfig(
tool_choice=agent_config.tool_choice,
tool_prompt_format=agent_config.tool_prompt_format,
)
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def get_agent(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_config:
raise ValueError(f"Could not find agent config for {agent_id}")
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e:
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
tempdir=self.tempdir,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
),
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session_id,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
stream=True,
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
steps = turn.steps
for step in steps:
if step.step_id == step_id:
return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found")
async def get_agents_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
session = Session(**json.loads(session), turns=[])
turns = []
if turn_ids:
for turn_id in turn_ids:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
turns.append(turn)
return Session(
session_name=session.session_name,
session_id=session_id,
turns=turns if turns else [],
started_at=session.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
async def delete_agent(self, agent_id: str) -> None:
await self.persistence_store.delete(f"agent:{agent_id}")