API Updates: fleshing out RAG APIs, introduce "llama stack" CLI command (#51)

* add tools to chat completion request

* use templates for generating system prompts

* Moved ToolPromptFormat and jinja templates to llama_models.llama3.api

* <WIP> memory changes

- inlined AgenticSystemInstanceConfig so API feels more ergonomic
- renamed it to AgentConfig, AgentInstance -> Agent
- added a MemoryConfig and `memory` parameter
- added `attachments` to input and `output_attachments` to the response

- some naming changes

* InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool

* flesh out memory banks API

* agentic loop has a RAG implementation

* faiss provider implementation

* memory client works

* re-work tool definitions, fix FastAPI issues, fix tool regressions

* fix agentic_system utils

* basic RAG seems to work

* small bug fixes for inline attachments

* Refactor custom tool execution utilities

* Bug fix, show memory retrieval steps in EventLogger

* No need for api_key for Remote providers

* add special unicode character ↵ to showcase newlines in model prompt templates

* remove api.endpoints imports

* combine datatypes.py and endpoints.py into api.py

* Attachment / add TTL api

* split batch_inference from inference

* minor import fixes

* use a single impl for ChatFormat.decode_assistant_mesage

* use interleaved_text_media_as_str() utilityt

* Fix api.datatypes imports

* Add blobfile for tiktoken

* Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly

* templates take optional --format={json,function_tag}

* Rag Updates

* Add `api build` subcommand -- WIP

* fix

* build + run image seems to work

* <WIP> adapters

* bunch more work to make adapters work

* api build works for conda now

* ollama remote adapter works

* Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight

* llama distribution -> llama stack + containers (WIP)

* All the new CLI for api + stack work

* Make Fireworks and Together into the Adapter format

* Some quick fixes to the CLI behavior to make it consistent

* Updated README phew

* Update cli_reference.md

* llama_toolchain/distribution -> llama_toolchain/core

* Add termcolor

* update paths

* Add a log just for consistency

* chmod +x scripts

* Fix api dependencies not getting added to configuration

* missing import lol

* Delete utils.py; move to agentic system

* Support downloading of URLs for attachments for code interpreter

* Simplify and generalize `llama api build` yay

* Update `llama stack configure` to be very simple also

* Fix stack start

* Allow building an "adhoc" distribution

* Remote `llama api []` subcommands

* Fixes to llama stack commands and update docs

* Update documentation again and add error messages to llama stack start

* llama stack start -> llama stack run

* Change name of build for less confusion

* Add pyopenapi fork to the repository, update RFC assets

* Remove conflicting annotation

* Added a "--raw" option for model template printing

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
This commit is contained in:
Ashwin Bharambe 2024-09-03 22:39:39 -07:00 committed by GitHub
parent 35093c0b6f
commit 7bc7785b0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
141 changed files with 8252 additions and 4032 deletions

View file

@ -1,4 +1,4 @@
include requirements.txt include requirements.txt
include llama_toolchain/data/*.yaml include llama_toolchain/data/*.yaml
include llama_toolchain/distribution/*.sh include llama_toolchain/core/*.sh
include llama_toolchain/cli/scripts/*.sh include llama_toolchain/cli/scripts/*.sh

View file

@ -2,10 +2,10 @@
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package. The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package.
### Subcommands ### Subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace. 1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace.
2. `model`: Lists available models and their properties. 2. `model`: Lists available models and their properties.
3. `distribution`: A distribution is a set of REST APIs, this command allows you to manage (list, install, create, configure, start) distributions. You can read more about this [here](https://github.com/meta-llama/llama-stack/blob/main/docs/cli_reference.md#step-3-installing-and-configuring-distributions). 3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](https://github.com/meta-llama/llama-stack/blob/api_updates_1/docs/cli_reference.md#step-3-building-configuring-and-running-llama-stack-servers).
### Sample Usage ### Sample Usage
@ -13,7 +13,7 @@ The `llama` CLI tool helps you setup and use the Llama toolchain & agentic syste
llama --help llama --help
``` ```
<pre style="font-family: monospace;"> <pre style="font-family: monospace;">
usage: llama [-h] {download,model,distribution} ... usage: llama [-h] {download,model,stack,api} ...
Welcome to the Llama CLI Welcome to the Llama CLI
@ -21,7 +21,7 @@ options:
-h, --help show this help message and exit -h, --help show this help message and exit
subcommands: subcommands:
{download,model,distribution} {download,model,stack,api}
</pre> </pre>
## Step 1. Get the models ## Step 1. Get the models
@ -101,9 +101,9 @@ The `llama model` command helps you explore the models interface.
### 2.1 Subcommands ### 2.1 Subcommands
1. `download`: Download the model from different sources. (meta, huggingface) 1. `download`: Download the model from different sources. (meta, huggingface)
2. `list`: Lists all the models available for download with hardware requirements to deploy the models. 2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
3. `template`: <TODO: What is a template?> 3. `template`: <TODO: What is a template?>
4. `describe`: Describes all the properties of the model. 4. `describe`: Describes all the properties of the model.
### 2.2 Sample Usage ### 2.2 Sample Usage
@ -236,11 +236,13 @@ These commands can help understand the model interface and how prompts / message
**NOTE**: Outputs in terminal are color printed to show special tokens. **NOTE**: Outputs in terminal are color printed to show special tokens.
## Step 3: Installing and Configuring Distributions ## Step 3: Building, Configuring and Running Llama Stack servers
An agentic app has several components including model inference, tool execution and system safety shields. Running all these components is made simpler (we hope!) with Llama Stack Distributions. An agentic app has several components including model inference, tool execution and system safety shields. Running all these components is made simpler (we hope!) with Llama Stack Distributions.
A Distribution is simply a collection of REST API providers that are part of the Llama stack. As an example, by running a simple command `llama distribution start`, you can bring up a server serving the following endpoints, among others: The Llama Stack is a collection of REST APIs. An API is _implemented_ by Provider. An assembly of Providers together provides the implementation for the Stack -- this package is called a Distribution.
As an example, by running a simple command `llama stack run`, you can bring up a server serving the following endpoints, among others:
``` ```
POST /inference/chat_completion POST /inference/chat_completion
POST /inference/completion POST /inference/completion
@ -253,103 +255,135 @@ POST /agentic_system/delete
The agentic app can now simply point to this server to execute all its needed components. The agentic app can now simply point to this server to execute all its needed components.
A distributions behavior can be configured by defining a specification or “spec”. This specification lays out the different API “Providers” that constitute this distribution. Lets build, configure and start a Llama Stack server specified via a "Distribution ID" to understand more !
Lets install, configure and start a distribution to understand more ! Lets start with listing available distributions:
Lets start with listing available distributions
``` ```
llama distribution list llama stack list-distributions
``` ```
<pre style="font-family: monospace;"> <pre style="font-family: monospace;">
+--------------+---------------------------------------------+----------------------------------------------------------------------+ i+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| Spec ID | ProviderSpecs | Description | | Distribution ID | Providers | Description |
+--------------+---------------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs | | local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
| | "inference": "meta-reference", | | | | "inference": "meta-reference", | |
| | "safety": "meta-reference", | | | | "memory": "meta-reference-faiss", | |
| | "agentic_system": "meta-reference" | | | | "safety": "meta-reference", | |
| | } | | | | "agentic_system": "meta-reference" | |
+--------------+---------------------------------------------+----------------------------------------------------------------------+ | | } | |
| remote | { | Point to remote services for all llama stack APIs | +--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| | "inference": "inference-remote", | | | remote | { | Point to remote services for all llama stack APIs |
| | "safety": "safety-remote", | | | | "inference": "remote", | |
| | "agentic_system": "agentic_system-remote" | | | | "safety": "remote", | |
| | } | | | | "agentic_system": "remote", | |
+--------------+---------------------------------------------+----------------------------------------------------------------------+ | | "memory": "remote" | |
| local-ollama | { | Like local, but use ollama for running LLM inference | | | } | |
| | "inference": "meta-ollama", | | +--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| | "safety": "meta-reference", | | | local-ollama | { | Like local, but use ollama for running LLM inference |
| | "agentic_system": "meta-reference" | | | | "inference": "remote::ollama", | |
| | } | | | | "safety": "meta-reference", | |
+--------------+---------------------------------------------+----------------------------------------------------------------------+ | | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | |
| | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
| | "inference": "remote::fireworks", | |
| | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | |
| | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local-plus-together-inference | { | Use Together.ai for running LLM inference |
| | "inference": "remote::together", | |
| | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | |
| | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
</pre> </pre>
As you can see above, each “spec” details the “providers” that make up that spec. For eg. The `local` spec uses the “meta-reference” provider for inference while the `local-ollama` spec relies on a different provider ( ollama ) for inference. As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.
Lets install the fully local implementation of the llama-stack named `local` above. To install a distribution, we run a simple command providing 2 inputs:
- **Distribution Id** of the distribution that we want to install ( as obtained from the list-distributions command )
- A **Name** for the specific build and configuration of this distribution.
To install a distro, we run a simple command providing 2 inputs Let's imagine you are working with a 8B-Instruct model. The following command will build a package (in the form of a Conda environment) _and_ configure it. As part of the configuration, you will be asked for some inputs (model_id, max_seq_len, etc.) Since we are working with a 8B model, we will name our build `8b-instruct` to help us remember the config.
- **Spec Id** of the distribution that we want to install ( as obtained from the list command )
- A **Name** by which this installation will be known locally.
``` ```
llama distribution install --spec local --name local_llama_8b llama stack build local --name 8b-instruct
``` ```
This will create a new conda environment (name can be passed optionally) and install dependencies (via pip) as required by the distro. Once it runs successfully , you should see some outputs in the form:
Once it runs successfully , you should see some outputs in the form
``` ```
llama distribution install --spec local --name local_llama_8b $ llama stack build local --name 8b-instruct
``` ....
<pre style="font-family: monospace;"> ....
Successfully installed cfgv-3.4.0 distlib-0.3.8 identify-2.6.0 libcst-1.4.0 llama_toolchain-0.0.2 moreorless-0.4.0 nodeenv-1.9.1 pre-commit-3.8.0 stdlibs-2024.5.15 toml-0.10.2 tomlkit-0.13.0 trailrunner-1.4.0 ufmt-2.7.0 usort-1.0.8 virtualenv-20.26.3 Successfully installed cfgv-3.4.0 distlib-0.3.8 identify-2.6.0 libcst-1.4.0 llama_toolchain-0.0.2 moreorless-0.4.0 nodeenv-1.9.1 pre-commit-3.8.0 stdlibs-2024.5.15 toml-0.10.2 tomlkit-0.13.0 trailrunner-1.4.0 ufmt-2.7.0 usort-1.0.8 virtualenv-20.26.3
Distribution `local_llama_8b` (with spec local) has been installed successfully! Successfully setup conda environment. Configuring build...
</pre>
Next step is to configure the distribution that you just installed. We provide a simple CLI tool to enable simple configuration. ...
This command will walk you through the configuration process. ...
It will ask for some details like model name, paths to models, etc.
**NOTE**: You will have to download the models if not done already. Follow instructions here on how to download using the llama cli YAML configuration has been written to ~/.llama/builds/local/conda/8b-instruct.yaml
```
llama distribution configure --name local_llama_8b
``` ```
Here is an example output of how the cli will guide you to fill the configuration: You can re-configure this distribution by running:
<pre style="font-family: monospace;"> ```
Configuring API surface: inference llama stack configure local --name 8b-instruct
```
Here is an example run of how the CLI will guide you to fill the configuration
```
$ llama stack configure local --name 8b-instruct
Configuring API: inference (meta-reference)
Enter value for model (required): Meta-Llama3.1-8B-Instruct Enter value for model (required): Meta-Llama3.1-8B-Instruct
Enter value for quantization (optional): Enter value for quantization (optional):
Enter value for torch_seed (optional): Enter value for torch_seed (optional):
Enter value for max_seq_len (required): 4096 Enter value for max_seq_len (required): 4096
Enter value for max_batch_size (default: 1): 1 Enter value for max_batch_size (default: 1): 1
Configuring API surface: safety Configuring API: safety (meta-reference)
Do you want to configure llama_guard_shield? (y/n): n Do you want to configure llama_guard_shield? (y/n): y
Do you want to configure prompt_guard_shield? (y/n): n Entering sub-configuration for llama_guard_shield:
Configuring API surface: agentic_system Enter value for model (required): Llama-Guard-3-8B
Enter value for excluded_categories (required): []
Enter value for disable_input_check (default: False):
Enter value for disable_output_check (default: False):
Do you want to configure prompt_guard_shield? (y/n): y
Entering sub-configuration for prompt_guard_shield:
Enter value for model (required): Prompt-Guard-86M
...
...
YAML configuration has been written to ~/.llama/builds/local/conda/8b-instruct.yaml
```
YAML configuration has been written to ~/.llama/distributions/local0/config.yaml As you can see, we did basic configuration above and configured:
</pre> - inference to run on model `Meta-Llama3.1-8B-Instruct` (obtained from `llama model list`)
- Llama Guard safety shield with model `Llama-Guard-3-8B`
As you can see, we did basic configuration above and configured inference to run on model Meta-Llama3.1-8B-Instruct ( obtained from the llama model list command ). - Prompt Guard safety shield with model `Prompt-Guard-86M`
For this initial setup we did not set up safety.
For how these configurations are stored as yaml, checkout the file printed at the end of the configuration. For how these configurations are stored as yaml, checkout the file printed at the end of the configuration.
## Step 4: Starting a Distribution and Testing it Note that all configurations as well as models are stored in `~/.llama`
Now lets start the distribution using the cli. ## Step 4: Starting a Llama Stack Distribution and Testing it
```
llama distribution start --name local_llama_8b --port 5000 Now lets start Llama Stack server.
```
You should see the distribution start and print the APIs that it is supporting: You need the YAML configuration file which was written out at the end by the `llama stack build` step.
```
llama stack run local --name 8b-instruct --port 5000
```
You should see the Stack server start and print the APIs that it is supporting,
```
$ llama stack run local --name 8b-instruct --port 5000
<pre style="font-family: monospace;">
> initializing model parallel with size 1 > initializing model parallel with size 1
> initializing ddp with size 1 > initializing ddp with size 1
> initializing pipeline with size 1 > initializing pipeline with size 1
@ -376,15 +410,23 @@ INFO: Started server process [453333]
INFO: Waiting for application startup. INFO: Waiting for application startup.
INFO: Application startup complete. INFO: Application startup complete.
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit) INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
</pre>
Lets test with a client
``` ```
cd /path/to/llama-toolchain
conda activate <env-for-distribution> # ( Eg. local_llama_8b in above example )
python -m llama_toolchain.inference.client localhost 5000
> [!NOTE]
> Configuration is in `~/.llama/builds/local/conda/8b-instruct.yaml`. Feel free to increase `max_seq_len`.
> [!IMPORTANT]
> The "local" distribution inference server currently only supports CUDA. It will not work on Apple Silicon machines.
This server is running a Llama model locally.
Lets test with a client.
```
cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work
python -m llama_toolchain.inference.client localhost 5000
``` ```
This will run the chat completion client and query the distributions /inference/chat_completion API. This will run the chat completion client and query the distributions /inference/chat_completion API.

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa from .api import * # noqa: F401 F403
from .endpoints import * # noqa

View file

@ -0,0 +1,413 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
@json_schema_type
class Attachment(BaseModel):
content: InterleavedTextMedia | URL
mime_type: str
class AgenticSystemTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
function_call = "function_call"
memory = "memory"
class ToolDefinitionCommon(BaseModel):
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@json_schema_type
class BraveSearchToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.brave_search.value] = (
AgenticSystemTool.brave_search.value
)
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
AgenticSystemTool.wolfram_alpha.value
)
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.code_interpreter.value] = (
AgenticSystemTool.code_interpreter.value
)
enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.function_call.value] = (
AgenticSystemTool.function_call.value
)
function_name: str
description: str
parameters: Dict[str, ToolParamDefinition]
remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgenticSystemVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[
Union[
BraveSearchToolDefinition,
WolframAlphaToolDefinition,
PhotogenToolDefinition,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
MemoryToolDefinition,
],
Field(discriminator="type"),
]
class StepCommon(BaseModel):
turn_id: str
step_id: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class StepType(Enum):
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
model_response: CompletionMessage
@json_schema_type
class ToolExecutionStep(StepCommon):
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
@json_schema_type
class ShieldCallStep(StepCommon):
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
response: ShieldResponse
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value
)
memory_bank_ids: List[str]
inserted_context: InterleavedTextMedia
Step = Annotated[
Union[
InferenceStep,
ToolExecutionStep,
ShieldCallStep,
MemoryRetrievalStep,
],
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System."""
turn_id: str
session_id: str
input_messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
steps: List[Step]
output_message: CompletionMessage
output_attachments: List[Attachment] = Field(default_factory=list)
started_at: datetime
completed_at: Optional[datetime] = None
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System."""
session_id: str
session_name: str
turns: List[Turn]
started_at: datetime
memory_bank: Optional[MemoryBank] = None
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type
class AgentConfig(AgentConfigCommon):
model: str
instructions: str
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
class AgenticSystemTurnResponseEventType(Enum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
@json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value
)
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value
)
step_type: StepType
step_details: Step
@json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value
)
step_type: StepType
step_id: str
model_response_text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
tool_response_text_delta: Optional[str] = None
@json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value
)
turn_id: str
@json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value
)
turn: Turn
@json_schema_type
class AgenticSystemTurnResponseEvent(BaseModel):
"""Streamed agent execution response."""
payload: Annotated[
Union[
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
],
Field(discriminator="event_type"),
]
@json_schema_type
class AgenticSystemCreateResponse(BaseModel):
agent_id: str
@json_schema_type
class AgenticSystemSessionCreateResponse(BaseModel):
session_id: str
@json_schema_type
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
attachments: Optional[List[Attachment]] = None
stream: Optional[bool] = False
@json_schema_type
class AgenticSystemTurnResponseStreamChunk(BaseModel):
event: AgenticSystemTurnResponseEvent
@json_schema_type
class AgenticSystemStepResponse(BaseModel):
step: Step
class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create")
async def create_agentic_system(
self,
agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ...
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
) -> AgenticSystemTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn(
self,
agent_id: str,
turn_id: str,
) -> Turn: ...
@webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ...
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
self,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse: ...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session(
self, agent_id: str, session_id: str
) -> None: ...
@webmethod(route="/agentic_system/delete")
async def delete_agentic_system(
self,
agent_id: str,
) -> None: ...

View file

@ -1,234 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api.datatypes import * # noqa: F403
from llama_toolchain.memory.api.datatypes import * # noqa: F403
@json_schema_type
class AgenticSystemToolDefinition(ToolDefinition):
execution_config: Optional[RestAPIExecutionConfig] = None
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
class StepCommon(BaseModel):
turn_id: str
step_id: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class StepType(Enum):
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
model_response: CompletionMessage
@json_schema_type
class ToolExecutionStep(StepCommon):
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
@json_schema_type
class ShieldCallStep(StepCommon):
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
response: ShieldResponse
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value
)
memory_bank_ids: List[str]
documents: List[MemoryBankDocument]
scores: List[float]
Step = Annotated[
Union[
InferenceStep,
ToolExecutionStep,
ShieldCallStep,
MemoryRetrievalStep,
],
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System."""
turn_id: str
session_id: str
input_messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
steps: List[Step]
output_message: CompletionMessage
started_at: datetime
completed_at: Optional[datetime] = None
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System."""
session_id: str
session_name: str
turns: List[Turn]
started_at: datetime
@json_schema_type
class ToolPromptFormat(Enum):
"""This Enum refers to the prompt format for calling zero shot tools
`json` --
Refers to the json format for calling tools.
The json format takes the form like
{
"type": "function",
"function" : {
"name": "function_name",
"description": "function_description",
"parameters": {...}
}
}
`function_tag` --
This is an example of how you could define
your own user defined format for making tool calls.
The function_tag format looks like this,
<function=function_name>(parameters)</function>
The detailed prompts for each of these formats are defined in `system_prompt.py`
"""
json = "json"
function_tag = "function_tag"
@json_schema_type
class AgenticSystemInstanceConfig(BaseModel):
instructions: str
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot or built-in tool configurations as input to the model
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
default_factory=list
)
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
# if you completely want to replace the messages prefixed by the system,
# this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
class AgenticSystemTurnResponseEventType(Enum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
@json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value
)
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value
)
step_type: StepType
step_details: Step
@json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value
)
step_type: StepType
step_id: str
model_response_text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
tool_response_text_delta: Optional[str] = None
@json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value
)
turn_id: str
@json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value
)
turn: Turn
@json_schema_type
class AgenticSystemTurnResponseEvent(BaseModel):
"""Streamed agent execution response."""
payload: Annotated[
Union[
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
],
Field(discriminator="event_type"),
]

View file

@ -1,127 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F403
from typing import Protocol
# this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import json_schema_type, webmethod
@json_schema_type
class AgenticSystemCreateRequest(BaseModel):
model: str
instance_config: AgenticSystemInstanceConfig
@json_schema_type
class AgenticSystemCreateResponse(BaseModel):
system_id: str
@json_schema_type
class AgenticSystemSessionCreateRequest(BaseModel):
system_id: str
session_name: str
@json_schema_type
class AgenticSystemSessionCreateResponse(BaseModel):
session_id: str
@json_schema_type
# what's the URI?
class AgenticSystemTurnCreateRequest(BaseModel):
system_id: str
session_id: str
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
stream: Optional[bool] = False
override_config: Optional[AgenticSystemInstanceConfig] = None
@json_schema_type
class AgenticSystemTurnResponseStreamChunk(BaseModel):
event: AgenticSystemTurnResponseEvent
@json_schema_type
class AgenticSystemStepResponse(BaseModel):
step: Step
class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create")
async def create_agentic_system(
self,
request: AgenticSystemCreateRequest,
) -> AgenticSystemCreateResponse: ...
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
) -> AgenticSystemTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn(
self,
agent_id: str,
turn_id: str,
) -> Turn: ...
@webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ...
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse: ...
@webmethod(route="/agentic_system/memory_bank/attach")
async def attach_memory_bank_to_agentic_system(
self,
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
@webmethod(route="/agentic_system/memory_bank/detach")
async def detach_memory_bank_from_agentic_system(
self,
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session(
self, agent_id: str, session_id: str
) -> None: ...
@webmethod(route="/agentic_system/delete")
async def delete_agentic_system(
self,
agent_id: str,
) -> None: ...

View file

@ -6,38 +6,28 @@
import asyncio import asyncio
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
import fire import fire
import httpx import httpx
from llama_models.llama3.api.datatypes import ( from pydantic import BaseModel
BuiltinTool,
SamplingParams,
ToolParamDefinition,
UserMessage,
)
from termcolor import cprint from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger from llama_models.llama3.api.datatypes import * # noqa: F403
from .api import ( from llama_toolchain.core.datatypes import RemoteProviderConfig
AgenticSystem,
AgenticSystemCreateRequest, from .api import * # noqa: F403
AgenticSystemCreateResponse, from .event_logger import EventLogger
AgenticSystemInstanceConfig,
AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse,
AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
ToolPromptFormat,
)
async def get_client_impl(base_url: str): async def get_client_impl(config: RemoteProviderConfig, _deps):
return AgenticSystemClient(base_url) return AgenticSystemClient(config.url)
def encodable_dict(d: BaseModel):
return json.loads(d.json())
class AgenticSystemClient(AgenticSystem): class AgenticSystemClient(AgenticSystem):
@ -45,12 +35,14 @@ class AgenticSystemClient(AgenticSystem):
self.base_url = base_url self.base_url = base_url
async def create_agentic_system( async def create_agentic_system(
self, request: AgenticSystemCreateRequest self, agent_config: AgentConfig
) -> AgenticSystemCreateResponse: ) -> AgenticSystemCreateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/agentic_system/create", f"{self.base_url}/agentic_system/create",
data=request.json(), json={
"agent_config": encodable_dict(agent_config),
},
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
@ -58,12 +50,16 @@ class AgenticSystemClient(AgenticSystem):
async def create_agentic_system_session( async def create_agentic_system_session(
self, self,
request: AgenticSystemSessionCreateRequest, agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse: ) -> AgenticSystemSessionCreateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/agentic_system/session/create", f"{self.base_url}/agentic_system/session/create",
data=request.json(), json={
"agent_id": agent_id,
"session_name": session_name,
},
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
@ -77,7 +73,9 @@ class AgenticSystemClient(AgenticSystem):
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/agentic_system/turn/create", f"{self.base_url}/agentic_system/turn/create",
data=request.json(), json={
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as response: ) as response:
@ -85,6 +83,10 @@ class AgenticSystemClient(AgenticSystem):
if line.startswith("data:"): if line.startswith("data:"):
data = line[len("data: ") :] data = line[len("data: ") :]
try: try:
if "error" in data:
cprint(data, "red")
continue
yield AgenticSystemTurnResponseStreamChunk( yield AgenticSystemTurnResponseStreamChunk(
**json.loads(data) **json.loads(data)
) )
@ -93,24 +95,52 @@ class AgenticSystemClient(AgenticSystem):
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
agent_config = AgentConfig(
model="Meta-Llama3.1-8B-Instruct",
instructions="You are a helpful assistant",
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
tools=tool_definitions,
tool_choice=ToolChoice.auto,
tool_prompt_format=ToolPromptFormat.function_tag,
)
create_response = await api.create_agentic_system(agent_config)
session_response = await api.create_agentic_system_session(
agent_id=create_response.agent_id,
session_name="test_session",
)
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,
messages=[
UserMessage(content=content),
],
attachments=attachments,
stream=True,
)
)
async for event, log in EventLogger().log(iterator):
if log is not None:
log.print()
async def run_main(host: str, port: int): async def run_main(host: str, port: int):
# client to test remote impl of agentic system
api = AgenticSystemClient(f"http://{host}:{port}") api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
AgenticSystemToolDefinition( BraveSearchToolDefinition(),
tool_name=BuiltinTool.brave_search, WolframAlphaToolDefinition(),
), CodeInterpreterToolDefinition(),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.wolfram_alpha,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter,
),
] ]
tool_definitions += [ tool_definitions += [
AgenticSystemToolDefinition( FunctionCallToolDefinition(
tool_name="get_boiling_point", function_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)", description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={ parameters={
"liquid_name": ToolParamDefinition( "liquid_name": ToolParamDefinition(
@ -127,30 +157,6 @@ async def run_main(host: str, port: int):
), ),
] ]
create_request = AgenticSystemCreateRequest(
model="Meta-Llama3.1-8B-Instruct",
instance_config=AgenticSystemInstanceConfig(
instructions="You are a helpful assistant",
sampling_params=SamplingParams(),
available_tools=tool_definitions,
input_shields=[],
output_shields=[],
debug_prefix_messages=[],
tool_prompt_format=ToolPromptFormat.json,
),
)
create_response = await api.create_agentic_system(create_request)
print(create_response)
session_response = await api.create_agentic_system_session(
AgenticSystemSessionCreateRequest(
system_id=create_response.system_id,
session_name="test_session",
)
)
print(session_response)
user_prompts = [ user_prompts = [
"Who are you?", "Who are you?",
"what is the 100th prime number?", "what is the 100th prime number?",
@ -158,26 +164,51 @@ async def run_main(host: str, port: int):
"Write code to check if a number is prime. Use that to check if 7 is prime", "Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?", "What is the boiling point of polyjuicepotion ?",
] ]
for content in user_prompts: await _run_agent(api, tool_definitions, user_prompts)
cprint(f"User> {content}", color="blue")
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest( async def run_rag(host: str, port: int):
system_id=create_response.system_id, api = AgenticSystemClient(f"http://{host}:{port}")
session_id=session_response.session_id,
messages=[ urls = [
UserMessage(content=content), "memory_optimizations.rst",
], "chat.rst",
stream=True, "llama3.rst",
) "datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
attachments = [
Attachment(
content=URL(
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
),
mime_type="text/plain",
) )
for i, url in enumerate(urls)
]
async for event, log in EventLogger().log(iterator): # Alternatively, you can pre-populate the memory bank with documents for example,
if log is not None: # using `llama_toolchain.memory.client`. Then you can grab the bank_id
log.print() # from the output of that run.
tool_definitions = [
MemoryToolDefinition(
max_tokens_in_context=2048,
memory_bank_configs=[],
),
]
user_prompts = [
"How do I use Lora?",
"Tell me briefly about llama3 and torchtune",
]
await _run_agent(api, tool_definitions, user_prompts, attachments)
def main(host: str, port: int): def main(host: str, port: int, rag: bool = False):
asyncio.run(run_main(host, port)) fn = run_rag if rag else run_main
asyncio.run(fn(host, port))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -6,7 +6,7 @@
from typing import Optional from typing import Optional
from llama_models.llama3.api.datatypes import ToolResponseMessage from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint from termcolor import cprint
@ -44,7 +44,12 @@ EventType = AgenticSystemTurnResponseEventType
class EventLogger: class EventLogger:
async def log(self, event_generator, stream=True): async def log(
self,
event_generator,
stream=True,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
):
previous_event_type = None previous_event_type = None
previous_step_type = None previous_step_type = None
@ -132,7 +137,9 @@ class EventLogger:
if event_type == EventType.step_complete.value: if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response response = event.payload.step_details.model_response
if response.tool_calls: if response.tool_calls:
content = ToolUtils.encode_tool_call(response.tool_calls[0]) content = ToolUtils.encode_tool_call(
response.tool_calls[0], tool_prompt_format
)
else: else:
content = response.content content = response.content
yield event, LogEvent( yield event, LogEvent(
@ -162,5 +169,19 @@ class EventLogger:
color="green", color="green",
) )
if (
step_type == StepType.memory_retrieval
and event_type == EventType.step_complete.value
):
details = event.payload.step_details
content = interleaved_text_media_as_str(details.inserted_context)
content = content[:200] + "..." if len(content) > 200 else content
yield event, LogEvent(
role=step_type,
content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>",
color="cyan",
)
preivous_event_type = event_type preivous_event_type = event_type
previous_step_type = step_type previous_step_type = step_type

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.agentic_system.api import (
AgenticSystemTurnResponseEventType as EventType,
)
from llama_toolchain.tools.custom.datatypes import CustomTool
class AgentWithCustomToolExecutor:
def __init__(
self,
api: AgenticSystem,
agent_id: str,
session_id: str,
agent_config: AgentConfig,
custom_tools: List[CustomTool],
):
self.api = api
self.agent_id = agent_id
self.session_id = session_id
self.agent_config = agent_config
self.custom_tools = custom_tools
async def execute_turn(
self,
messages: List[Message],
attachments: Optional[List[Attachment]] = None,
max_iters: int = 5,
stream: bool = True,
) -> AsyncGenerator:
tools_dict = {t.get_name(): t for t in self.custom_tools}
current_messages = messages.copy()
n_iter = 0
while n_iter < max_iters:
n_iter += 1
request = AgenticSystemTurnCreateRequest(
agent_id=self.agent_id,
session_id=self.session_id,
messages=current_messages,
attachments=attachments,
stream=stream,
)
turn = None
async for chunk in self.api.create_agentic_system_turn(request):
if chunk.event.payload.event_type != EventType.turn_complete.value:
yield chunk
else:
turn = chunk.event.payload.turn
message = turn.output_message
if len(message.tool_calls) == 0:
yield chunk
return
if message.stop_reason == StopReason.out_of_tokens:
yield chunk
return
tool_call = message.tool_calls[0]
if tool_call.tool_name not in tools_dict:
m = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
)
next_message = m
else:
tool = tools_dict[tool_call.tool_name]
result_messages = await execute_custom_tool(tool, message)
next_message = result_messages[0]
yield next_message
current_messages = [next_message]
async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]:
result_messages = await tool.run([message])
assert (
len(result_messages) == 1
), f"Expected single message, got {len(result_messages)}"
return result_messages

View file

@ -4,5 +4,27 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .agentic_system import get_provider_impl # noqa from typing import Dict
from .config import AgenticSystemConfig # noqa
from llama_toolchain.core.datatypes import Api, ProviderSpec
from .config import MetaReferenceImplConfig
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agentic_system import MetaReferenceAgenticSystemImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
return impl

View file

@ -4,111 +4,111 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import copy import copy
import os
import secrets
import shutil
import string
import tempfile
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Tuple
from urllib.parse import urlparse
import httpx
from termcolor import cprint from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import ( from llama_toolchain.agentic_system.api import * # noqa: F403
AgenticSystemInstanceConfig, from llama_toolchain.inference.api import * # noqa: F403
AgenticSystemTurnResponseEvent, from llama_toolchain.memory.api import * # noqa: F403
AgenticSystemTurnResponseEventType, from llama_toolchain.safety.api import * # noqa: F403
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
InferenceStep,
Session,
ShieldCallStep,
StepType,
ToolExecutionStep,
ToolPromptFormat,
Turn,
)
from llama_toolchain.inference.api import ChatCompletionRequest, Inference from llama_toolchain.tools.base import BaseTool
from llama_toolchain.tools.builtin import (
from llama_toolchain.inference.api.datatypes import ( interpret_content_as_attachment,
Attachment, SingleMessageBuiltinTool,
BuiltinTool,
ChatCompletionResponseEventType,
CompletionMessage,
Message,
Role,
SamplingParams,
StopReason,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
URL,
) )
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.api.datatypes import (
BuiltinShield,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
from .system_prompt import get_agentic_prefix_messages
from .tools.base import BaseTool
from .tools.builtin import SingleMessageBuiltinTool
class AgentInstance(ShieldRunnerMixin): def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
class ChatAgent(ShieldRunnerMixin):
def __init__( def __init__(
self, self,
system_id: int, agent_config: AgentConfig,
instance_config: AgenticSystemInstanceConfig,
model: str,
inference_api: Inference, inference_api: Inference,
memory_api: Memory,
safety_api: Safety, safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool], builtin_tools: List[SingleMessageBuiltinTool],
custom_tool_definitions: List[ToolDefinition],
input_shields: List[ShieldDefinition],
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10, max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
): ):
self.system_id = system_id self.agent_config = agent_config
self.instance_config = instance_config
self.model = model
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api self.safety_api = safety_api
if prefix_messages is not None and len(prefix_messages) > 0:
self.prefix_messages = prefix_messages
else:
self.prefix_messages = get_agentic_prefix_messages(
builtin_tools,
custom_tool_definitions,
tool_prompt_format,
)
for m in self.prefix_messages:
print(m.content)
self.max_infer_iters = max_infer_iters self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools} self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.tempdir = tempfile.mkdtemp()
self.sessions = {} self.sessions = {}
ShieldRunnerMixin.__init__( ShieldRunnerMixin.__init__(
self, self,
safety_api, safety_api,
input_shields=input_shields, input_shields=agent_config.input_shields,
output_shields=output_shields, output_shields=agent_config.output_shields,
) )
def __del__(self):
shutil.rmtree(self.tempdir)
def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = []
# We do not want to keep adding RAG context to the input messages
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
for m in turn.input_messages:
msg = m.copy()
if isinstance(msg, UserMessage):
msg.context = None
messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
def create_session(self, name: str) -> Session: def create_session(self, name: str) -> Session:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
session = Session( session = Session(
@ -131,32 +131,7 @@ class AgentInstance(ShieldRunnerMixin):
messages = [] messages = []
for i, turn in enumerate(session.turns): for i, turn in enumerate(session.turns):
# print(f"turn {i}") messages.extend(self.turn_to_messages(turn))
# print_dialog(turn.input_messages)
messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
# TODO: Properly persist the
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
stop_reason=StopReason.end_of_turn,
)
)
messages.extend(request.messages) messages.extend(request.messages)
@ -164,7 +139,6 @@ class AgentInstance(ShieldRunnerMixin):
# print_dialog(messages) # print_dialog(messages)
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
params = self.instance_config.sampling_params
start_time = datetime.now() start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk( yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgenticSystemTurnResponseEvent(
@ -177,12 +151,12 @@ class AgentInstance(ShieldRunnerMixin):
steps = [] steps = []
output_message = None output_message = None
async for chunk in self.run( async for chunk in self.run(
session=session,
turn_id=turn_id, turn_id=turn_id,
input_messages=messages, input_messages=messages,
temperature=params.temperature, attachments=request.attachments or [],
top_p=params.top_p, sampling_params=self.agent_config.sampling_params,
stream=request.stream, stream=request.stream,
max_gen_len=params.max_tokens,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
cprint( cprint(
@ -227,6 +201,53 @@ class AgentInstance(ShieldRunnerMixin):
) )
yield chunk yield chunk
async def run(
self,
session: Session,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
session, turn_id, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def run_shields_wrapper( async def run_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
@ -288,65 +309,62 @@ class AgentInstance(ShieldRunnerMixin):
) )
) )
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def _run( async def _run(
self, self,
session: Session,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
temperature: float, attachments: List[Attachment],
top_p: float, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
input_messages = preprocess_dialog(input_messages, self.prefix_messages) enabled_tools = set(t.type for t in self.agent_config.tools)
need_rag_context = await self._should_retrieve_context(
input_messages, attachments
)
if need_rag_context:
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
)
)
)
attachments = [] # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context(
session, input_messages, attachments
)
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
turn_id=turn_id,
step_id=step_id,
memory_bank_ids=bank_ids,
inserted_context=rag_context or "",
),
)
)
)
if rag_context:
last_message = input_messages[-1]
last_message.context = "\n".join(rag_context)
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
msg = await attachment_message(self.tempdir, urls)
input_messages.append(msg)
output_attachments = []
n_iter = 0 n_iter = 0
while True: while True:
@ -369,17 +387,13 @@ class AgentInstance(ShieldRunnerMixin):
) )
) )
# where are the available tools?
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=self.model, model=self.agent_config.model,
messages=input_messages, messages=input_messages,
available_tools=self.instance_config.available_tools, tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True, stream=True,
sampling_params=SamplingParams( sampling_params=sampling_params,
temperature=temperature,
top_p=top_p,
max_tokens=max_gen_len,
),
) )
tool_calls = [] tool_calls = []
@ -464,7 +478,8 @@ class AgentInstance(ShieldRunnerMixin):
if len(message.tool_calls) == 0: if len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn: if stop_reason == StopReason.end_of_turn:
if len(attachments) > 0: # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list): if isinstance(message.content, list):
message.content += attachments message.content += attachments
else: else:
@ -572,63 +587,175 @@ class AgentInstance(ShieldRunnerMixin):
yield False yield False
return return
if isinstance(result_message.content, Attachment): if out_attachment := interpret_content_as_attachment(
result_message.content
):
# NOTE: when we push this message back to the model, the model may ignore the # NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message # attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message # with the summary. We keep all generated attachments and then attach them to final message
attachments.append(result_message.content) output_attachments.append(out_attachment)
elif isinstance(result_message.content, list) or isinstance(
result_message.content, tuple
):
for c in result_message.content:
if isinstance(c, Attachment):
attachments.append(c)
input_messages = input_messages + [message, result_message] input_messages = input_messages + [message, result_message]
n_iter += 1 n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
if session.memory_bank is None:
session.memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session.session_id}",
config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
def attachment_message(url: URL) -> ToolResponseMessage: return session.memory_bank
uri = url.uri
assert uri.startswith("file://") async def _should_retrieve_context(
filepath = uri[len("file://") :] self, messages: List[Message], attachments: List[Attachment]
) -> bool:
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgenticSystemTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
else:
return True
return AgenticSystemTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgenticSystemTool.memory.value:
return t
return None
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
assert memory is not None, "Memory tool not configured"
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank = await self._ensure_memory_bank(session)
bank_ids.append(bank.bank_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
]
await self.memory_api.insert_documents(bank.bank_id, documents)
elif session.memory_bank:
bank_ids.append(session.memory_bank.bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated
# (i.e., no prior turns uploaded an Attachment)
return None, []
query = " ".join(m.content for m in messages)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
if not chunks:
return None, bank_ids
tokens = 0
picked = []
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
cprint(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
], bank_ids
def _get_tools(self) -> List[ToolDefinition]:
ret = []
for t in self.agent_config.tools:
if isinstance(t, BraveSearchToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
elif isinstance(t, WolframAlphaToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
elif isinstance(t, PhotogenToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
elif isinstance(t, CodeInterpreterToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
elif isinstance(t, FunctionCallToolDefinition):
ret.append(
ToolDefinition(
tool_name=t.function_name,
description=t.description,
parameters=t.parameters,
)
)
return ret
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
content = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
print(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else:
raise ValueError(f"Unsupported URL {url}")
content.append(f'# There is a file accessible to you at "{filepath}"\n')
return ToolResponseMessage( return ToolResponseMessage(
call_id="", call_id="",
tool_name=BuiltinTool.code_interpreter, tool_name=BuiltinTool.code_interpreter,
content=f'# There is a file accessible to you at "{filepath}"', content=content,
) )
def preprocess_dialog(
messages: List[Message], prefix_messages: List[Message]
) -> List[Message]:
"""
Preprocesses the dialog by removing the system message and
adding the system message to the beginning of the dialog.
"""
ret = prefix_messages.copy()
for m in messages:
if m.role == Role.system.value:
continue
# NOTE: the ideal behavior is to use `file_path = ...` but that
# means we need to have stateful execution o f code which we currently
# do not have.
if isinstance(m.content, Attachment):
ret.append(attachment_message(m.content.url))
elif isinstance(m.content, list):
for c in m.content:
if isinstance(c, Attachment):
ret.append(attachment_message(c.url))
ret.append(m)
return ret
async def execute_tool_call_maybe( async def execute_tool_call_maybe(
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage] tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
) -> List[ToolResponseMessage]: ) -> List[ToolResponseMessage]:

View file

@ -8,62 +8,42 @@
import logging import logging
import os import os
import uuid import uuid
from typing import AsyncGenerator, Dict from typing import AsyncGenerator
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.api.datatypes import BuiltinTool from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety from llama_toolchain.safety.api import Safety
from llama_toolchain.agentic_system.api.endpoints import * # noqa from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.agentic_system.api import ( from llama_toolchain.tools.builtin import (
AgenticSystem,
AgenticSystemCreateRequest,
AgenticSystemCreateResponse,
AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse,
AgenticSystemTurnCreateRequest,
)
from .agent_instance import AgentInstance
from .config import AgenticSystemConfig
from .tools.builtin import (
BraveSearchTool, BraveSearchTool,
CodeInterpreterTool, CodeInterpreterTool,
PhotogenTool, PhotogenTool,
WolframAlphaTool, WolframAlphaTool,
) )
from .tools.safety import with_safety from llama_toolchain.tools.safety import with_safety
from .agent_instance import ChatAgent
from .config import MetaReferenceImplConfig
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
assert isinstance(
config, AgenticSystemConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.safety],
)
await impl.initialize()
return impl
AGENT_INSTANCES_BY_ID = {} AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgenticSystemImpl(AgenticSystem): class MetaReferenceAgenticSystemImpl(AgenticSystem):
def __init__( def __init__(
self, config: AgenticSystemConfig, inference_api: Inference, safety_api: Safety self,
config: MetaReferenceImplConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api self.safety_api = safety_api
async def initialize(self) -> None: async def initialize(self) -> None:
@ -71,69 +51,61 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
async def create_agentic_system( async def create_agentic_system(
self, self,
request: AgenticSystemCreateRequest, agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ) -> AgenticSystemCreateResponse:
system_id = str(uuid.uuid4()) agent_id = str(uuid.uuid4())
builtin_tools = [] builtin_tools = []
custom_tool_definitions = [] for tool_defn in agent_config.tools:
cfg = request.instance_config if isinstance(tool_defn, WolframAlphaToolDefinition):
for dfn in cfg.available_tools: key = self.config.wolfram_api_key
if isinstance(dfn.tool_name, BuiltinTool): if not key:
if dfn.tool_name == BuiltinTool.wolfram_alpha: raise ValueError("Wolfram API key not defined in config")
key = self.config.wolfram_api_key tool = WolframAlphaTool(key)
if not key: elif isinstance(tool_defn, BraveSearchToolDefinition):
raise ValueError("Wolfram API key not defined in config") key = self.config.brave_search_api_key
tool = WolframAlphaTool(key) if not key:
elif dfn.tool_name == BuiltinTool.brave_search: raise ValueError("Brave API key not defined in config")
key = self.config.brave_search_api_key tool = BraveSearchTool(key)
if not key: elif isinstance(tool_defn, CodeInterpreterToolDefinition):
raise ValueError("Brave API key not defined in config") tool = CodeInterpreterTool()
tool = BraveSearchTool(key) elif isinstance(tool_defn, PhotogenToolDefinition):
elif dfn.tool_name == BuiltinTool.code_interpreter: tool = PhotogenTool(
tool = CodeInterpreterTool() dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
elif dfn.tool_name == BuiltinTool.photogen:
tool = PhotogenTool(
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
)
else:
raise ValueError(f"Unknown builtin tool: {dfn.tool_name}")
builtin_tools.append(
with_safety(
tool, self.safety_api, dfn.input_shields, dfn.output_shields
)
) )
else: else:
custom_tool_definitions.append(dfn) continue
AGENT_INSTANCES_BY_ID[system_id] = AgentInstance( builtin_tools.append(
system_id=system_id, with_safety(
instance_config=request.instance_config, tool,
model=request.model, self.safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
agent_config=agent_config,
inference_api=self.inference_api, inference_api=self.inference_api,
builtin_tools=builtin_tools,
custom_tool_definitions=custom_tool_definitions,
safety_api=self.safety_api, safety_api=self.safety_api,
input_shields=cfg.input_shields, memory_api=self.memory_api,
output_shields=cfg.output_shields, builtin_tools=builtin_tools,
prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
) )
return AgenticSystemCreateResponse( return AgenticSystemCreateResponse(
system_id=system_id, agent_id=agent_id,
) )
async def create_agentic_system_session( async def create_agentic_system_session(
self, self,
request: AgenticSystemSessionCreateRequest, agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse: ) -> AgenticSystemSessionCreateResponse:
system_id = request.system_id assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found" agent = AGENT_INSTANCES_BY_ID[agent_id]
agent = AGENT_INSTANCES_BY_ID[system_id]
session = agent.create_session(request.session_name) session = agent.create_session(session_name)
return AgenticSystemSessionCreateResponse( return AgenticSystemSessionCreateResponse(
session_id=session.session_id, session_id=session.session_id,
) )
@ -142,9 +114,9 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
self, self,
request: AgenticSystemTurnCreateRequest, request: AgenticSystemTurnCreateRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
system_id = request.system_id agent_id = request.agent_id
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found" assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[system_id] agent = AGENT_INSTANCES_BY_ID[agent_id]
assert ( assert (
request.session_id in agent.sessions request.session_id in agent.sessions

View file

@ -9,6 +9,6 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
class AgenticSystemConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):
brave_search_api_key: Optional[str] = None brave_search_api_key: Optional[str] = None
wolfram_api_key: Optional[str] = None wolfram_api_key: Optional[str] = None

View file

@ -9,12 +9,13 @@ from typing import List
from llama_models.llama3.api.datatypes import Message, Role, UserMessage from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from termcolor import cprint from termcolor import cprint
from llama_toolchain.safety.api.datatypes import ( from llama_toolchain.safety.api import (
OnViolationAction, OnViolationAction,
RunShieldRequest,
Safety,
ShieldDefinition, ShieldDefinition,
ShieldResponse, ShieldResponse,
) )
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
class SafetyException(Exception): # noqa: N818 class SafetyException(Exception): # noqa: N818

View file

@ -1,180 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import textwrap
from datetime import datetime
from typing import List
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.inference.api import (
BuiltinTool,
Message,
SystemMessage,
ToolDefinition,
UserMessage,
)
from .tools.builtin import SingleMessageBuiltinTool
def get_agentic_prefix_messages(
builtin_tools: List[SingleMessageBuiltinTool],
custom_tools: List[ToolDefinition],
tool_prompt_format: ToolPromptFormat,
) -> List[Message]:
messages = []
content = ""
if builtin_tools:
content += "Environment: ipython\n"
tool_str = ", ".join(
[
t.get_name()
for t in builtin_tools
if t.get_name() != BuiltinTool.code_interpreter.value
]
)
if tool_str:
content += f"Tools: {tool_str}"
current_date = datetime.now()
formatted_date = current_date.strftime("%d %B %Y")
date_str = f"""
Cutting Knowledge Date: December 2023
Today Date: {formatted_date}\n"""
content += date_str
messages.append(SystemMessage(content=content))
if custom_tools:
if tool_prompt_format == ToolPromptFormat.function_tag:
text = prompt_for_function_tag(custom_tools)
messages.append(UserMessage(content=text))
elif tool_prompt_format == ToolPromptFormat.json:
text = prompt_for_json(custom_tools)
messages.append(UserMessage(content=text))
else:
raise NotImplementedError(
f"Tool prompt format {tool_prompt_format} is not supported"
)
else:
messages.append(SystemMessage(content=content))
return messages
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
tool_defs = "\n".join(
translate_custom_tool_definition_to_json(t) for t in custom_tools
)
content = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{tool_defs}
Return function calls in JSON format.
"""
)
content = content.lstrip("\n").format(tool_defs=tool_defs)
return content
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
custom_tool_params = ""
for t in custom_tools:
custom_tool_params += get_instruction_string(t) + "\n"
custom_tool_params += get_parameters_string(t) + "\n\n"
content = f"""
You have access to the following functions:
{custom_tool_params}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
"""
return content
def get_instruction_string(custom_tool_definition) -> str:
return f"Use the function '{custom_tool_definition.tool_name}' to '{custom_tool_definition.description}'"
def get_parameters_string(custom_tool_definition) -> str:
return json.dumps(
{
"name": custom_tool_definition.tool_name,
"description": custom_tool_definition.description,
"parameters": {
name: definition.__dict__
for name, definition in custom_tool_definition.parameters.items()
},
}
)
def translate_custom_tool_definition_to_json(tool_def):
"""Translates ToolDefinition to json as expected by model
eg. output for a function
{
"type": "function",
"function": {
"name": "conv_int",
"description": "Convert serialized fract24 integer into int value.",
"parameters": {
"type": "object",
"properties": [
{
"data": {
"type": "object",
"description": ""
}
}
],
"required": ["data"]
}
}
}
"""
assert isinstance(tool_def.tool_name, str)
func_def = {"type": "function", "function": {}}
func_def["function"]["name"] = tool_def.tool_name
func_def["function"]["description"] = tool_def.description or ""
if tool_def.parameters:
required = []
properties = []
for p_name, p_def in tool_def.parameters.items():
properties.append(
{
p_name: {
# TODO: see if this should not always be object
"type": "object",
"description": p_def.description or "",
}
}
)
if p_def.required:
required.append(p_name)
func_def["function"]["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
}
else:
func_def["function"]["parameters"] = {}
return json.dumps(func_def, indent=4)

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_agentic_system_providers() -> List[ProviderSpec]: def available_agentic_system_providers() -> List[ProviderSpec]:
@ -16,15 +16,19 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
provider_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
"matplotlib",
"pillow", "pillow",
"pandas",
"scikit-learn",
"torch", "torch",
"transformers", "transformers",
], ],
module="llama_toolchain.agentic_system.meta_reference", module="llama_toolchain.agentic_system.meta_reference",
config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig", config_class="llama_toolchain.agentic_system.meta_reference.MetaReferenceImplConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
Api.safety, Api.safety,
Api.memory,
], ],
), ),
] ]

View file

@ -1,83 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, List
from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
from llama_toolchain.agentic_system.api import (
AgenticSystem,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseEventType as EventType,
)
from llama_toolchain.inference.api import Message
async def execute_with_custom_tools(
system: AgenticSystem,
system_id: str,
session_id: str,
messages: List[Message],
custom_tools: List[Any],
max_iters: int = 5,
stream: bool = True,
) -> AsyncGenerator:
# first create a session, or do you keep a persistent session?
tools_dict = {t.get_name(): t for t in custom_tools}
current_messages = messages.copy()
n_iter = 0
while n_iter < max_iters:
n_iter += 1
request = AgenticSystemTurnCreateRequest(
system_id=system_id,
session_id=session_id,
messages=current_messages,
stream=stream,
)
turn = None
async for chunk in system.create_agentic_system_turn(request):
if chunk.event.payload.event_type != EventType.turn_complete.value:
yield chunk
else:
turn = chunk.event.payload.turn
message = turn.output_message
if len(message.tool_calls) == 0:
yield chunk
return
if message.stop_reason == StopReason.out_of_tokens:
yield chunk
return
tool_call = message.tool_calls[0]
if tool_call.tool_name not in tools_dict:
m = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
)
next_message = m
else:
tool = tools_dict[tool_call.tool_name]
result_messages = await execute_custom_tool(tool, message)
next_message = result_messages[0]
yield next_message
current_messages = [next_message]
async def execute_custom_tool(tool: Any, message: Message) -> List[Message]:
result_messages = await tool.run([message])
assert (
len(result_messages) == 1
), f"Expected single message, got {len(result_messages)}"
return result_messages

View file

@ -1,122 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Any, List, Optional
from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
from llama_toolchain.agentic_system.api import (
AgenticSystemCreateRequest,
AgenticSystemInstanceConfig,
AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition,
)
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.agentic_system.tools.custom.execute import (
execute_with_custom_tools,
)
from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
# TODO: this should move back to the llama-agentic-system repo
class AgenticSystemClientWrapper:
def __init__(self, api, system_id, custom_tools):
self.api = api
self.system_id = system_id
self.custom_tools = custom_tools
self.session_id = None
async def create_session(self, name: str = None):
if name is None:
name = f"Session-{uuid.uuid4()}"
response = await self.api.create_agentic_system_session(
AgenticSystemSessionCreateRequest(
system_id=self.system_id,
session_name=name,
)
)
self.session_id = response.session_id
return self.session_id
async def run(self, messages: List[Message], stream: bool = True):
async for chunk in execute_with_custom_tools(
self.api,
self.system_id,
self.session_id,
messages,
self.custom_tools,
stream=stream,
):
yield chunk
async def get_agent_system_instance(
host: str,
port: int,
custom_tools: Optional[List[Any]] = None,
disable_safety: bool = False,
model: str = "Meta-Llama3.1-8B-Instruct",
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> AgenticSystemClientWrapper:
custom_tools = custom_tools or []
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
tool_definitions = [
AgenticSystemToolDefinition(
tool_name=BuiltinTool.brave_search,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.wolfram_alpha,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.photogen,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter,
),
] + [t.get_tool_definition() for t in custom_tools]
if not disable_safety:
for t in tool_definitions:
t.input_shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
t.output_shields = [
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
]
create_request = AgenticSystemCreateRequest(
model=model,
instance_config=AgenticSystemInstanceConfig(
instructions="You are a helpful assistant",
available_tools=tool_definitions,
input_shields=(
[]
if disable_safety
else [
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
]
),
output_shields=(
[]
if disable_safety
else [
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
]
),
sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
),
)
create_response = await api.create_agentic_system(create_request)
return AgenticSystemClientWrapper(api, create_response.system_id, custom_tools)

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .distribution import DistributionParser # noqa from .api import * # noqa: F401 F403

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: str
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse: ...
@webmethod(route="/batch_inference/chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ...

View file

@ -1,106 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
import shlex
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from termcolor import cprint
class DistributionConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama distribution configure",
description="configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to configure",
required=True,
)
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.registry import resolve_distribution_spec
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = DistributionConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config.spec)
if dist is None:
raise ValueError(f"Could not find any registered spec `{config.spec}`")
configure_llama_distribution(dist, config)
def configure_llama_distribution(dist: "Distribution", config: "DistributionConfig"):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.dynamic import instantiate_class_type
python_exe = run_command(shlex.split("which python"))
# simple check
conda_env = config.conda_env
if conda_env not in python_exe:
raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment"
)
if config.providers:
cprint(
f"Configuration already exists for {config.name}. Will overwrite...",
"yellow",
attrs=["bold"],
)
for api, provider_spec in dist.provider_specs.items():
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
provider_config = prompt_for_config(
config_type,
(
config_type(**config.providers[api.value])
if api.value in config.providers
else None
),
)
print("")
config.providers[api.value] = {
"provider_id": provider_spec.provider_id,
**provider_config.dict(),
}
config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
with open(config_path, "w") as fp:
dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {config_path}")

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
class DistributionCreate(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"create",
prog="llama distribution create",
description="create a Llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_create_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to create",
required=True,
)
# for each Api the user wants to support, we should
# get the list of available providers, ask which one the user
# wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution_spec
dist = resolve_distribution_spec(args.name)
if dist is not None:
self.parser.error(f"Distribution with name {args.name} already exists")
return
raise NotImplementedError()

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .configure import DistributionConfigure
from .create import DistributionCreate
from .install import DistributionInstall
from .list import DistributionList
from .start import DistributionStart
class DistributionParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"distribution",
prog="llama distribution",
description="Operate on llama stack distributions",
)
subparsers = self.parser.add_subparsers(title="distribution_subcommands")
# Add sub-commands
DistributionList.create(subparsers)
DistributionInstall.create(subparsers)
DistributionCreate.create(subparsers)
DistributionConfigure.create(subparsers)
DistributionStart.create(subparsers)

View file

@ -1,111 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from termcolor import cprint
class DistributionInstall(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"install",
prog="llama distribution install",
description="Install a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_install_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distribution_specs
self.parser.add_argument(
"--spec",
type=str,
help="Distribution spec to install (try local-ollama)",
required=True,
choices=[d.spec_id for d in available_distribution_specs()],
)
self.parser.add_argument(
"--name",
type=str,
help="What should the installation be called locally?",
required=True,
)
self.parser.add_argument(
"--conda-env",
type=str,
help="conda env in which this distribution will run (default = distribution name)",
)
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/install_distribution.sh",
)
dist = resolve_distribution_spec(args.spec)
if dist is None:
self.parser.error(f"Could not find distribution {args.spec}")
return
distrib_dir = DISTRIBS_BASE_DIR / args.name
os.makedirs(distrib_dir, exist_ok=True)
deps = distribution_dependencies(dist)
if not args.conda_env:
print(f"Using {args.name} as the Conda environment for this distribution")
conda_env = args.conda_env or args.name
config_file = distrib_dir / "config.yaml"
if config_file.exists():
c = DistributionConfig(**yaml.safe_load(config_file.read_text()))
if c.spec != dist.spec_id:
self.parser.error(
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
)
return
if c.conda_env != conda_env:
self.parser.error(
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
)
return
else:
with open(config_file, "w") as f:
c = DistributionConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
assert return_code == 0, cprint(
f"Failed to install distribution {dist.spec_id}", color="red"
)
cprint(
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
color="green",
)

View file

@ -1,81 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class DistributionStart(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama distribution start",
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to start",
required=True,
)
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.registry import resolve_distribution_spec
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = yaml.safe_load(f)
dist = resolve_distribution_spec(config["spec"])
if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
conda_env = config["conda_env"]
if not conda_env:
raise ValueError(
f"Could not find Conda environment for distribution `{args.name}`"
)
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_distribution.sh",
)
args = [script, conda_env, config_file, "--port", str(args.port)] + (
["--disable-ipv6"] if args.disable_ipv6 else []
)
run_with_pty(args)

View file

@ -6,9 +6,9 @@
import argparse import argparse
from .distribution import DistributionParser
from .download import Download from .download import Download
from .model import ModelParser from .model import ModelParser
from .stack import StackParser
class LlamaCLIParser: class LlamaCLIParser:
@ -29,7 +29,7 @@ class LlamaCLIParser:
# Add sub-commands # Add sub-commands
Download.create(subparsers) Download.create(subparsers)
ModelParser.create(subparsers) ModelParser.create(subparsers)
DistributionParser.create(subparsers) StackParser.create(subparsers)
# Import sub-commands from agentic_system if they exist # Import sub-commands from agentic_system if they exist
try: try:

View file

@ -32,6 +32,16 @@ class ModelTemplate(Subcommand):
self._add_arguments() self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd) self.parser.set_defaults(func=self._run_model_template_cmd)
def _prompt_type(self, value):
from llama_models.llama3.api.datatypes import ToolPromptFormat
try:
return ToolPromptFormat(value.lower())
except ValueError:
raise argparse.ArgumentTypeError(
f"{value} is not a valid ToolPromptFormat. Choose from {', '.join(t.value for t in ToolPromptFormat)}"
) from None
def _add_arguments(self): def _add_arguments(self):
self.parser.add_argument( self.parser.add_argument(
"-m", "-m",
@ -46,6 +56,18 @@ class ModelTemplate(Subcommand):
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...", help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
required=False, required=False,
) )
self.parser.add_argument(
"--format",
type=str,
help="ToolPromptFormat (json or function_tag). This flag is used to print the template in a specific formats.",
required=False,
default="json",
)
self.parser.add_argument(
"--raw",
action="store_true",
help="If set to true, don't pretty-print into a table. Useful to copy-paste.",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None: def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
from llama_models.llama3.api.interface import ( from llama_models.llama3.api.interface import (
@ -56,22 +78,32 @@ class ModelTemplate(Subcommand):
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
if args.name: if args.name:
template, tokens_info = render_jinja_template(args.name) tool_prompt_format = self._prompt_type(args.format)
template, tokens_info = render_jinja_template(args.name, tool_prompt_format)
rendered = "" rendered = ""
for tok, is_special in tokens_info: for tok, is_special in tokens_info:
if is_special: if is_special:
rendered += colored(tok, "yellow", attrs=["bold"]) rendered += colored(tok, "yellow", attrs=["bold"])
else: else:
rendered += tok rendered += tok
rendered += "\n"
print_table( if not args.raw:
[ rendered = rendered.replace("\n", "\n")
("Name", colored(template.template_name, "white", attrs=["bold"])), print_table(
("Template", rendered), [
("Notes", template.notes), (
], "Name",
separate_rows=True, colored(template.template_name, "white", attrs=["bold"]),
) ),
("Template", rendered),
("Notes", template.notes),
],
separate_rows=True,
)
else:
print("Template: ", template.template_name)
print("=" * 40)
print(rendered)
else: else:
templates = list_jinja_templates() templates = list_jinja_templates()
headers = ["Role", "Template Name"] headers = ["Role", "Template Name"]

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .config import OllamaImplConfig # noqa from .stack import StackParser # noqa
from .ollama import get_provider_impl # noqa

View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403
def parse_api_provider_tuples(
tuples: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
from llama_toolchain.core.distribution import api_providers
all_providers = api_providers()
deps = {}
for dep in tuples.split(","):
dep = dep.strip()
if not dep:
continue
api_str, provider = dep.split("=")
api = Api(api_str)
provider = provider.strip()
if provider not in all_providers[api]:
parser.error(f"Provider `{provider}` is not available for API `{api}`")
return
deps[api] = all_providers[api][provider]
return deps
class StackBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"build",
prog="llama stack build",
description="Build a Llama stack container",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_build_command)
def _add_arguments(self):
from llama_toolchain.core.distribution_registry import available_distribution_specs
from llama_toolchain.core.package import (
BuildType,
)
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
help="Distribution to build (either \"adhoc\" OR one of: {})".format(allowed_ids),
)
self.parser.add_argument(
"api_providers",
nargs='?',
help="Comma separated list of (api=provider) tuples",
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env)",
required=True,
)
self.parser.add_argument(
"--type",
type=str,
default="conda_env",
choices=[v.value for v in BuildType],
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.core.distribution_registry import resolve_distribution_spec
from llama_toolchain.core.package import (
ApiInput,
BuildType,
build_package,
)
api_inputs = []
if args.distribution == "adhoc":
if not args.api_providers:
self.parser.error("You must specify API providers with (api=provider,...) for building an adhoc distribution")
return
parsed = parse_api_provider_tuples(args.api_providers, self.parser)
for api, provider_spec in parsed.items():
for dep in provider_spec.api_dependencies:
if dep not in parsed:
self.parser.error(f"API {api} needs dependency {dep} provided also")
return
api_inputs.append(
ApiInput(
api=api,
provider=provider_spec.provider_id,
)
)
docker_image = None
else:
if args.api_providers:
self.parser.error("You cannot specify API providers for pre-registered distributions")
return
dist = resolve_distribution_spec(args.distribution)
if dist is None:
self.parser.error(f"Could not find distribution {args.distribution}")
return
for api, provider_id in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
provider=provider_id,
)
)
docker_image = dist.docker_image
build_package(
api_inputs,
build_type=BuildType(args.type),
name=args.name,
distribution_id=args.distribution,
docker_image=docker_image,
)

View file

@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
from pathlib import Path
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.core.datatypes import * # noqa: F403
class StackConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama stack configure",
description="configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_configure_cmd)
def _add_arguments(self):
from llama_toolchain.core.distribution_registry import (
available_distribution_specs,
)
from llama_toolchain.core.package import BuildType
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
choices=allowed_ids,
help="Distribution (one of: {})".format(allowed_ids),
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build",
required=True,
)
self.parser.add_argument(
"--type",
type=str,
default="conda_env",
choices=[v.value for v in BuildType],
)
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.core.package import BuildType
build_type = BuildType(args.type)
name = args.name
config_file = (
BUILDS_BASE_DIR
/ args.distribution
/ build_type.descriptor()
/ f"{name}.yaml"
)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama stack build` first"
)
return
configure_llama_distribution(config_file)
def configure_llama_distribution(config_file: Path) -> None:
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.core.configure import configure_api_providers
from llama_toolchain.core.distribution_registry import resolve_distribution_spec
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config.distribution_id)
if dist is None:
raise ValueError(
f"Could not find any registered distribution `{config.distribution_id}`"
)
if config.providers:
cprint(
f"Configuration already exists for {config.distribution_id}. Will overwrite...",
"yellow",
attrs=["bold"],
)
config.providers = configure_api_providers(config.providers)
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))
print(f"YAML configuration has been written to {config_file}")

View file

@ -10,13 +10,13 @@ import json
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
class DistributionList(Subcommand): class StackList(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction): def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__() super().__init__()
self.parser = subparsers.add_parser( self.parser = subparsers.add_parser(
"list", "list-distributions",
prog="llama distribution list", prog="llama stack list-distributions",
description="Show available llama stack distributions", description="Show available Llama Stack Distributions",
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
self._add_arguments() self._add_arguments()
@ -27,21 +27,23 @@ class DistributionList(Subcommand):
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
from llama_toolchain.distribution.registry import available_distribution_specs from llama_toolchain.core.distribution_registry import (
available_distribution_specs,
)
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [ headers = [
"Spec ID", "Distribution ID",
"ProviderSpecs", "Providers",
"Description", "Description",
] ]
rows = [] rows = []
for spec in available_distribution_specs(): for spec in available_distribution_specs():
providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()} providers = {k.value: v for k, v in spec.providers.items()}
rows.append( rows.append(
[ [
spec.spec_id, spec.distribution_id,
json.dumps(providers, indent=2), json.dumps(providers, indent=2),
spec.description, spec.description,
] ]

View file

@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
class StackRun(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"run",
prog="llama stack run",
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_run_cmd)
def _add_arguments(self):
from llama_toolchain.core.package import BuildType
self.parser.add_argument(
"distribution",
type=str,
help="Distribution whose build you want to start",
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build you want to start",
required=True,
)
self.parser.add_argument(
"--type",
type=str,
default="conda_env",
choices=[v.value for v in BuildType],
)
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.core.package import BuildType
build_type = BuildType(args.type)
build_dir = BUILDS_BASE_DIR / args.distribution / build_type.descriptor()
path = build_dir / f"{args.name}.yaml"
config_file = Path(path)
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist. Did you run `llama stack build`?"
)
return
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
if not config.distribution_id:
raise ValueError("Build config appears to be corrupt.")
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"core/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"core/start_conda_env.sh",
)
run_args = [
script,
config.conda_env,
]
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_with_pty(run_args)

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .build import StackBuild
from .configure import StackConfigure
from .list import StackList
from .run import StackRun
class StackParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"stack",
prog="llama stack",
description="Operations for the Llama Stack / Distributions",
)
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands
StackBuild.create(subparsers)
StackConfigure.create(subparsers)
StackList.create(subparsers)
StackRun.create(subparsers)

View file

@ -13,3 +13,5 @@ LLAMA_STACK_CONFIG_DIR = Path(os.path.expanduser("~/.llama/"))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions" DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints" DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from datetime import datetime
from enum import Enum from enum import Enum
@ -12,4 +13,6 @@ class EnumEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, Enum): if isinstance(obj, Enum):
return obj.value return obj.value
elif isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj) return super().default(obj)

View file

@ -10,20 +10,36 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
fi
set -euo pipefail set -euo pipefail
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <distribution_id> <build_name> <pip_dependencies>" >&2
echo "Example: $0 <distribution_id> mybuild 'numpy pandas scipy'" >&2
exit 1
fi
distribution_id="$1"
build_name="$2"
env_name="llamastack-$build_name"
pip_dependencies="$3"
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'
GREEN='\033[0;32m' GREEN='\033[0;32m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
error_handler() { # this is set if we actually create a new conda in which case we need to clean up
echo "Error occurred in script at line: ${1}" >&2 ENVNAME=""
exit 1
}
# Set up the error trap SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
trap 'error_handler ${LINENO}' ERR source "$SCRIPT_DIR/common.sh"
ensure_conda_env_python310() { ensure_conda_env_python310() {
local env_name="$1" local env_name="$1"
@ -32,26 +48,29 @@ ensure_conda_env_python310() {
# Check if conda command is available # Check if conda command is available
if ! command -v conda &>/dev/null; then if ! command -v conda &>/dev/null; then
echo -e "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1 exit 1
fi fi
# Check if the environment exists # Check if the environment exists
if conda env list | grep -q "^${env_name} "; then if conda env list | grep -q "^${env_name} "; then
echo "Conda environment '${env_name}' exists. Checking Python version..." printf "Conda environment '${env_name}' exists. Checking Python version...\n"
# Check Python version in the environment # Check Python version in the environment
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2) current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then if [ "$current_version" = "$python_version" ]; then
echo "Environment '${env_name}' already has Python ${python_version}. No action needed." printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
else else
echo "Updating environment '${env_name}' to Python ${python_version}..." printf "Updating environment '${env_name}' to Python ${python_version}...\n"
conda install -n "${env_name}" python="${python_version}" -y conda install -n "${env_name}" python="${python_version}" -y
fi fi
else else
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..." printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
conda create -n "${env_name}" python="${python_version}" -y conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
# setup_cleanup_handlers
fi fi
eval "$(conda shell.bash hook)" eval "$(conda shell.bash hook)"
@ -65,48 +84,45 @@ ensure_conda_env_python310() {
# Re-installing llama-toolchain in the new conda environment # Re-installing llama-toolchain in the new conda environment
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
echo -e "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2 printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}\n" >&2
exit 1 exit 1
fi fi
echo "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR" printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR\n"
pip install -e "$LLAMA_TOOLCHAIN_DIR" pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR"
else else
pip install llama-toolchain pip install --no-cache-dir llama-toolchain
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then if [ ! -d "$LLAMA_MODELS_DIR" ]; then
echo -e "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2 printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
exit 1 exit 1
fi fi
echo "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR" printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
pip uninstall -y llama-models pip uninstall -y llama-models
pip install -e "$LLAMA_MODELS_DIR" pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi fi
# Install pip dependencies # Install pip dependencies
if [ -n "$pip_dependencies" ]; then if [ -n "$pip_dependencies" ]; then
echo "Installing pip dependencies: $pip_dependencies" printf "Installing pip dependencies: $pip_dependencies\n"
pip install $pip_dependencies pip install $pip_dependencies
fi fi
fi fi
} }
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <environment_name> <distribution_name> <pip_dependencies>" >&2
echo "Example: $0 my_env local-llama-8b 'numpy pandas scipy'" >&2
exit 1
fi
env_name="$1"
distribution_name="$2"
pip_dependencies="$3"
ensure_conda_env_python310 "$env_name" "$pip_dependencies" ensure_conda_env_python310 "$env_name" "$pip_dependencies"
echo -e "${GREEN}Successfully setup distribution environment. Configuring...${NC}" printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
which python3 if [ "$distribution_id" = "adhoc" ]; then
python3 -m llama_toolchain.cli.llama distribution configure --name "$distribution_name" subcommand="api"
target=""
else
subcommand="stack"
target="$distribution_id"
fi
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type conda_env

View file

@ -0,0 +1,120 @@
#!/bin/bash
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then
echo "Usage: $0 <distribution_id> <build_name> <docker_base> <pip_dependencies>
echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn'
exit 1
fi
distribution_id=$1
build_name="$2"
image_name="llamastack-$build_name"
docker_base=$3
pip_dependencies=$4
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
set -euo pipefail
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
TEMP_DIR=$(mktemp -d)
add_to_docker() {
local input
output_file="$TEMP_DIR/Dockerfile"
if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file"
else
# If stdin is not a terminal, read from it (heredoc)
cat >>"$output_file"
fi
}
add_to_docker <<EOF
FROM $docker_base
WORKDIR /app
RUN apt-get update && apt-get install -y \
iputils-ping net-tools iproute2 dnsutils telnet \
curl wget telnet \
procps psmisc lsof \
traceroute \
&& rm -rf /var/lib/apt/lists/*
EOF
toolchain_mount="/app/llama-toolchain-source"
models_mount="/app/llama-models-source"
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
exit 1
fi
add_to_docker "RUN pip install $toolchain_mount"
else
add_to_docker "RUN pip install llama-toolchain"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
exit 1
fi
add_to_docker <<EOF
RUN pip uninstall -y llama-models
RUN pip install $models_mount
EOF
fi
if [ -n "$pip_dependencies" ]; then
add_to_docker "RUN pip install $pip_dependencies"
fi
add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now
# We need a more solid production ready entrypoint.sh anyway
#
# ENTRYPOINT ["python", "-m", "llama_toolchain.core.server"]
EOF
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
cat $TEMP_DIR/Dockerfile
printf "\n"
mounts=""
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_TOOLCHAIN_DIR):$toolchain_mount"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
set -x
podman build -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
set +x
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
echo "You can run it with: podman run -p 8000:8000 $image_name"
if [ "$distribution_id" = "adhoc" ]; then
subcommand="api"
target=""
else
subcommand="stack"
target="$distribution_id"
fi
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type container

40
llama_toolchain/core/common.sh Executable file
View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
cleanup() {
envname="$1"
set +x
echo "Cleaning up..."
conda deactivate
conda env remove --name $envname -y
}
handle_int() {
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
exit 1
}
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
fi
}
setup_cleanup_handlers() {
trap handle_int INT
trap handle_exit EXIT
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
eval "$__conda_setup"
conda deactivate
}

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_toolchain.core.datatypes import * # noqa: F403
from termcolor import cprint
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.core.distribution import api_providers
from llama_toolchain.core.dynamic import instantiate_class_type
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
all_providers = api_providers()
provider_configs = {}
for api_str, stub_config in existing_configs.items():
api = Api(api_str)
providers = all_providers[api]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
)
provider_spec = providers[provider_id]
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except Exception:
existing_provider_config = None
provider_config = prompt_for_config(
config_type,
existing_provider_config,
)
print("")
provider_configs[api_str] = {
"provider_id": provider_id,
**provider_config.dict(),
}
return provider_configs

View file

@ -0,0 +1,190 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, validator
@json_schema_type
class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
memory = "memory"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
docker_image: Optional[str] = Field(
default=None,
description="""
The docker image to use for this implementation. If one is provided, pip_packages will be ignored.
If a provider depends on other providers, the dependencies MUST NOT specify a docker image.
""",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
class RemoteProviderConfig(BaseModel):
url: str = Field(..., description="The URL for the provider")
@validator("url")
@classmethod
def validate_url(cls, url: str) -> str:
if not url.startswith("http"):
raise ValueError(f"URL must start with http: {url}")
return url.rstrip("/")
def remote_provider_id(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
@property
def docker_image(self) -> Optional[str]:
return None
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_toolchain.{self.api.value}.client"
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
)
@json_schema_type
class DistributionSpec(BaseModel):
distribution_id: str
description: str
docker_image: Optional[str] = None
providers: Dict[Api, str] = Field(
default_factory=dict,
description="Provider IDs for each of the APIs provided by this distribution",
)
@json_schema_type
class PackageConfig(BaseModel):
built_at: datetime
package_name: str = Field(
...,
description="""
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
this could be just a hash
""",
)
distribution_id: Optional[str] = None
docker_image: Optional[str] = Field(
default=None,
description="Reference to the docker image if this package refers to a container",
)
conda_env: Optional[str] = Field(
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
providers: Dict[str, Any] = Field(
default_factory=dict,
description="""
Provider configurations for each of the APIs provided by this package. This includes configurations for
the dependencies of these providers as well.
""",
)

View file

@ -7,11 +7,13 @@
import inspect import inspect
from typing import Dict, List from typing import Dict, List
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api.endpoints import Inference from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.providers import available_inference_providers from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.api.endpoints import Safety from llama_toolchain.memory.api import Memory
from llama_toolchain.memory.providers import available_memory_providers
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.providers import available_safety_providers from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import ( from .datatypes import (
@ -20,6 +22,7 @@ from .datatypes import (
DistributionSpec, DistributionSpec,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec,
) )
# These are the dependencies needed by the distribution server. # These are the dependencies needed by the distribution server.
@ -40,6 +43,10 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
] + SERVER_DEPENDENCIES ] + SERVER_DEPENDENCIES
def stack_apis() -> List[Api]:
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {} apis = {}
@ -47,6 +54,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.inference: Inference, Api.inference: Inference,
Api.safety: Safety, Api.safety: Safety,
Api.agentic_system: AgenticSystem, Api.agentic_system: AgenticSystem,
Api.memory: Memory,
} }
for api, protocol in protocols.items(): for api, protocol in protocols.items():
@ -60,9 +68,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
webmethod = method.__webmethod__ webmethod = method.__webmethod__
route = webmethod.route route = webmethod.route
# use `post` for all methods right now until we fix up the `webmethod` openapi if webmethod.method == "GET":
# annotation and write our own openapi generator method = "get"
endpoints.append(ApiEndpoint(route=route, method="post", name=name)) elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints apis[api] = endpoints
@ -78,8 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
a.provider_id: a for a in available_agentic_system_providers() a.provider_id: a for a in available_agentic_system_providers()
} }
return { ret = {
Api.inference: inference_providers_by_id, Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id, Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id, Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
} }
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)
return ret

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from functools import lru_cache
from typing import List, Optional
from .datatypes import * # noqa: F403
@lru_cache()
def available_distribution_specs() -> List[DistributionSpec]:
return [
DistributionSpec(
distribution_id="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
providers={
Api.inference: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
},
),
DistributionSpec(
distribution_id="remote",
description="Point to remote services for all llama stack APIs",
providers={x: "remote" for x in Api},
),
DistributionSpec(
distribution_id="local-ollama",
description="Like local, but use ollama for running LLM inference",
providers={
Api.inference: remote_provider_id("ollama"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
distribution_id="local-plus-fireworks-inference",
description="Use Fireworks.ai for running LLM inference",
providers={
Api.inference: remote_provider_id("fireworks"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
distribution_id="local-plus-together-inference",
description="Use Together.ai for running LLM inference",
providers={
Api.inference: remote_provider_id("together"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
]
@lru_cache()
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.distribution_id == distribution_id:
return spec
return None

View file

@ -8,7 +8,7 @@ import asyncio
import importlib import importlib
from typing import Any, Dict from typing import Any, Dict
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec from .datatypes import ProviderSpec, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name): def instantiate_class_type(fully_qualified_name):
@ -19,18 +19,24 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the Api # returns a class implementing the protocol corresponding to the Api
def instantiate_provider( def instantiate_provider(
provider_spec: InlineProviderSpec, provider_spec: ProviderSpec,
provider_config: Dict[str, Any], provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec], deps: Dict[str, ProviderSpec],
): ):
module = importlib.import_module(provider_spec.module) module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
else:
method = "get_provider_impl"
config = config_type(**provider_config) config = config_type(**provider_config)
return asyncio.run(module.get_provider_impl(config, deps)) fn = getattr(module, method)
impl = asyncio.run(fn(config, deps))
impl.__provider_spec__ = provider_spec
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str): impl.__provider_config__ = config
module = importlib.import_module(provider_spec.module) return impl
return asyncio.run(module.get_client_impl(base_url))

View file

@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
from enum import Enum
from typing import List, Optional
import pkg_resources
import yaml
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
def descriptor(self) -> str:
return "docker" if self == self.container else "conda"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
class ApiInput(BaseModel):
api: Api
provider: str
def build_package(
api_inputs: List[ApiInput],
build_type: BuildType,
name: str,
distribution_id: Optional[str] = None,
docker_image: Optional[str] = None,
):
if not distribution_id:
distribution_id = "adhoc"
build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor()
os.makedirs(build_dir, exist_ok=True)
package_name = name.replace("::", "-")
package_file = build_dir / f"{package_name}.yaml"
all_providers = api_providers()
package_deps = Dependencies(
docker_image=docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)
stub_config = {}
for api_input in api_inputs:
api = api_input.api
providers_for_api = all_providers[api]
if api_input.provider not in providers_for_api:
raise ValueError(
f"Provider `{api_input.provider}` is not available for API `{api}`"
)
provider = providers_for_api[api_input.provider]
package_deps.pip_packages.extend(provider.pip_packages)
if provider.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
stub_config[api.value] = {"provider_id": api_input.provider}
if package_file.exists():
cprint(
f"Build `{package_name}` exists; will reconfigure",
color="yellow",
)
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
for api_str, new_config in stub_config.items():
if api_str not in c.providers:
c.providers[api_str] = new_config
else:
existing_config = c.providers[api_str]
if existing_config["provider_id"] != new_config["provider_id"]:
cprint(
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
color="yellow",
)
c.providers[api_str] = new_config
else:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
providers=stub_config,
)
c.distribution_id = distribution_id
c.docker_image = package_name if build_type == BuildType.container else None
c.conda_env = package_name if build_type == BuildType.conda_env else None
with open(package_file, "w") as f:
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
if build_type == BuildType.container:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_container.sh"
)
args = [
script,
distribution_id,
package_name,
package_deps.docker_image,
" ".join(package_deps.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_conda_env.sh"
)
args = [
script,
distribution_id,
package_name,
" ".join(package_deps.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {package_name} with return code {return_code}",
color="red",
)
return
cprint(
f"Target `{package_name}` built with configuration at {str(package_file)}",
color="green",
)

View file

@ -5,8 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import inspect
import json import json
import signal import signal
import traceback
from collections.abc import ( from collections.abc import (
AsyncGenerator as AsyncGeneratorABC, AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC, AsyncIterator as AsyncIteratorABC,
@ -28,18 +30,17 @@ import fire
import httpx import httpx
import yaml import yaml
from fastapi import FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_client, instantiate_provider from .dynamic import instantiate_provider
from .registry import resolve_distribution_spec
def is_async_iterator_type(typ): def is_async_iterator_type(typ):
@ -66,6 +67,7 @@ def create_sse_event(data: Any) -> str:
async def global_exception_handler(request: Request, exc: Exception): async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc) http_exc = translate_exception(exc)
return JSONResponse( return JSONResponse(
@ -155,9 +157,8 @@ def create_dynamic_passthrough(
return endpoint return endpoint
def create_dynamic_typed_route(func: Any): def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func) hints = get_type_hints(func)
request_model = next(iter(hints.values()))
response_model = hints["return"] response_model = hints["return"]
# NOTE: I think it is better to just add a method within each Api # NOTE: I think it is better to just add a method within each Api
@ -168,7 +169,7 @@ def create_dynamic_typed_route(func: Any):
if is_streaming: if is_streaming:
async def endpoint(request: request_model): async def endpoint(**kwargs):
async def sse_generator(event_gen): async def sse_generator(event_gen):
try: try:
async for item in event_gen: async for item in event_gen:
@ -178,10 +179,7 @@ def create_dynamic_typed_route(func: Any):
print("Generator cancelled") print("Generator cancelled")
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
print(e) traceback.print_exception(e)
import traceback
traceback.print_exc()
yield create_sse_event( yield create_sse_event(
{ {
"error": { "error": {
@ -191,25 +189,38 @@ def create_dynamic_typed_route(func: Any):
) )
return StreamingResponse( return StreamingResponse(
sse_generator(func(request)), media_type="text/event-stream" sse_generator(func(**kwargs)), media_type="text/event-stream"
) )
else: else:
async def endpoint(request: request_model): async def endpoint(**kwargs):
try: try:
return ( return (
await func(request) await func(**kwargs)
if asyncio.iscoroutinefunction(func) if asyncio.iscoroutinefunction(func)
else func(request) else func(**kwargs)
) )
except Exception as e: except Exception as e:
print(e) traceback.print_exception(e)
import traceback
traceback.print_exc()
raise translate_exception(e) from e raise translate_exception(e) from e
sig = inspect.signature(func)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
return endpoint return endpoint
@ -219,10 +230,9 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api) visited.add(a.api)
if not isinstance(a, RemoteProviderSpec): for api in a.api_dependencies:
for api in a.api_dependencies: if api not in visited:
if api not in visited: dfs(by_id[api], visited, stack)
dfs(by_id[api], visited, stack)
stack.append(a.api) stack.append(a.api)
@ -236,9 +246,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack] return [by_id[x] for x in stack]
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]: def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"] provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values()) provider_specs = topological_sort(provider_specs.values())
impls = {} impls = {}
for provider_spec in provider_specs: for provider_spec in provider_specs:
@ -248,15 +260,13 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
f"Could not find provider_spec config for {api}. Please add it to the config" f"Could not find provider_spec config for {api}. Please add it to the config"
) )
provider_config = provider_configs[api.value] if isinstance(provider_spec, InlineProviderSpec):
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_config["base_url"].rstrip("/")
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies} deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps) else:
impls[api] = impl deps = {}
provider_config = provider_configs[api.value]
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls return impls
@ -265,24 +275,36 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp: with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp) config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI() app = FastAPI()
all_endpoints = api_endpoints() all_endpoints = api_endpoints()
impls = resolve_impls(dist, config) all_providers = api_providers()
for provider_spec in dist.provider_specs.values(): provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
for provider_spec in provider_specs.values():
api = provider_spec.api api = provider_spec.api
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
impl = impls[api] impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec): if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints: for endpoint in endpoints:
url = impl.base_url + endpoint.route url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)( getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url) create_dynamic_passthrough(url)
) )
@ -296,7 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)( getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method) create_dynamic_typed_route(impl_method, endpoint.method)
) )
for route in app.routes: for route in app.routes:
@ -307,6 +329,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
attrs=["bold"], attrs=["bold"],
) )
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, handle_sigint)

View file

@ -8,7 +8,6 @@
set -euo pipefail set -euo pipefail
# Define color codes
RED='\033[0;31m' RED='\033[0;31m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
@ -17,20 +16,27 @@ error_handler() {
exit 1 exit 1
} }
# Set up the error trap
trap 'error_handler ${LINENO}' ERR trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then if [ $# -lt 3 ]; then
echo "Usage: $0 <environment_name> <script_args...>" echo "Usage: $0 <build_name> <yaml_config> <port> <script_args...>"
exit 1 exit 1
fi fi
build_name="$1"
env_name="llamastack-$build_name"
shift
env_name="$1" yaml_config="$1"
shift
port="$1"
shift shift
eval "$(conda shell.bash hook)" eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name" conda deactivate && conda activate "$env_name"
python_interp=$(conda run -n "$env_name" which python) $CONDA_PREFIX/bin/python \
$python_interp -m llama_toolchain.distribution.server "$@" -m llama_toolchain.core.server \
--yaml_config "$yaml_config" \
--port "$port" "$@"

View file

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <build_name> <yaml_config> <port> <other_args...>"
exit 1
fi
build_name="$1"
docker_image="llamastack-$build_name"
shift
yaml_config="$1"
shift
port="$1"
shift
set -x
podman run -it \
-p $port:$port \
-v "$yaml_config:/app/config.yaml" \
$docker_image \
python -m llama_toolchain.core.server \
--yaml_config /app/config.yaml \
--port $port "$@"

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .api import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -4,13 +4,34 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Protocol from enum import Enum
from typing import Any, Dict, Optional, Protocol
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from .datatypes import * # noqa: F403
@json_schema_type
class TrainEvalDatasetColumnType(Enum):
dialog = "dialog"
text = "text"
media = "media"
number = "number"
json = "json"
@json_schema_type
class TrainEvalDataset(BaseModel):
"""Dataset to be used for training or evaluating language models."""
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
columns: Dict[str, TrainEvalDatasetColumnType]
content_url: URL
metadata: Optional[Dict[str, Any]] = None
@json_schema_type @json_schema_type

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class TrainEvalDatasetColumnType(Enum):
dialog = "dialog"
text = "text"
media = "media"
number = "number"
json = "json"
@json_schema_type
class TrainEvalDataset(BaseModel):
"""Dataset to be used for training or evaluating language models."""
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
columns: Dict[str, TrainEvalDatasetColumnType]
content_url: URL
metadata: Optional[Dict[str, Any]] = None

View file

@ -1,106 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, validator
@json_schema_type
class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
class RemoteProviderConfig(BaseModel):
base_url: str = Field(..., description="The base URL for the llama stack provider")
api_key: Optional[str] = Field(
..., description="API key, if needed, for the provider"
)
@validator("base_url")
@classmethod
def validate_base_url(cls, base_url: str) -> str:
if not base_url.startswith("http"):
raise ValueError(f"URL must start with http: {base_url}")
return base_url
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
""",
)
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
@json_schema_type
class DistributionSpec(BaseModel):
spec_id: str
description: str
provider_specs: Dict[Api, ProviderSpec] = Field(
default_factory=dict,
description="Provider specifications for each of the APIs provided by this distribution",
)
@json_schema_type
class DistributionConfig(BaseModel):
"""References to a installed / configured DistributionSpec"""
name: str
spec: str
conda_env: str
providers: Dict[str, Any] = Field(
default_factory=dict,
description="Provider configurations for each of the APIs provided by this distribution",
)

View file

@ -1,79 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from functools import lru_cache
from typing import List, Optional
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
from .distribution import api_providers
def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client"
def remote_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_id=f"{api.value}-remote",
module=client_module(api),
)
@lru_cache()
def available_distribution_specs() -> List[DistributionSpec]:
providers = api_providers()
return [
DistributionSpec(
spec_id="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
provider_specs={
Api.inference: providers[Api.inference]["meta-reference"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
DistributionSpec(
spec_id="remote",
description="Point to remote services for all llama stack APIs",
provider_specs={x: remote_spec(x) for x in providers},
),
DistributionSpec(
spec_id="local-ollama",
description="Like local, but use ollama for running LLM inference",
provider_specs={
Api.inference: providers[Api.inference]["meta-ollama"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
DistributionSpec(
spec_id="remote-fireworks",
description="Use Fireworks.ai for running LLM inference",
provider_specs={
Api.inference: providers[Api.inference]["fireworks"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
DistributionSpec(
spec_id="remote-together",
description="Use Together.ai for running LLM inference",
provider_specs={
Api.inference: providers[Api.inference]["together"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
]
@lru_cache()
def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.spec_id == spec_id:
return spec
return None

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .api import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import List, Protocol from typing import List, Protocol
from llama_models.schema_utils import webmethod from llama_models.schema_utils import webmethod
@ -11,11 +12,34 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403 from llama_toolchain.dataset.api import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel):
job_uuid: str
class EvaluationJobLogStream(BaseModel):
job_uuid: str
class EvaluateTaskRequestCommon(BaseModel): class EvaluateTaskRequestCommon(BaseModel):
job_uuid: str job_uuid: str
dataset: TrainEvalDataset dataset: TrainEvalDataset

View file

@ -1,33 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel):
job_uuid: str
class EvaluationJobLogStream(BaseModel):
job_uuid: str

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FireworksImplConfig
async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference:
from .fireworks import FireworksInferenceAdapter
assert isinstance(
config, FireworksImplConfig
), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -5,9 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import uuid import uuid
from typing import AsyncGenerator, Dict from typing import AsyncGenerator
import httpx from fireworks.client import Fireworks
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
@ -18,20 +18,8 @@ from llama_models.llama3.api.datatypes import (
) )
from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from fireworks.client import Fireworks
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from .config import FireworksImplConfig from .config import FireworksImplConfig
@ -42,18 +30,7 @@ FIREWORKS_SUPPORTED_MODELS = {
} }
async def get_provider_impl( class FireworksInferenceAdapter(Inference):
config: FireworksImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, FireworksImplConfig
), f"Unexpected config type: {type(config)}"
impl = FireworksInference(config)
await impl.initialize()
return impl
class FireworksInference(Inference):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
self.config = config self.config = config

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_toolchain.core.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl

View file

@ -4,63 +4,37 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import uuid from typing import AsyncGenerator
from typing import AsyncGenerator, Dict
import httpx import httpx
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.chat_format import ChatFormat
BuiltinTool, from llama_models.llama3.api.datatypes import Message, StopReason
CompletionMessage, from llama_models.llama3.api.tokenizer import Tokenizer
Message,
StopReason,
ToolCall,
)
from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from ollama import AsyncClient from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.api import ( from llama_toolchain.inference.prepare_messages import prepare_messages
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from .config import OllamaImplConfig
# TODO: Eventually this will move to the llama cli model list command # TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = { OLLAMA_SUPPORTED_SKUS = {
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
} }
async def get_provider_impl( class OllamaInferenceAdapter(Inference):
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec] def __init__(self, url: str) -> None:
) -> Inference: self.url = url
assert isinstance( tokenizer = Tokenizer.get_instance()
config, OllamaImplConfig self.formatter = ChatFormat(tokenizer)
), f"Unexpected config type: {type(config)}"
impl = OllamaInference(config)
await impl.initialize()
return impl
class OllamaInference(Inference):
def __init__(self, config: OllamaImplConfig) -> None:
self.config = config
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
return AsyncClient(host=self.config.url) return AsyncClient(host=self.url)
async def initialize(self) -> None: async def initialize(self) -> None:
try: try:
@ -111,6 +85,7 @@ class OllamaInference(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
messages = prepare_messages(request)
# accumulate sampling params and other options to pass to ollama # accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request) options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model) ollama_model = self.resolve_ollama_model(request.model)
@ -132,7 +107,7 @@ class OllamaInference(Inference):
if not request.stream: if not request.stream:
r = await self.client.chat( r = await self.client.chat(
model=ollama_model, model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages), messages=self._messages_to_ollama_messages(messages),
stream=False, stream=False,
options=options, options=options,
) )
@ -143,9 +118,8 @@ class OllamaInference(Inference):
elif r["done_reason"] == "length": elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
completion_message = decode_assistant_message_from_content( completion_message = self.formatter.decode_assistant_message_from_content(
r["message"]["content"], r["message"]["content"], stop_reason
stop_reason,
) )
yield ChatCompletionResponse( yield ChatCompletionResponse(
completion_message=completion_message, completion_message=completion_message,
@ -160,7 +134,7 @@ class OllamaInference(Inference):
) )
stream = await self.client.chat( stream = await self.client.chat(
model=ollama_model, model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages), messages=self._messages_to_ollama_messages(messages),
stream=True, stream=True,
options=options, options=options,
) )
@ -228,7 +202,9 @@ class OllamaInference(Inference):
) )
# parse tool calls and report errors # parse tool calls and report errors
message = decode_assistant_message_from_content(buffer, stop_reason) message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
@ -261,70 +237,3 @@ class OllamaInference(Inference):
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
# TODO: Consolidate this with impl in llama-models
def decode_assistant_message_from_content(
content: str,
stop_reason: StopReason,
) -> CompletionMessage:
ipython = content.startswith("<|python_tag|>")
if ipython:
content = content[len("<|python_tag|>") :]
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
)
)
content = ""
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
return CompletionMessage(
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig
async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference:
from .together import TogetherInferenceAdapter
assert isinstance(
config, TogetherImplConfig
), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import uuid import uuid
from typing import AsyncGenerator, Dict from typing import AsyncGenerator
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
@ -18,18 +18,7 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from together import Together from together import Together
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from .config import TogetherImplConfig from .config import TogetherImplConfig
@ -40,18 +29,7 @@ TOGETHER_SUPPORTED_MODELS = {
} }
async def get_provider_impl( class TogetherInferenceAdapter(Inference):
config: TogetherImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, TogetherImplConfig
), f"Unexpected config type: {type(config)}"
impl = TogetherInference(config)
await impl.initialize()
return impl
class TogetherInference(Inference):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
self.config = config self.config = config

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .api import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -4,17 +4,79 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F403 from enum import Enum
from typing import Optional, Protocol
# this dependency is annoying and we need a forked up version anyway from typing import List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import webmethod
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="type"),
]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None
@json_schema_type @json_schema_type
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
content: InterleavedTextAttachment content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
stream: Optional[bool] = False stream: Optional[bool] = False
@ -39,7 +101,7 @@ class CompletionResponseStreamChunk(BaseModel):
@json_schema_type @json_schema_type
class BatchCompletionRequest(BaseModel): class BatchCompletionRequest(BaseModel):
model: str model: str
content_batch: List[InterleavedTextAttachment] content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -56,7 +118,11 @@ class ChatCompletionRequest(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -82,8 +148,11 @@ class BatchChatCompletionRequest(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -92,6 +161,11 @@ class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage] completion_message_batch: List[CompletionMessage]
@json_schema_type
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class Inference(Protocol): class Inference(Protocol):
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( async def completion(
@ -105,14 +179,9 @@ class Inference(Protocol):
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/batch_completion") @webmethod(route="/inference/embeddings")
async def batch_completion( async def embeddings(
self, self,
request: BatchCompletionRequest, model: str,
) -> BatchCompletionResponse: ... contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ...

View file

@ -1,72 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="type"),
]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None

View file

@ -6,12 +6,15 @@
import asyncio import asyncio
import json import json
from typing import AsyncGenerator from typing import Any, AsyncGenerator
import fire import fire
import httpx import httpx
from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig
from .api import ( from .api import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -23,13 +26,16 @@ from .api import (
from .event_logger import EventLogger from .event_logger import EventLogger
async def get_client_impl(base_url: str): async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(base_url) return InferenceClient(config.url)
def encodable_dict(d: BaseModel):
return json.loads(d.json())
class InferenceClient(Inference): class InferenceClient(Inference):
def __init__(self, base_url: str): def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
self.base_url = base_url self.base_url = base_url
async def initialize(self) -> None: async def initialize(self) -> None:
@ -46,7 +52,9 @@ class InferenceClient(Inference):
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/inference/chat_completion", f"{self.base_url}/inference/chat_completion",
data=request.json(), json={
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as response: ) as response:

View file

@ -5,4 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa from .config import MetaReferenceImplConfig # noqa
from .inference import get_provider_impl # noqa
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_toolchain.inference.api import QuantizationConfig
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from llama_toolchain.inference.api import QuantizationConfig
@json_schema_type @json_schema_type
class MetaReferenceImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):

View file

@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
) )
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
@ -279,6 +279,7 @@ class Llama:
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator: ) -> Generator:
if ( if (
max_gen_len is None max_gen_len is None
@ -288,7 +289,10 @@ class Llama:
max_gen_len = self.model.params.max_seq_len - 1 max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate( yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(messages), model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
),
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,

View file

@ -6,12 +6,11 @@
import asyncio import asyncio
from typing import AsyncIterator, Dict, Union from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import ( from llama_toolchain.inference.api import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -22,23 +21,11 @@ from llama_toolchain.inference.api import (
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
) )
from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
async def get_provider_impl(
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process. # we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
@ -67,6 +54,7 @@ class MetaReferenceInferenceImpl(Inference):
) -> AsyncIterator[ ) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]: ]:
messages = prepare_messages(request)
model = resolve_model(request.model) model = resolve_model(request.model)
if model is None: if model is None:
raise RuntimeError( raise RuntimeError(
@ -98,11 +86,12 @@ class MetaReferenceInferenceImpl(Inference):
ipython = False ipython = False
for token_result in self.generator.chat_completion( for token_result in self.generator.chat_completion(
messages=request.messages, messages=messages,
temperature=request.sampling_params.temperature, temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p, top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens, max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs, logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
): ):
buffer += token_result.text buffer += token_result.text
tokens.append(token_result.token) tokens.append(token_result.token)

View file

@ -11,7 +11,7 @@ from functools import partial
from typing import Generator, List, Optional from typing import Generator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
@ -27,6 +27,7 @@ class InferenceArgs:
top_p: float top_p: float
max_gen_len: int max_gen_len: int
logprobs: bool logprobs: bool
tool_prompt_format: ToolPromptFormat
class ModelRunner: class ModelRunner:
@ -41,6 +42,7 @@ class ModelRunner:
task.top_p, task.top_p,
task.max_gen_len, task.max_gen_len,
task.logprobs, task.logprobs,
task.tool_prompt_format,
) )
@ -93,6 +95,7 @@ class LlamaModelParallelGenerator:
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator: ) -> Generator:
req_obj = InferenceArgs( req_obj = InferenceArgs(
messages=deepcopy(messages), messages=deepcopy(messages),
@ -100,6 +103,7 @@ class LlamaModelParallelGenerator:
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
logprobs=logprobs, logprobs=logprobs,
tool_prompt_format=tool_prompt_format,
) )
gen = self.group.run_inference(req_obj) gen = self.group.run_inference(req_obj)

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec from llama_toolchain.core.datatypes import * # noqa: F403
def available_inference_providers() -> List[ProviderSpec]: def available_inference_providers() -> List[ProviderSpec]:
@ -27,14 +27,13 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference", module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig", config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
), ),
InlineProviderSpec( remote_provider_spec(
api=Api.inference, api=Api.inference,
provider_id="meta-ollama", adapter=AdapterSpec(
pip_packages=[ adapter_id="ollama",
"ollama", pip_packages=["ollama"],
], module="llama_toolchain.inference.adapters.ollama",
module="llama_toolchain.inference.ollama", ),
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,

View file

@ -14,12 +14,12 @@ import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api import QuantizationType
from llama_toolchain.inference.api.config import ( from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
MetaReferenceImplConfig, MetaReferenceImplConfig,
) )
from llama_toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint from termcolor import cprint
from torch import Tensor from torch import Tensor

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .api import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type
class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedTextMedia | URL
mime_type: str
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
class KeyValueMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
class KeywordMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
document_id: str
@json_schema_type
class QueryDocumentsResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
@json_schema_type
class QueryAPI(Protocol):
@webmethod(route="/query_documents")
def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@json_schema_type
class MemoryBank(BaseModel):
bank_id: str
name: str
config: MemoryBankConfig
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
url: Optional[URL] = None
class Memory(Protocol):
@webmethod(route="/memory_banks/create")
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank: ...
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory_banks/drop", method="DELETE")
async def drop_memory_bank(
self,
bank_id: str,
) -> str: ...
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory_bank/insert")
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory_bank/update")
async def update_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/query")
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/documents/get", method="GET")
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
async def delete_documents(
self,
bank_id: str,
document_ids: List[str],
) -> None: ...

View file

@ -1,25 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class MemoryBank(BaseModel):
memory_bank_id: str
memory_bank_name: str
@json_schema_type
class MemoryBankDocument(BaseModel):
document_id: str
content: bytes
metadata: Dict[str, Any]
mime_type: str

View file

@ -1,61 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol
from llama_models.schema_utils import webmethod
from .datatypes import * # noqa: F403
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/create")
def create_memory_bank(
self,
bank_id: str,
bank_name: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_banks/list")
def get_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/drop")
def delete_memory_bank(
self,
bank_id: str,
) -> str: ...
@webmethod(route="/memory_bank/insert")
def insert_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/update")
def update_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/get")
def get_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/delete")
def delete_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[str]: ...

View file

@ -0,0 +1,181 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any, Dict, List, Optional
import fire
import httpx
from llama_toolchain.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
return MemoryClient(config.url)
class MemoryClient(Memory):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_banks/create",
json={
"name": name,
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_bank/insert",
json={
"bank_id": bank_id,
"documents": [d.dict() for d in documents],
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_bank/query",
json={
"bank_id": bank_id,
"query": query,
"params": params,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}")
# create a memory bank
bank = await client.create_memory_bank(
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="dragon-roberta-query-2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
print(bank)
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
content=URL(
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
),
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"How do I use Lora?",
],
)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"Tell me more about llama3 and torchtune",
],
)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps):
from .faiss import FaissMemoryImpl
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config)
await impl.initialize()
return impl

View file

@ -5,12 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from pydantic import BaseModel
@json_schema_type @json_schema_type
class OllamaImplConfig(BaseModel): class FaissImplConfig(BaseModel): ...
url: str = Field(
default="http://localhost:11434",
description="The URL for the ollama server",
)

View file

@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import faiss
import httpx
import numpy as np
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.memory.api import * # noqa: F403
from .config import FaissImplConfig
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
return interleaved_text_media_as_str(doc.content)
def make_overlapped_chunks(
text: str, window_len: int, overlap_len: int
) -> List[Tuple[str, int]]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunks.append((chunk, len(toks)))
return chunks
@dataclass
class BankState:
bank: MemoryBank
index: Optional[faiss.IndexFlatL2] = None
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
id_by_index: Dict[int, str] = field(default_factory=dict)
chunk_by_index: Dict[int, str] = field(default_factory=dict)
async def insert_documents(
self,
model: "SentenceTransformer",
documents: List[MemoryBankDocument],
) -> None:
tokenizer = Tokenizer.get_instance()
chunk_size = self.bank.config.chunk_size_in_tokens
for doc in documents:
indexlen = len(self.id_by_index)
self.doc_by_id[doc.document_id] = doc
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
await self._ensure_index(embeddings.shape[1])
self.index.add(embeddings)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = Chunk(
content=chunk[0],
token_count=chunk[1],
document_id=doc.document_id,
)
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id
async def query_documents(
self,
model: "SentenceTransformer",
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
query_vector = model.encode([query_str])[0]
distances, indices = self.index.search(
query_vector.reshape(1, -1).astype(np.float32), k
)
chunks = []
scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
if self.index is None:
self.index = faiss.IndexFlatL2(dimension)
return self.index
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.model = None
self.states = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
assert url is None, "URL is not supported for this implementation"
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
state = BankState(bank=bank)
self.states[bank_id] = state
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
if bank_id not in self.states:
return None
return self.states[bank_id].bank
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
await state.insert_documents(self.get_model(), documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
return await state.query_documents(self.get_model(), query, params)
def get_model(self) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer
if self.model is None:
print("Loading sentence transformer")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
return self.model

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_memory_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
provider_id="meta-reference-faiss",
pip_packages=[
"blobfile",
"faiss-cpu",
"sentence-transformers",
],
module="llama_toolchain.memory.meta_reference.faiss",
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
),
]

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .api import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -5,12 +5,79 @@
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
@json_schema_type
class ExperimentStatus(Enum):
NOT_STARTED = "not_started"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@json_schema_type
class Experiment(BaseModel):
id: str
name: str
status: ExperimentStatus
created_at: datetime
updated_at: datetime
metadata: Dict[str, Any]
@json_schema_type
class Run(BaseModel):
id: str
experiment_id: str
status: str
started_at: datetime
ended_at: Optional[datetime]
metadata: Dict[str, Any]
@json_schema_type
class Metric(BaseModel):
name: str
value: Union[float, int, str, bool]
timestamp: datetime
run_id: str
@json_schema_type
class Log(BaseModel):
message: str
level: str
timestamp: datetime
additional_info: Dict[str, Any]
@json_schema_type
class ArtifactType(Enum):
MODEL = "model"
DATASET = "dataset"
CHECKPOINT = "checkpoint"
PLOT = "plot"
METRIC = "metric"
CONFIG = "config"
CODE = "code"
OTHER = "other"
@json_schema_type
class Artifact(BaseModel):
id: str
name: str
type: ArtifactType
size: int
created_at: datetime
metadata: Dict[str, Any]
@json_schema_type @json_schema_type

View file

@ -1,80 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional, Union
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class ExperimentStatus(Enum):
NOT_STARTED = "not_started"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@json_schema_type
class Experiment(BaseModel):
id: str
name: str
status: ExperimentStatus
created_at: datetime
updated_at: datetime
metadata: Dict[str, Any]
@json_schema_type
class Run(BaseModel):
id: str
experiment_id: str
status: str
started_at: datetime
ended_at: Optional[datetime]
metadata: Dict[str, Any]
@json_schema_type
class Metric(BaseModel):
name: str
value: Union[float, int, str, bool]
timestamp: datetime
run_id: str
@json_schema_type
class Log(BaseModel):
message: str
level: str
timestamp: datetime
additional_info: Dict[str, Any]
@json_schema_type
class ArtifactType(Enum):
MODEL = "model"
DATASET = "dataset"
CHECKPOINT = "checkpoint"
PLOT = "plot"
METRIC = "metric"
CONFIG = "config"
CODE = "code"
OTHER = "other"
@json_schema_type
class Artifact(BaseModel):
id: str
name: str
type: ArtifactType
size: int
created_at: datetime
metadata: Dict[str, Any]

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .api import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol from typing import Any, Dict, List, Optional, Protocol
@ -13,9 +14,90 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.dataset.api import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403
from .datatypes import * # noqa: F403
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
lr_min: float
weight_decay: float
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
batch_size: int
shuffle: bool
n_iters: int
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: bool
fsdp_cpu_offload: bool
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
@json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
reward_scale: float
reward_clip: float
epsilon: float
gamma: float
@json_schema_type @json_schema_type

View file

@ -1,94 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
lr_min: float
weight_decay: float
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
batch_size: int
shuffle: bool
n_iters: int
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: bool
fsdp_cpu_offload: bool
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
@json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
reward_scale: float
reward_clip: float
epsilon: float
gamma: float

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .api import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -5,9 +5,30 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Protocol, Union from typing import List, Protocol, Union
from .datatypes import * # noqa: F403
from llama_models.schema_utils import webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type
class ScoredMessage(BaseModel):
message: Message
score: float
@json_schema_type
class DialogGenerations(BaseModel):
dialog: List[Message]
sampled_generations: List[Message]
@json_schema_type
class ScoredDialogGenerations(BaseModel):
dialog: List[Message]
scored_generations: List[ScoredMessage]
@json_schema_type @json_schema_type

View file

@ -1,31 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type
class ScoredMessage(BaseModel):
message: Message
score: float
@json_schema_type
class DialogGenerations(BaseModel):
dialog: List[Message]
sampled_generations: List[Message]
@json_schema_type
class ScoredDialogGenerations(BaseModel):
dialog: List[Message]
scored_generations: List[ScoredMessage]

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa from .api import * # noqa: F401 F403
from .endpoints import * # noqa

View file

@ -5,13 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Dict, Optional, Union from typing import Dict, List, Optional, Protocol, Union
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@ -70,3 +69,22 @@ class ShieldResponse(BaseModel):
except ValueError: except ValueError:
return v return v
return v return v
@json_schema_type
class RunShieldRequest(BaseModel):
messages: List[Message]
shields: List[ShieldDefinition]
@json_schema_type
class RunShieldResponse(BaseModel):
responses: List[ShieldResponse]
class Safety(Protocol):
@webmethod(route="/safety/run_shields")
async def run_shields(
self,
request: RunShieldRequest,
) -> RunShieldResponse: ...

View file

@ -1,32 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F403
from typing import List, Protocol
from llama_models.llama3.api.datatypes import Message
# this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import webmethod
@json_schema_type
class RunShieldRequest(BaseModel):
messages: List[Message]
shields: List[ShieldDefinition]
@json_schema_type
class RunShieldResponse(BaseModel):
responses: List[ShieldResponse]
class Safety(Protocol):
@webmethod(route="/safety/run_shields")
async def run_shields(
self,
request: RunShieldRequest,
) -> RunShieldResponse: ...

Some files were not shown because too many files have changed in this diff Show more