mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? Add Tavily as a built-in search tool, in addition to Brave and Bing. ## Test Plan It's tested using ollama remote, showing parity to the Brave search tool. - Install and run ollama with `ollama run llama3.1:8b-instruct-fp16` - Build ollama distribution `llama stack build --template ollama --image-type conda` - Run ollama `stack run /$USER/.llama/distributions/llamastack-ollama/ollama-run.yaml --port 5001` - Client test command: `python - m agents.test_agents.TestAgents.test_create_agent_turn_with_tavily_search`, with enviroments: MASTER_ADDR=0.0.0.0;MASTER_PORT=5001;RANK=0;REMOTE_STACK_HOST=0.0.0.0;REMOTE_STACK_PORT=5001;TAVILY_SEARCH_API_KEY=tvly-<YOUR-KEY>;WORLD_SIZE=1 Test passes on the specific case (ollama remote). Server output: ``` Listening on ['::', '0.0.0.0']:5001 INFO: Started server process [7220] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5001 (Press CTRL+C to quit) INFO: 127.0.0.1:65209 - "POST /agents/create HTTP/1.1" 200 OK INFO: 127.0.0.1:65210 - "POST /agents/session/create HTTP/1.1" 200 OK INFO: 127.0.0.1:65211 - "POST /agents/turn/create HTTP/1.1" 200 OK role='user' content='What are the latest developments in quantum computing?' context=None role='assistant' content='' stop_reason=<StopReason.end_of_turn: 'end_of_turn'> tool_calls=[ToolCall(call_id='fc92ccb8-1039-4ce8-ba5e-8f2b0147661c', tool_name=<BuiltinTool.brave_search: 'brave_search'>, arguments={'query': 'latest developments in quantum computing'})] role='ipython' call_id='fc92ccb8-1039-4ce8-ba5e-8f2b0147661c' tool_name=<BuiltinTool.brave_search: 'brave_search'> content='{"query": "latest developments in quantum computing", "top_k": [{"title": "IBM Unveils 400 Qubit-Plus Quantum Processor and Next-Generation IBM ...", "url": "https://newsroom.ibm.com/2022-11-09-IBM-Unveils-400-Qubit-Plus-Quantum-Processor-and-Next-Generation-IBM-Quantum-System-Two", "content": "This system is targeted to be online by the end of 2023 and will be a building b...<more>...onnect large-scale ...", "url": "https://news.mit.edu/2023/quantum-interconnects-photon-emission-0105", "content": "Quantum computers hold the promise of performing certain tasks that are intractable even on the world\'s most powerful supercomputers. In the future, scientists anticipate using quantum computing to emulate materials systems, simulate quantum chemistry, and optimize hard tasks, with impacts potentially spanning finance to pharmaceuticals.", "score": 0.71721, "raw_content": null}]}' Assistant: The latest developments in quantum computing include: * IBM unveiling its 400 qubit-plus quantum processor and next-generation IBM Quantum System Two, which will be a building block of quantum-centric supercomputing. * The development of utility-scale quantum computing, which can serve as a scientific tool to explore utility-scale classes of problems in chemistry, physics, and materials beyond brute force classical simulation of quantum mechanics. * The introduction of advanced hardware across IBM's global fleet of 100+ qubit systems, as well as easy-to-use software that users and computational scientists can now obtain reliable results from quantum systems as they map increasingly larger and more complex problems to quantum circuits. * Research on quantum repeaters, which use defects in diamond to interconnect quantum systems and could provide the foundation for scalable quantum networking. * The development of a new source of quantum light, which could be used to improve the efficiency of quantum computers. * The creation of a new mathematical "blueprint" that is accelerating fusion device development using Dyson maps. * Research on canceling noise to improve quantum devices, with MIT researchers developing a protocol to extend the life of quantum coherence. ``` Verified with tool response. The final model response is updated with the search requests. ## Sources ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [x] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. Co-authored-by: Martin Yuan <myuan@meta.com>
475 lines
13 KiB
Python
475 lines
13 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Protocol,
|
|
runtime_checkable,
|
|
Union,
|
|
)
|
|
|
|
from llama_models.schema_utils import json_schema_type, webmethod
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
|
from llama_stack.apis.inference import * # noqa: F403
|
|
from llama_stack.apis.safety import * # noqa: F403
|
|
from llama_stack.apis.memory import * # noqa: F403
|
|
|
|
|
|
@json_schema_type
|
|
class Attachment(BaseModel):
|
|
content: InterleavedTextMedia | URL
|
|
mime_type: str
|
|
|
|
|
|
class AgentTool(Enum):
|
|
brave_search = "brave_search"
|
|
wolfram_alpha = "wolfram_alpha"
|
|
photogen = "photogen"
|
|
code_interpreter = "code_interpreter"
|
|
|
|
function_call = "function_call"
|
|
memory = "memory"
|
|
|
|
|
|
class ToolDefinitionCommon(BaseModel):
|
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
|
|
|
|
|
class SearchEngineType(Enum):
|
|
bing = "bing"
|
|
brave = "brave"
|
|
tavily = "tavily"
|
|
|
|
|
|
@json_schema_type
|
|
class SearchToolDefinition(ToolDefinitionCommon):
|
|
# NOTE: brave_search is just a placeholder since model always uses
|
|
# brave_search as tool call name
|
|
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
|
api_key: str
|
|
engine: SearchEngineType = SearchEngineType.brave
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
|
api_key: str
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class PhotogenToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
|
enable_inline_code_execution: bool = True
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
|
function_name: str
|
|
description: str
|
|
parameters: Dict[str, ToolParamDefinition]
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
class _MemoryBankConfigCommon(BaseModel):
|
|
bank_id: str
|
|
|
|
|
|
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
|
|
|
|
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
|
keys: List[str] # what keys to focus on
|
|
|
|
|
|
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
|
|
|
|
|
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
entities: List[str] # what entities to focus on
|
|
|
|
|
|
MemoryBankConfig = Annotated[
|
|
Union[
|
|
AgentVectorMemoryBankConfig,
|
|
AgentKeyValueMemoryBankConfig,
|
|
AgentKeywordMemoryBankConfig,
|
|
AgentGraphMemoryBankConfig,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
class MemoryQueryGenerator(Enum):
|
|
default = "default"
|
|
llm = "llm"
|
|
custom = "custom"
|
|
|
|
|
|
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
|
type: Literal[MemoryQueryGenerator.default.value] = (
|
|
MemoryQueryGenerator.default.value
|
|
)
|
|
sep: str = " "
|
|
|
|
|
|
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
|
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
|
model: str
|
|
template: str
|
|
|
|
|
|
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
|
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
|
|
|
|
|
MemoryQueryGeneratorConfig = Annotated[
|
|
Union[
|
|
DefaultMemoryQueryGeneratorConfig,
|
|
LLMMemoryQueryGeneratorConfig,
|
|
CustomMemoryQueryGeneratorConfig,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class MemoryToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
|
# This config defines how a query is generated using the messages
|
|
# for memory bank retrieval.
|
|
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
|
default=DefaultMemoryQueryGeneratorConfig()
|
|
)
|
|
max_tokens_in_context: int = 4096
|
|
max_chunks: int = 10
|
|
|
|
|
|
AgentToolDefinition = Annotated[
|
|
Union[
|
|
SearchToolDefinition,
|
|
WolframAlphaToolDefinition,
|
|
PhotogenToolDefinition,
|
|
CodeInterpreterToolDefinition,
|
|
FunctionCallToolDefinition,
|
|
MemoryToolDefinition,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
class StepCommon(BaseModel):
|
|
turn_id: str
|
|
step_id: str
|
|
started_at: Optional[datetime] = None
|
|
completed_at: Optional[datetime] = None
|
|
|
|
|
|
class StepType(Enum):
|
|
inference = "inference"
|
|
tool_execution = "tool_execution"
|
|
shield_call = "shield_call"
|
|
memory_retrieval = "memory_retrieval"
|
|
|
|
|
|
@json_schema_type
|
|
class InferenceStep(StepCommon):
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
|
model_response: CompletionMessage
|
|
|
|
|
|
@json_schema_type
|
|
class ToolExecutionStep(StepCommon):
|
|
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
|
tool_calls: List[ToolCall]
|
|
tool_responses: List[ToolResponse]
|
|
|
|
|
|
@json_schema_type
|
|
class ShieldCallStep(StepCommon):
|
|
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
|
violation: Optional[SafetyViolation]
|
|
|
|
|
|
@json_schema_type
|
|
class MemoryRetrievalStep(StepCommon):
|
|
step_type: Literal[StepType.memory_retrieval.value] = (
|
|
StepType.memory_retrieval.value
|
|
)
|
|
memory_bank_ids: List[str]
|
|
inserted_context: InterleavedTextMedia
|
|
|
|
|
|
Step = Annotated[
|
|
Union[
|
|
InferenceStep,
|
|
ToolExecutionStep,
|
|
ShieldCallStep,
|
|
MemoryRetrievalStep,
|
|
],
|
|
Field(discriminator="step_type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class Turn(BaseModel):
|
|
"""A single turn in an interaction with an Agentic System."""
|
|
|
|
turn_id: str
|
|
session_id: str
|
|
input_messages: List[
|
|
Union[
|
|
UserMessage,
|
|
ToolResponseMessage,
|
|
]
|
|
]
|
|
steps: List[Step]
|
|
output_message: CompletionMessage
|
|
output_attachments: List[Attachment] = Field(default_factory=list)
|
|
|
|
started_at: datetime
|
|
completed_at: Optional[datetime] = None
|
|
|
|
|
|
@json_schema_type
|
|
class Session(BaseModel):
|
|
"""A single session of an interaction with an Agentic System."""
|
|
|
|
session_id: str
|
|
session_name: str
|
|
turns: List[Turn]
|
|
started_at: datetime
|
|
|
|
memory_bank: Optional[MemoryBank] = None
|
|
|
|
|
|
class AgentConfigCommon(BaseModel):
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
|
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
|
|
|
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
default=ToolPromptFormat.json
|
|
)
|
|
|
|
max_infer_iters: int = 10
|
|
|
|
|
|
@json_schema_type
|
|
class AgentConfig(AgentConfigCommon):
|
|
model: str
|
|
instructions: str
|
|
enable_session_persistence: bool
|
|
|
|
|
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
|
instructions: Optional[str] = None
|
|
|
|
|
|
class AgentTurnResponseEventType(Enum):
|
|
step_start = "step_start"
|
|
step_complete = "step_complete"
|
|
step_progress = "step_progress"
|
|
|
|
turn_start = "turn_start"
|
|
turn_complete = "turn_complete"
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseStepStartPayload(BaseModel):
|
|
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
|
|
AgentTurnResponseEventType.step_start.value
|
|
)
|
|
step_type: StepType
|
|
step_id: str
|
|
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseStepCompletePayload(BaseModel):
|
|
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
|
|
AgentTurnResponseEventType.step_complete.value
|
|
)
|
|
step_type: StepType
|
|
step_details: Step
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseStepProgressPayload(BaseModel):
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
|
|
AgentTurnResponseEventType.step_progress.value
|
|
)
|
|
step_type: StepType
|
|
step_id: str
|
|
|
|
model_response_text_delta: Optional[str] = None
|
|
tool_call_delta: Optional[ToolCallDelta] = None
|
|
tool_response_text_delta: Optional[str] = None
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseTurnStartPayload(BaseModel):
|
|
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
|
|
AgentTurnResponseEventType.turn_start.value
|
|
)
|
|
turn_id: str
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
|
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
|
|
AgentTurnResponseEventType.turn_complete.value
|
|
)
|
|
turn: Turn
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseEvent(BaseModel):
|
|
"""Streamed agent execution response."""
|
|
|
|
payload: Annotated[
|
|
Union[
|
|
AgentTurnResponseStepStartPayload,
|
|
AgentTurnResponseStepProgressPayload,
|
|
AgentTurnResponseStepCompletePayload,
|
|
AgentTurnResponseTurnStartPayload,
|
|
AgentTurnResponseTurnCompletePayload,
|
|
],
|
|
Field(discriminator="event_type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class AgentCreateResponse(BaseModel):
|
|
agent_id: str
|
|
|
|
|
|
@json_schema_type
|
|
class AgentSessionCreateResponse(BaseModel):
|
|
session_id: str
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|
agent_id: str
|
|
session_id: str
|
|
|
|
# TODO: figure out how we can simplify this and make why
|
|
# ToolResponseMessage needs to be here (it is function call
|
|
# execution from outside the system)
|
|
messages: List[
|
|
Union[
|
|
UserMessage,
|
|
ToolResponseMessage,
|
|
]
|
|
]
|
|
attachments: Optional[List[Attachment]] = None
|
|
|
|
stream: Optional[bool] = False
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseStreamChunk(BaseModel):
|
|
"""streamed agent turn completion response."""
|
|
|
|
event: AgentTurnResponseEvent
|
|
|
|
|
|
@json_schema_type
|
|
class AgentStepResponse(BaseModel):
|
|
step: Step
|
|
|
|
|
|
@runtime_checkable
|
|
class Agents(Protocol):
|
|
@webmethod(route="/agents/create")
|
|
async def create_agent(
|
|
self,
|
|
agent_config: AgentConfig,
|
|
) -> AgentCreateResponse: ...
|
|
|
|
@webmethod(route="/agents/turn/create")
|
|
async def create_agent_turn(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
messages: List[
|
|
Union[
|
|
UserMessage,
|
|
ToolResponseMessage,
|
|
]
|
|
],
|
|
attachments: Optional[List[Attachment]] = None,
|
|
stream: Optional[bool] = False,
|
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
|
|
|
@webmethod(route="/agents/turn/get")
|
|
async def get_agents_turn(
|
|
self, agent_id: str, session_id: str, turn_id: str
|
|
) -> Turn: ...
|
|
|
|
@webmethod(route="/agents/step/get")
|
|
async def get_agents_step(
|
|
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
|
) -> AgentStepResponse: ...
|
|
|
|
@webmethod(route="/agents/session/create")
|
|
async def create_agent_session(
|
|
self,
|
|
agent_id: str,
|
|
session_name: str,
|
|
) -> AgentSessionCreateResponse: ...
|
|
|
|
@webmethod(route="/agents/session/get")
|
|
async def get_agents_session(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
turn_ids: Optional[List[str]] = None,
|
|
) -> Session: ...
|
|
|
|
@webmethod(route="/agents/session/delete")
|
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
|
|
|
|
@webmethod(route="/agents/delete")
|
|
async def delete_agents(
|
|
self,
|
|
agent_id: str,
|
|
) -> None: ...
|