llama-stack/llama_toolchain/agentic_system/api/api.py
Ashwin Bharambe 7bc7785b0d
API Updates: fleshing out RAG APIs, introduce "llama stack" CLI command (#51)
* add tools to chat completion request

* use templates for generating system prompts

* Moved ToolPromptFormat and jinja templates to llama_models.llama3.api

* <WIP> memory changes

- inlined AgenticSystemInstanceConfig so API feels more ergonomic
- renamed it to AgentConfig, AgentInstance -> Agent
- added a MemoryConfig and `memory` parameter
- added `attachments` to input and `output_attachments` to the response

- some naming changes

* InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool

* flesh out memory banks API

* agentic loop has a RAG implementation

* faiss provider implementation

* memory client works

* re-work tool definitions, fix FastAPI issues, fix tool regressions

* fix agentic_system utils

* basic RAG seems to work

* small bug fixes for inline attachments

* Refactor custom tool execution utilities

* Bug fix, show memory retrieval steps in EventLogger

* No need for api_key for Remote providers

* add special unicode character ↵ to showcase newlines in model prompt templates

* remove api.endpoints imports

* combine datatypes.py and endpoints.py into api.py

* Attachment / add TTL api

* split batch_inference from inference

* minor import fixes

* use a single impl for ChatFormat.decode_assistant_mesage

* use interleaved_text_media_as_str() utilityt

* Fix api.datatypes imports

* Add blobfile for tiktoken

* Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly

* templates take optional --format={json,function_tag}

* Rag Updates

* Add `api build` subcommand -- WIP

* fix

* build + run image seems to work

* <WIP> adapters

* bunch more work to make adapters work

* api build works for conda now

* ollama remote adapter works

* Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight

* llama distribution -> llama stack + containers (WIP)

* All the new CLI for api + stack work

* Make Fireworks and Together into the Adapter format

* Some quick fixes to the CLI behavior to make it consistent

* Updated README phew

* Update cli_reference.md

* llama_toolchain/distribution -> llama_toolchain/core

* Add termcolor

* update paths

* Add a log just for consistency

* chmod +x scripts

* Fix api dependencies not getting added to configuration

* missing import lol

* Delete utils.py; move to agentic system

* Support downloading of URLs for attachments for code interpreter

* Simplify and generalize `llama api build` yay

* Update `llama stack configure` to be very simple also

* Fix stack start

* Allow building an "adhoc" distribution

* Remote `llama api []` subcommands

* Fixes to llama stack commands and update docs

* Update documentation again and add error messages to llama stack start

* llama stack start -> llama stack run

* Change name of build for less confusion

* Add pyopenapi fork to the repository, update RFC assets

* Remove conflicting annotation

* Added a "--raw" option for model template printing

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
2024-09-03 22:39:39 -07:00

413 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
@json_schema_type
class Attachment(BaseModel):
content: InterleavedTextMedia | URL
mime_type: str
class AgenticSystemTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
function_call = "function_call"
memory = "memory"
class ToolDefinitionCommon(BaseModel):
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@json_schema_type
class BraveSearchToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.brave_search.value] = (
AgenticSystemTool.brave_search.value
)
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
AgenticSystemTool.wolfram_alpha.value
)
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.code_interpreter.value] = (
AgenticSystemTool.code_interpreter.value
)
enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.function_call.value] = (
AgenticSystemTool.function_call.value
)
function_name: str
description: str
parameters: Dict[str, ToolParamDefinition]
remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgenticSystemVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[
Union[
BraveSearchToolDefinition,
WolframAlphaToolDefinition,
PhotogenToolDefinition,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
MemoryToolDefinition,
],
Field(discriminator="type"),
]
class StepCommon(BaseModel):
turn_id: str
step_id: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class StepType(Enum):
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
model_response: CompletionMessage
@json_schema_type
class ToolExecutionStep(StepCommon):
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
@json_schema_type
class ShieldCallStep(StepCommon):
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
response: ShieldResponse
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value
)
memory_bank_ids: List[str]
inserted_context: InterleavedTextMedia
Step = Annotated[
Union[
InferenceStep,
ToolExecutionStep,
ShieldCallStep,
MemoryRetrievalStep,
],
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System."""
turn_id: str
session_id: str
input_messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
steps: List[Step]
output_message: CompletionMessage
output_attachments: List[Attachment] = Field(default_factory=list)
started_at: datetime
completed_at: Optional[datetime] = None
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System."""
session_id: str
session_name: str
turns: List[Turn]
started_at: datetime
memory_bank: Optional[MemoryBank] = None
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type
class AgentConfig(AgentConfigCommon):
model: str
instructions: str
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
class AgenticSystemTurnResponseEventType(Enum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
@json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value
)
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value
)
step_type: StepType
step_details: Step
@json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value
)
step_type: StepType
step_id: str
model_response_text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
tool_response_text_delta: Optional[str] = None
@json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value
)
turn_id: str
@json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value
)
turn: Turn
@json_schema_type
class AgenticSystemTurnResponseEvent(BaseModel):
"""Streamed agent execution response."""
payload: Annotated[
Union[
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
],
Field(discriminator="event_type"),
]
@json_schema_type
class AgenticSystemCreateResponse(BaseModel):
agent_id: str
@json_schema_type
class AgenticSystemSessionCreateResponse(BaseModel):
session_id: str
@json_schema_type
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
attachments: Optional[List[Attachment]] = None
stream: Optional[bool] = False
@json_schema_type
class AgenticSystemTurnResponseStreamChunk(BaseModel):
event: AgenticSystemTurnResponseEvent
@json_schema_type
class AgenticSystemStepResponse(BaseModel):
step: Step
class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create")
async def create_agentic_system(
self,
agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ...
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
) -> AgenticSystemTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn(
self,
agent_id: str,
turn_id: str,
) -> Turn: ...
@webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ...
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
self,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse: ...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session(
self, agent_id: str, session_id: str
) -> None: ...
@webmethod(route="/agentic_system/delete")
async def delete_agentic_system(
self,
agent_id: str,
) -> None: ...