InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool

This commit is contained in:
Ashwin Bharambe 2024-08-22 17:44:56 -07:00
parent 48c6a32edd
commit 31289e3f47
5 changed files with 56 additions and 40 deletions

View file

@ -10,7 +10,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, validator
from typing_extensions import Annotated
from llama_toolchain.common.deployment_types import * # noqa: F403
@ -20,7 +20,31 @@ from llama_toolchain.memory.api.datatypes import * # noqa: F403
@json_schema_type
class AgenticSystemToolDefinition(ToolDefinition):
class Attachment(BaseModel):
url: URL
mime_type: str
class AgenticSystemBuiltinTool(BuiltinTool):
memory = "memory"
@json_schema_type
class AgenticSystemToolDefinition(BaseModel):
tool_name: Union[AgenticSystemBuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@validator("tool_name", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return AgenticSystemBuiltinTool(v)
except ValueError:
return v
return v
execution_config: Optional[RestAPIExecutionConfig] = None
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@ -124,11 +148,7 @@ class MemoryConfig(BaseModel):
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot or built-in tool configurations as input to the model
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
default_factory=list
)
memory: Optional[List[MemoryConfig]] = Field(default_factory=list)
memory_configs: Optional[List[MemoryConfig]] = Field(default_factory=list)
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@ -136,6 +156,9 @@ class AgentConfigCommon(BaseModel):
# if you completely want to replace the messages prefixed by the system,
# this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)

View file

@ -46,7 +46,7 @@ class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn):
ToolResponseMessage,
]
]
attachments: List[Attachment]
attachments: Optional[List[Attachment]] = None
stream: Optional[bool] = False