address feedback

This commit is contained in:
Dinesh Yeduguru 2025-01-02 18:42:20 -08:00
parent ee542a7373
commit 16d1f66f55
9 changed files with 286 additions and 149 deletions

View file

@ -18,7 +18,7 @@ from typing import (
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
@ -132,14 +132,27 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None
class AgentToolWithArgs(BaseModel):
name: str
args: Dict[str, Any]
AgentTool = register_schema(
Union[
str,
AgentToolWithArgs,
],
name="AgentTool",
)
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tool_names: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentTool]] = Field(default_factory=list)
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -295,6 +308,7 @@ class Agents(Protocol):
]
],
stream: Optional[bool] = False,
tools: Optional[List[AgentTool]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")