mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:12:29 +00:00
agents to use tools api
This commit is contained in:
parent
596afc6497
commit
f90e9c2003
21 changed files with 538 additions and 329 deletions
|
|
@ -14,18 +14,16 @@ from typing import (
|
|||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
|
|
@ -40,7 +38,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
|
|
@ -110,85 +107,6 @@ class FunctionCallToolDefinition(ToolDefinitionCommon):
|
|||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["vector"] = "vector"
|
||||
|
||||
|
||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyvalue"] = "keyvalue"
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyword"] = "keyword"
|
||||
|
||||
|
||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["graph"] = "graph"
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
AgentVectorMemoryBankConfig,
|
||||
AgentKeyValueMemoryBankConfig,
|
||||
AgentKeywordMemoryBankConfig,
|
||||
AgentGraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 10
|
||||
|
||||
|
||||
AgentToolDefinition = Annotated[
|
||||
Union[
|
||||
SearchToolDefinition,
|
||||
|
|
@ -196,7 +114,6 @@ AgentToolDefinition = Annotated[
|
|||
PhotogenToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
FunctionCallToolDefinition,
|
||||
MemoryToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -295,7 +212,11 @@ class AgentConfigCommon(BaseModel):
|
|||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
||||
tools: Optional[List[AgentToolDefinition]] = Field(
|
||||
default_factory=list, deprecated=True
|
||||
)
|
||||
available_tools: Optional[List[str]] = Field(default_factory=list)
|
||||
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue