mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
API Updates: fleshing out RAG APIs, introduce "llama stack" CLI command (#51)
* add tools to chat completion request * use templates for generating system prompts * Moved ToolPromptFormat and jinja templates to llama_models.llama3.api * <WIP> memory changes - inlined AgenticSystemInstanceConfig so API feels more ergonomic - renamed it to AgentConfig, AgentInstance -> Agent - added a MemoryConfig and `memory` parameter - added `attachments` to input and `output_attachments` to the response - some naming changes * InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool * flesh out memory banks API * agentic loop has a RAG implementation * faiss provider implementation * memory client works * re-work tool definitions, fix FastAPI issues, fix tool regressions * fix agentic_system utils * basic RAG seems to work * small bug fixes for inline attachments * Refactor custom tool execution utilities * Bug fix, show memory retrieval steps in EventLogger * No need for api_key for Remote providers * add special unicode character ↵ to showcase newlines in model prompt templates * remove api.endpoints imports * combine datatypes.py and endpoints.py into api.py * Attachment / add TTL api * split batch_inference from inference * minor import fixes * use a single impl for ChatFormat.decode_assistant_mesage * use interleaved_text_media_as_str() utilityt * Fix api.datatypes imports * Add blobfile for tiktoken * Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly * templates take optional --format={json,function_tag} * Rag Updates * Add `api build` subcommand -- WIP * fix * build + run image seems to work * <WIP> adapters * bunch more work to make adapters work * api build works for conda now * ollama remote adapter works * Several smaller fixes to make adapters work Also, reorganized the pattern of __init__ inside providers so configuration can stay lightweight * llama distribution -> llama stack + containers (WIP) * All the new CLI for api + stack work * Make Fireworks and Together into the Adapter format * Some quick fixes to the CLI behavior to make it consistent * Updated README phew * Update cli_reference.md * llama_toolchain/distribution -> llama_toolchain/core * Add termcolor * update paths * Add a log just for consistency * chmod +x scripts * Fix api dependencies not getting added to configuration * missing import lol * Delete utils.py; move to agentic system * Support downloading of URLs for attachments for code interpreter * Simplify and generalize `llama api build` yay * Update `llama stack configure` to be very simple also * Fix stack start * Allow building an "adhoc" distribution * Remote `llama api []` subcommands * Fixes to llama stack commands and update docs * Update documentation again and add error messages to llama stack start * llama stack start -> llama stack run * Change name of build for less confusion * Add pyopenapi fork to the repository, update RFC assets * Remove conflicting annotation * Added a "--raw" option for model template printing --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com> Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
This commit is contained in:
parent
35093c0b6f
commit
7bc7785b0d
141 changed files with 8252 additions and 4032 deletions
|
@ -1,4 +1,4 @@
|
||||||
include requirements.txt
|
include requirements.txt
|
||||||
include llama_toolchain/data/*.yaml
|
include llama_toolchain/data/*.yaml
|
||||||
include llama_toolchain/distribution/*.sh
|
include llama_toolchain/core/*.sh
|
||||||
include llama_toolchain/cli/scripts/*.sh
|
include llama_toolchain/cli/scripts/*.sh
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
|
|
||||||
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package.
|
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package.
|
||||||
|
|
||||||
### Subcommands
|
### Subcommands
|
||||||
1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace.
|
1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace.
|
||||||
2. `model`: Lists available models and their properties.
|
2. `model`: Lists available models and their properties.
|
||||||
3. `distribution`: A distribution is a set of REST APIs, this command allows you to manage (list, install, create, configure, start) distributions. You can read more about this [here](https://github.com/meta-llama/llama-stack/blob/main/docs/cli_reference.md#step-3-installing-and-configuring-distributions).
|
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](https://github.com/meta-llama/llama-stack/blob/api_updates_1/docs/cli_reference.md#step-3-building-configuring-and-running-llama-stack-servers).
|
||||||
|
|
||||||
### Sample Usage
|
### Sample Usage
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ The `llama` CLI tool helps you setup and use the Llama toolchain & agentic syste
|
||||||
llama --help
|
llama --help
|
||||||
```
|
```
|
||||||
<pre style="font-family: monospace;">
|
<pre style="font-family: monospace;">
|
||||||
usage: llama [-h] {download,model,distribution} ...
|
usage: llama [-h] {download,model,stack,api} ...
|
||||||
|
|
||||||
Welcome to the Llama CLI
|
Welcome to the Llama CLI
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
|
|
||||||
subcommands:
|
subcommands:
|
||||||
{download,model,distribution}
|
{download,model,stack,api}
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
## Step 1. Get the models
|
## Step 1. Get the models
|
||||||
|
@ -101,9 +101,9 @@ The `llama model` command helps you explore the model’s interface.
|
||||||
|
|
||||||
### 2.1 Subcommands
|
### 2.1 Subcommands
|
||||||
1. `download`: Download the model from different sources. (meta, huggingface)
|
1. `download`: Download the model from different sources. (meta, huggingface)
|
||||||
2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
|
2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
|
||||||
3. `template`: <TODO: What is a template?>
|
3. `template`: <TODO: What is a template?>
|
||||||
4. `describe`: Describes all the properties of the model.
|
4. `describe`: Describes all the properties of the model.
|
||||||
|
|
||||||
### 2.2 Sample Usage
|
### 2.2 Sample Usage
|
||||||
|
|
||||||
|
@ -236,11 +236,13 @@ These commands can help understand the model interface and how prompts / message
|
||||||
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
**NOTE**: Outputs in terminal are color printed to show special tokens.
|
||||||
|
|
||||||
|
|
||||||
## Step 3: Installing and Configuring Distributions
|
## Step 3: Building, Configuring and Running Llama Stack servers
|
||||||
|
|
||||||
An agentic app has several components including model inference, tool execution and system safety shields. Running all these components is made simpler (we hope!) with Llama Stack Distributions.
|
An agentic app has several components including model inference, tool execution and system safety shields. Running all these components is made simpler (we hope!) with Llama Stack Distributions.
|
||||||
|
|
||||||
A Distribution is simply a collection of REST API providers that are part of the Llama stack. As an example, by running a simple command `llama distribution start`, you can bring up a server serving the following endpoints, among others:
|
The Llama Stack is a collection of REST APIs. An API is _implemented_ by Provider. An assembly of Providers together provides the implementation for the Stack -- this package is called a Distribution.
|
||||||
|
|
||||||
|
As an example, by running a simple command `llama stack run`, you can bring up a server serving the following endpoints, among others:
|
||||||
```
|
```
|
||||||
POST /inference/chat_completion
|
POST /inference/chat_completion
|
||||||
POST /inference/completion
|
POST /inference/completion
|
||||||
|
@ -253,103 +255,135 @@ POST /agentic_system/delete
|
||||||
|
|
||||||
The agentic app can now simply point to this server to execute all its needed components.
|
The agentic app can now simply point to this server to execute all its needed components.
|
||||||
|
|
||||||
A distribution’s behavior can be configured by defining a specification or “spec”. This specification lays out the different API “Providers” that constitute this distribution.
|
Lets build, configure and start a Llama Stack server specified via a "Distribution ID" to understand more !
|
||||||
|
|
||||||
Lets install, configure and start a distribution to understand more !
|
Let’s start with listing available distributions:
|
||||||
|
|
||||||
Let’s start with listing available distributions
|
|
||||||
```
|
```
|
||||||
llama distribution list
|
llama stack list-distributions
|
||||||
```
|
```
|
||||||
|
|
||||||
<pre style="font-family: monospace;">
|
<pre style="font-family: monospace;">
|
||||||
+--------------+---------------------------------------------+----------------------------------------------------------------------+
|
i+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| Spec ID | ProviderSpecs | Description |
|
| Distribution ID | Providers | Description |
|
||||||
+--------------+---------------------------------------------+----------------------------------------------------------------------+
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
|
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
|
||||||
| | "inference": "meta-reference", | |
|
| | "inference": "meta-reference", | |
|
||||||
| | "safety": "meta-reference", | |
|
| | "memory": "meta-reference-faiss", | |
|
||||||
| | "agentic_system": "meta-reference" | |
|
| | "safety": "meta-reference", | |
|
||||||
| | } | |
|
| | "agentic_system": "meta-reference" | |
|
||||||
+--------------+---------------------------------------------+----------------------------------------------------------------------+
|
| | } | |
|
||||||
| remote | { | Point to remote services for all llama stack APIs |
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| | "inference": "inference-remote", | |
|
| remote | { | Point to remote services for all llama stack APIs |
|
||||||
| | "safety": "safety-remote", | |
|
| | "inference": "remote", | |
|
||||||
| | "agentic_system": "agentic_system-remote" | |
|
| | "safety": "remote", | |
|
||||||
| | } | |
|
| | "agentic_system": "remote", | |
|
||||||
+--------------+---------------------------------------------+----------------------------------------------------------------------+
|
| | "memory": "remote" | |
|
||||||
| local-ollama | { | Like local, but use ollama for running LLM inference |
|
| | } | |
|
||||||
| | "inference": "meta-ollama", | |
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| | "safety": "meta-reference", | |
|
| local-ollama | { | Like local, but use ollama for running LLM inference |
|
||||||
| | "agentic_system": "meta-reference" | |
|
| | "inference": "remote::ollama", | |
|
||||||
| | } | |
|
| | "safety": "meta-reference", | |
|
||||||
+--------------+---------------------------------------------+----------------------------------------------------------------------+
|
| | "agentic_system": "meta-reference", | |
|
||||||
|
| | "memory": "meta-reference-faiss" | |
|
||||||
|
| | } | |
|
||||||
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
|
| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
|
||||||
|
| | "inference": "remote::fireworks", | |
|
||||||
|
| | "safety": "meta-reference", | |
|
||||||
|
| | "agentic_system": "meta-reference", | |
|
||||||
|
| | "memory": "meta-reference-faiss" | |
|
||||||
|
| | } | |
|
||||||
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
|
| local-plus-together-inference | { | Use Together.ai for running LLM inference |
|
||||||
|
| | "inference": "remote::together", | |
|
||||||
|
| | "safety": "meta-reference", | |
|
||||||
|
| | "agentic_system": "meta-reference", | |
|
||||||
|
| | "memory": "meta-reference-faiss" | |
|
||||||
|
| | } | |
|
||||||
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
As you can see above, each “spec” details the “providers” that make up that spec. For eg. The `local` spec uses the “meta-reference” provider for inference while the `local-ollama` spec relies on a different provider ( ollama ) for inference.
|
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.
|
||||||
|
|
||||||
Lets install the fully local implementation of the llama-stack – named `local` above.
|
To install a distribution, we run a simple command providing 2 inputs:
|
||||||
|
- **Distribution Id** of the distribution that we want to install ( as obtained from the list-distributions command )
|
||||||
|
- A **Name** for the specific build and configuration of this distribution.
|
||||||
|
|
||||||
To install a distro, we run a simple command providing 2 inputs –
|
Let's imagine you are working with a 8B-Instruct model. The following command will build a package (in the form of a Conda environment) _and_ configure it. As part of the configuration, you will be asked for some inputs (model_id, max_seq_len, etc.) Since we are working with a 8B model, we will name our build `8b-instruct` to help us remember the config.
|
||||||
- **Spec Id** of the distribution that we want to install ( as obtained from the list command )
|
|
||||||
- A **Name** by which this installation will be known locally.
|
|
||||||
|
|
||||||
```
|
```
|
||||||
llama distribution install --spec local --name local_llama_8b
|
llama stack build local --name 8b-instruct
|
||||||
```
|
```
|
||||||
|
|
||||||
This will create a new conda environment (name can be passed optionally) and install dependencies (via pip) as required by the distro.
|
Once it runs successfully , you should see some outputs in the form:
|
||||||
|
|
||||||
Once it runs successfully , you should see some outputs in the form
|
|
||||||
|
|
||||||
```
|
```
|
||||||
llama distribution install --spec local --name local_llama_8b
|
$ llama stack build local --name 8b-instruct
|
||||||
```
|
....
|
||||||
<pre style="font-family: monospace;">
|
....
|
||||||
Successfully installed cfgv-3.4.0 distlib-0.3.8 identify-2.6.0 libcst-1.4.0 llama_toolchain-0.0.2 moreorless-0.4.0 nodeenv-1.9.1 pre-commit-3.8.0 stdlibs-2024.5.15 toml-0.10.2 tomlkit-0.13.0 trailrunner-1.4.0 ufmt-2.7.0 usort-1.0.8 virtualenv-20.26.3
|
Successfully installed cfgv-3.4.0 distlib-0.3.8 identify-2.6.0 libcst-1.4.0 llama_toolchain-0.0.2 moreorless-0.4.0 nodeenv-1.9.1 pre-commit-3.8.0 stdlibs-2024.5.15 toml-0.10.2 tomlkit-0.13.0 trailrunner-1.4.0 ufmt-2.7.0 usort-1.0.8 virtualenv-20.26.3
|
||||||
|
|
||||||
Distribution `local_llama_8b` (with spec local) has been installed successfully!
|
Successfully setup conda environment. Configuring build...
|
||||||
</pre>
|
|
||||||
|
|
||||||
Next step is to configure the distribution that you just installed. We provide a simple CLI tool to enable simple configuration.
|
...
|
||||||
This command will walk you through the configuration process.
|
...
|
||||||
It will ask for some details like model name, paths to models, etc.
|
|
||||||
|
|
||||||
**NOTE**: You will have to download the models if not done already. Follow instructions here on how to download using the llama cli
|
YAML configuration has been written to ~/.llama/builds/local/conda/8b-instruct.yaml
|
||||||
```
|
|
||||||
llama distribution configure --name local_llama_8b
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Here is an example output of how the cli will guide you to fill the configuration:
|
You can re-configure this distribution by running:
|
||||||
<pre style="font-family: monospace;">
|
```
|
||||||
Configuring API surface: inference
|
llama stack configure local --name 8b-instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is an example run of how the CLI will guide you to fill the configuration
|
||||||
|
```
|
||||||
|
$ llama stack configure local --name 8b-instruct
|
||||||
|
|
||||||
|
Configuring API: inference (meta-reference)
|
||||||
Enter value for model (required): Meta-Llama3.1-8B-Instruct
|
Enter value for model (required): Meta-Llama3.1-8B-Instruct
|
||||||
Enter value for quantization (optional):
|
Enter value for quantization (optional):
|
||||||
Enter value for torch_seed (optional):
|
Enter value for torch_seed (optional):
|
||||||
Enter value for max_seq_len (required): 4096
|
Enter value for max_seq_len (required): 4096
|
||||||
Enter value for max_batch_size (default: 1): 1
|
Enter value for max_batch_size (default: 1): 1
|
||||||
Configuring API surface: safety
|
Configuring API: safety (meta-reference)
|
||||||
Do you want to configure llama_guard_shield? (y/n): n
|
Do you want to configure llama_guard_shield? (y/n): y
|
||||||
Do you want to configure prompt_guard_shield? (y/n): n
|
Entering sub-configuration for llama_guard_shield:
|
||||||
Configuring API surface: agentic_system
|
Enter value for model (required): Llama-Guard-3-8B
|
||||||
|
Enter value for excluded_categories (required): []
|
||||||
|
Enter value for disable_input_check (default: False):
|
||||||
|
Enter value for disable_output_check (default: False):
|
||||||
|
Do you want to configure prompt_guard_shield? (y/n): y
|
||||||
|
Entering sub-configuration for prompt_guard_shield:
|
||||||
|
Enter value for model (required): Prompt-Guard-86M
|
||||||
|
...
|
||||||
|
...
|
||||||
|
YAML configuration has been written to ~/.llama/builds/local/conda/8b-instruct.yaml
|
||||||
|
```
|
||||||
|
|
||||||
YAML configuration has been written to ~/.llama/distributions/local0/config.yaml
|
As you can see, we did basic configuration above and configured:
|
||||||
</pre>
|
- inference to run on model `Meta-Llama3.1-8B-Instruct` (obtained from `llama model list`)
|
||||||
|
- Llama Guard safety shield with model `Llama-Guard-3-8B`
|
||||||
As you can see, we did basic configuration above and configured inference to run on model Meta-Llama3.1-8B-Instruct ( obtained from the llama model list command ).
|
- Prompt Guard safety shield with model `Prompt-Guard-86M`
|
||||||
For this initial setup we did not set up safety.
|
|
||||||
|
|
||||||
For how these configurations are stored as yaml, checkout the file printed at the end of the configuration.
|
For how these configurations are stored as yaml, checkout the file printed at the end of the configuration.
|
||||||
|
|
||||||
## Step 4: Starting a Distribution and Testing it
|
Note that all configurations as well as models are stored in `~/.llama`
|
||||||
|
|
||||||
Now let’s start the distribution using the cli.
|
## Step 4: Starting a Llama Stack Distribution and Testing it
|
||||||
```
|
|
||||||
llama distribution start --name local_llama_8b --port 5000
|
Now let’s start Llama Stack server.
|
||||||
```
|
|
||||||
You should see the distribution start and print the APIs that it is supporting:
|
You need the YAML configuration file which was written out at the end by the `llama stack build` step.
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack run local --name 8b-instruct --port 5000
|
||||||
|
```
|
||||||
|
You should see the Stack server start and print the APIs that it is supporting,
|
||||||
|
|
||||||
|
```
|
||||||
|
$ llama stack run local --name 8b-instruct --port 5000
|
||||||
|
|
||||||
<pre style="font-family: monospace;">
|
|
||||||
> initializing model parallel with size 1
|
> initializing model parallel with size 1
|
||||||
> initializing ddp with size 1
|
> initializing ddp with size 1
|
||||||
> initializing pipeline with size 1
|
> initializing pipeline with size 1
|
||||||
|
@ -376,15 +410,23 @@ INFO: Started server process [453333]
|
||||||
INFO: Waiting for application startup.
|
INFO: Waiting for application startup.
|
||||||
INFO: Application startup complete.
|
INFO: Application startup complete.
|
||||||
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
|
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
|
||||||
</pre>
|
|
||||||
|
|
||||||
Lets test with a client
|
|
||||||
|
|
||||||
```
|
```
|
||||||
cd /path/to/llama-toolchain
|
|
||||||
conda activate <env-for-distribution> # ( Eg. local_llama_8b in above example )
|
|
||||||
|
|
||||||
python -m llama_toolchain.inference.client localhost 5000
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Configuration is in `~/.llama/builds/local/conda/8b-instruct.yaml`. Feel free to increase `max_seq_len`.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> The "local" distribution inference server currently only supports CUDA. It will not work on Apple Silicon machines.
|
||||||
|
|
||||||
|
This server is running a Llama model locally.
|
||||||
|
|
||||||
|
Lets test with a client.
|
||||||
|
```
|
||||||
|
cd /path/to/llama-stack
|
||||||
|
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
||||||
|
|
||||||
|
python -m llama_toolchain.inference.client localhost 5000
|
||||||
```
|
```
|
||||||
|
|
||||||
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
||||||
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa
|
|
||||||
|
|
413
llama_toolchain/agentic_system/api/api.py
Normal file
413
llama_toolchain/agentic_system/api/api.py
Normal file
|
@ -0,0 +1,413 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||||
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Attachment(BaseModel):
|
||||||
|
content: InterleavedTextMedia | URL
|
||||||
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemTool(Enum):
|
||||||
|
brave_search = "brave_search"
|
||||||
|
wolfram_alpha = "wolfram_alpha"
|
||||||
|
photogen = "photogen"
|
||||||
|
code_interpreter = "code_interpreter"
|
||||||
|
|
||||||
|
function_call = "function_call"
|
||||||
|
memory = "memory"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolDefinitionCommon(BaseModel):
|
||||||
|
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BraveSearchToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.brave_search.value] = (
|
||||||
|
AgenticSystemTool.brave_search.value
|
||||||
|
)
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
|
||||||
|
AgenticSystemTool.wolfram_alpha.value
|
||||||
|
)
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PhotogenToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.code_interpreter.value] = (
|
||||||
|
AgenticSystemTool.code_interpreter.value
|
||||||
|
)
|
||||||
|
enable_inline_code_execution: bool = True
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.function_call.value] = (
|
||||||
|
AgenticSystemTool.function_call.value
|
||||||
|
)
|
||||||
|
function_name: str
|
||||||
|
description: str
|
||||||
|
parameters: Dict[str, ToolParamDefinition]
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _MemoryBankConfigCommon(BaseModel):
|
||||||
|
bank_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||||
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
AgenticSystemVectorMemoryBankConfig,
|
||||||
|
AgenticSystemKeyValueMemoryBankConfig,
|
||||||
|
AgenticSystemKeywordMemoryBankConfig,
|
||||||
|
AgenticSystemGraphMemoryBankConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
||||||
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
|
max_tokens_in_context: int = 4096
|
||||||
|
max_chunks: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
AgenticSystemToolDefinition = Annotated[
|
||||||
|
Union[
|
||||||
|
BraveSearchToolDefinition,
|
||||||
|
WolframAlphaToolDefinition,
|
||||||
|
PhotogenToolDefinition,
|
||||||
|
CodeInterpreterToolDefinition,
|
||||||
|
FunctionCallToolDefinition,
|
||||||
|
MemoryToolDefinition,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class StepCommon(BaseModel):
|
||||||
|
turn_id: str
|
||||||
|
step_id: str
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StepType(Enum):
|
||||||
|
inference = "inference"
|
||||||
|
tool_execution = "tool_execution"
|
||||||
|
shield_call = "shield_call"
|
||||||
|
memory_retrieval = "memory_retrieval"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class InferenceStep(StepCommon):
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||||
|
model_response: CompletionMessage
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolExecutionStep(StepCommon):
|
||||||
|
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||||
|
tool_calls: List[ToolCall]
|
||||||
|
tool_responses: List[ToolResponse]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ShieldCallStep(StepCommon):
|
||||||
|
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||||
|
response: ShieldResponse
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryRetrievalStep(StepCommon):
|
||||||
|
step_type: Literal[StepType.memory_retrieval.value] = (
|
||||||
|
StepType.memory_retrieval.value
|
||||||
|
)
|
||||||
|
memory_bank_ids: List[str]
|
||||||
|
inserted_context: InterleavedTextMedia
|
||||||
|
|
||||||
|
|
||||||
|
Step = Annotated[
|
||||||
|
Union[
|
||||||
|
InferenceStep,
|
||||||
|
ToolExecutionStep,
|
||||||
|
ShieldCallStep,
|
||||||
|
MemoryRetrievalStep,
|
||||||
|
],
|
||||||
|
Field(discriminator="step_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Turn(BaseModel):
|
||||||
|
"""A single turn in an interaction with an Agentic System."""
|
||||||
|
|
||||||
|
turn_id: str
|
||||||
|
session_id: str
|
||||||
|
input_messages: List[
|
||||||
|
Union[
|
||||||
|
UserMessage,
|
||||||
|
ToolResponseMessage,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
steps: List[Step]
|
||||||
|
output_message: CompletionMessage
|
||||||
|
output_attachments: List[Attachment] = Field(default_factory=list)
|
||||||
|
|
||||||
|
started_at: datetime
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Session(BaseModel):
|
||||||
|
"""A single session of an interaction with an Agentic System."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
session_name: str
|
||||||
|
turns: List[Turn]
|
||||||
|
started_at: datetime
|
||||||
|
|
||||||
|
memory_bank: Optional[MemoryBank] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfigCommon(BaseModel):
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
|
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
|
||||||
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
|
default=ToolPromptFormat.json
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentConfig(AgentConfigCommon):
|
||||||
|
model: str
|
||||||
|
instructions: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemTurnResponseEventType(Enum):
|
||||||
|
step_start = "step_start"
|
||||||
|
step_complete = "step_complete"
|
||||||
|
step_progress = "step_progress"
|
||||||
|
|
||||||
|
turn_start = "turn_start"
|
||||||
|
turn_complete = "turn_complete"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||||
|
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
||||||
|
AgenticSystemTurnResponseEventType.step_start.value
|
||||||
|
)
|
||||||
|
step_type: StepType
|
||||||
|
step_id: str
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||||
|
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
||||||
|
AgenticSystemTurnResponseEventType.step_complete.value
|
||||||
|
)
|
||||||
|
step_type: StepType
|
||||||
|
step_details: Step
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
||||||
|
AgenticSystemTurnResponseEventType.step_progress.value
|
||||||
|
)
|
||||||
|
step_type: StepType
|
||||||
|
step_id: str
|
||||||
|
|
||||||
|
model_response_text_delta: Optional[str] = None
|
||||||
|
tool_call_delta: Optional[ToolCallDelta] = None
|
||||||
|
tool_response_text_delta: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
||||||
|
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
||||||
|
AgenticSystemTurnResponseEventType.turn_start.value
|
||||||
|
)
|
||||||
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
||||||
|
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
||||||
|
AgenticSystemTurnResponseEventType.turn_complete.value
|
||||||
|
)
|
||||||
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnResponseEvent(BaseModel):
|
||||||
|
"""Streamed agent execution response."""
|
||||||
|
|
||||||
|
payload: Annotated[
|
||||||
|
Union[
|
||||||
|
AgenticSystemTurnResponseStepStartPayload,
|
||||||
|
AgenticSystemTurnResponseStepProgressPayload,
|
||||||
|
AgenticSystemTurnResponseStepCompletePayload,
|
||||||
|
AgenticSystemTurnResponseTurnStartPayload,
|
||||||
|
AgenticSystemTurnResponseTurnCompletePayload,
|
||||||
|
],
|
||||||
|
Field(discriminator="event_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemCreateResponse(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemSessionCreateResponse(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
|
agent_id: str
|
||||||
|
session_id: str
|
||||||
|
|
||||||
|
# TODO: figure out how we can simplify this and make why
|
||||||
|
# ToolResponseMessage needs to be here (it is function call
|
||||||
|
# execution from outside the system)
|
||||||
|
messages: List[
|
||||||
|
Union[
|
||||||
|
UserMessage,
|
||||||
|
ToolResponseMessage,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
attachments: Optional[List[Attachment]] = None
|
||||||
|
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
||||||
|
event: AgenticSystemTurnResponseEvent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemStepResponse(BaseModel):
|
||||||
|
step: Step
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystem(Protocol):
|
||||||
|
@webmethod(route="/agentic_system/create")
|
||||||
|
async def create_agentic_system(
|
||||||
|
self,
|
||||||
|
agent_config: AgentConfig,
|
||||||
|
) -> AgenticSystemCreateResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/turn/create")
|
||||||
|
async def create_agentic_system_turn(
|
||||||
|
self,
|
||||||
|
request: AgenticSystemTurnCreateRequest,
|
||||||
|
) -> AgenticSystemTurnResponseStreamChunk: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/turn/get")
|
||||||
|
async def get_agentic_system_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
) -> Turn: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/step/get")
|
||||||
|
async def get_agentic_system_step(
|
||||||
|
self, agent_id: str, turn_id: str, step_id: str
|
||||||
|
) -> AgenticSystemStepResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/session/create")
|
||||||
|
async def create_agentic_system_session(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_name: str,
|
||||||
|
) -> AgenticSystemSessionCreateResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/session/get")
|
||||||
|
async def get_agentic_system_session(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_ids: Optional[List[str]] = None,
|
||||||
|
) -> Session: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/session/delete")
|
||||||
|
async def delete_agentic_system_session(
|
||||||
|
self, agent_id: str, session_id: str
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/delete")
|
||||||
|
async def delete_agentic_system(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
) -> None: ...
|
|
@ -1,234 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
|
||||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
|
||||||
from llama_toolchain.memory.api.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemToolDefinition(ToolDefinition):
|
|
||||||
execution_config: Optional[RestAPIExecutionConfig] = None
|
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class StepCommon(BaseModel):
|
|
||||||
turn_id: str
|
|
||||||
step_id: str
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
class StepType(Enum):
|
|
||||||
inference = "inference"
|
|
||||||
tool_execution = "tool_execution"
|
|
||||||
shield_call = "shield_call"
|
|
||||||
memory_retrieval = "memory_retrieval"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class InferenceStep(StepCommon):
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
|
||||||
model_response: CompletionMessage
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolExecutionStep(StepCommon):
|
|
||||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
|
||||||
tool_calls: List[ToolCall]
|
|
||||||
tool_responses: List[ToolResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ShieldCallStep(StepCommon):
|
|
||||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
|
||||||
response: ShieldResponse
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryRetrievalStep(StepCommon):
|
|
||||||
step_type: Literal[StepType.memory_retrieval.value] = (
|
|
||||||
StepType.memory_retrieval.value
|
|
||||||
)
|
|
||||||
memory_bank_ids: List[str]
|
|
||||||
documents: List[MemoryBankDocument]
|
|
||||||
scores: List[float]
|
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
|
||||||
Union[
|
|
||||||
InferenceStep,
|
|
||||||
ToolExecutionStep,
|
|
||||||
ShieldCallStep,
|
|
||||||
MemoryRetrievalStep,
|
|
||||||
],
|
|
||||||
Field(discriminator="step_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Turn(BaseModel):
|
|
||||||
"""A single turn in an interaction with an Agentic System."""
|
|
||||||
|
|
||||||
turn_id: str
|
|
||||||
session_id: str
|
|
||||||
input_messages: List[
|
|
||||||
Union[
|
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
steps: List[Step]
|
|
||||||
output_message: CompletionMessage
|
|
||||||
started_at: datetime
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Session(BaseModel):
|
|
||||||
"""A single session of an interaction with an Agentic System."""
|
|
||||||
|
|
||||||
session_id: str
|
|
||||||
session_name: str
|
|
||||||
turns: List[Turn]
|
|
||||||
started_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolPromptFormat(Enum):
|
|
||||||
"""This Enum refers to the prompt format for calling zero shot tools
|
|
||||||
|
|
||||||
`json` --
|
|
||||||
Refers to the json format for calling tools.
|
|
||||||
The json format takes the form like
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function" : {
|
|
||||||
"name": "function_name",
|
|
||||||
"description": "function_description",
|
|
||||||
"parameters": {...}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
`function_tag` --
|
|
||||||
This is an example of how you could define
|
|
||||||
your own user defined format for making tool calls.
|
|
||||||
The function_tag format looks like this,
|
|
||||||
<function=function_name>(parameters)</function>
|
|
||||||
|
|
||||||
The detailed prompts for each of these formats are defined in `system_prompt.py`
|
|
||||||
"""
|
|
||||||
|
|
||||||
json = "json"
|
|
||||||
function_tag = "function_tag"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemInstanceConfig(BaseModel):
|
|
||||||
instructions: str
|
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
||||||
# zero-shot or built-in tool configurations as input to the model
|
|
||||||
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
|
|
||||||
default_factory=list
|
|
||||||
)
|
|
||||||
|
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
# if you completely want to replace the messages prefixed by the system,
|
|
||||||
# this is debug only
|
|
||||||
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
|
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemTurnResponseEventType(Enum):
|
|
||||||
step_start = "step_start"
|
|
||||||
step_complete = "step_complete"
|
|
||||||
step_progress = "step_progress"
|
|
||||||
|
|
||||||
turn_start = "turn_start"
|
|
||||||
turn_complete = "turn_complete"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
|
||||||
AgenticSystemTurnResponseEventType.step_start.value
|
|
||||||
)
|
|
||||||
step_type: StepType
|
|
||||||
step_id: str
|
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
|
||||||
AgenticSystemTurnResponseEventType.step_complete.value
|
|
||||||
)
|
|
||||||
step_type: StepType
|
|
||||||
step_details: Step
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
|
||||||
AgenticSystemTurnResponseEventType.step_progress.value
|
|
||||||
)
|
|
||||||
step_type: StepType
|
|
||||||
step_id: str
|
|
||||||
|
|
||||||
model_response_text_delta: Optional[str] = None
|
|
||||||
tool_call_delta: Optional[ToolCallDelta] = None
|
|
||||||
tool_response_text_delta: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
|
||||||
AgenticSystemTurnResponseEventType.turn_start.value
|
|
||||||
)
|
|
||||||
turn_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
|
||||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
|
||||||
)
|
|
||||||
turn: Turn
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemTurnResponseEvent(BaseModel):
|
|
||||||
"""Streamed agent execution response."""
|
|
||||||
|
|
||||||
payload: Annotated[
|
|
||||||
Union[
|
|
||||||
AgenticSystemTurnResponseStepStartPayload,
|
|
||||||
AgenticSystemTurnResponseStepProgressPayload,
|
|
||||||
AgenticSystemTurnResponseStepCompletePayload,
|
|
||||||
AgenticSystemTurnResponseTurnStartPayload,
|
|
||||||
AgenticSystemTurnResponseTurnCompletePayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="event_type"),
|
|
||||||
]
|
|
|
@ -1,127 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
# this dependency is annoying and we need a forked up version anyway
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemCreateRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
instance_config: AgenticSystemInstanceConfig
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemCreateResponse(BaseModel):
|
|
||||||
system_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemSessionCreateRequest(BaseModel):
|
|
||||||
system_id: str
|
|
||||||
session_name: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemSessionCreateResponse(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
# what's the URI?
|
|
||||||
class AgenticSystemTurnCreateRequest(BaseModel):
|
|
||||||
system_id: str
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
messages: List[
|
|
||||||
Union[
|
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
override_config: Optional[AgenticSystemInstanceConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
|
||||||
event: AgenticSystemTurnResponseEvent
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemStepResponse(BaseModel):
|
|
||||||
step: Step
|
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystem(Protocol):
|
|
||||||
@webmethod(route="/agentic_system/create")
|
|
||||||
async def create_agentic_system(
|
|
||||||
self,
|
|
||||||
request: AgenticSystemCreateRequest,
|
|
||||||
) -> AgenticSystemCreateResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/create")
|
|
||||||
async def create_agentic_system_turn(
|
|
||||||
self,
|
|
||||||
request: AgenticSystemTurnCreateRequest,
|
|
||||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/get")
|
|
||||||
async def get_agentic_system_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
) -> Turn: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/step/get")
|
|
||||||
async def get_agentic_system_step(
|
|
||||||
self, agent_id: str, turn_id: str, step_id: str
|
|
||||||
) -> AgenticSystemStepResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/create")
|
|
||||||
async def create_agentic_system_session(
|
|
||||||
self,
|
|
||||||
request: AgenticSystemSessionCreateRequest,
|
|
||||||
) -> AgenticSystemSessionCreateResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/memory_bank/attach")
|
|
||||||
async def attach_memory_bank_to_agentic_system(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
memory_bank_ids: List[str],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/memory_bank/detach")
|
|
||||||
async def detach_memory_bank_from_agentic_system(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
memory_bank_ids: List[str],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/get")
|
|
||||||
async def get_agentic_system_session(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_ids: Optional[List[str]] = None,
|
|
||||||
) -> Session: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/delete")
|
|
||||||
async def delete_agentic_system_session(
|
|
||||||
self, agent_id: str, session_id: str
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/delete")
|
|
||||||
async def delete_agentic_system(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
) -> None: ...
|
|
|
@ -6,38 +6,28 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from pydantic import BaseModel
|
||||||
BuiltinTool,
|
|
||||||
SamplingParams,
|
|
||||||
ToolParamDefinition,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.event_logger import EventLogger
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from .api import (
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
AgenticSystem,
|
|
||||||
AgenticSystemCreateRequest,
|
from .api import * # noqa: F403
|
||||||
AgenticSystemCreateResponse,
|
from .event_logger import EventLogger
|
||||||
AgenticSystemInstanceConfig,
|
|
||||||
AgenticSystemSessionCreateRequest,
|
|
||||||
AgenticSystemSessionCreateResponse,
|
|
||||||
AgenticSystemToolDefinition,
|
|
||||||
AgenticSystemTurnCreateRequest,
|
|
||||||
AgenticSystemTurnResponseStreamChunk,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(base_url: str):
|
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
||||||
return AgenticSystemClient(base_url)
|
return AgenticSystemClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
def encodable_dict(d: BaseModel):
|
||||||
|
return json.loads(d.json())
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemClient(AgenticSystem):
|
class AgenticSystemClient(AgenticSystem):
|
||||||
|
@ -45,12 +35,14 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
async def create_agentic_system(
|
async def create_agentic_system(
|
||||||
self, request: AgenticSystemCreateRequest
|
self, agent_config: AgentConfig
|
||||||
) -> AgenticSystemCreateResponse:
|
) -> AgenticSystemCreateResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/agentic_system/create",
|
f"{self.base_url}/agentic_system/create",
|
||||||
data=request.json(),
|
json={
|
||||||
|
"agent_config": encodable_dict(agent_config),
|
||||||
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -58,12 +50,16 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
|
|
||||||
async def create_agentic_system_session(
|
async def create_agentic_system_session(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemSessionCreateRequest,
|
agent_id: str,
|
||||||
|
session_name: str,
|
||||||
) -> AgenticSystemSessionCreateResponse:
|
) -> AgenticSystemSessionCreateResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/agentic_system/session/create",
|
f"{self.base_url}/agentic_system/session/create",
|
||||||
data=request.json(),
|
json={
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"session_name": session_name,
|
||||||
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -77,7 +73,9 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/agentic_system/turn/create",
|
f"{self.base_url}/agentic_system/turn/create",
|
||||||
data=request.json(),
|
json={
|
||||||
|
"request": encodable_dict(request),
|
||||||
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as response:
|
) as response:
|
||||||
|
@ -85,6 +83,10 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data = line[len("data: ") :]
|
data = line[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
|
if "error" in data:
|
||||||
|
cprint(data, "red")
|
||||||
|
continue
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
**json.loads(data)
|
**json.loads(data)
|
||||||
)
|
)
|
||||||
|
@ -93,24 +95,52 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="Meta-Llama3.1-8B-Instruct",
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
||||||
|
tools=tool_definitions,
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_response = await api.create_agentic_system(agent_config)
|
||||||
|
session_response = await api.create_agentic_system_session(
|
||||||
|
agent_id=create_response.agent_id,
|
||||||
|
session_name="test_session",
|
||||||
|
)
|
||||||
|
|
||||||
|
for content in user_prompts:
|
||||||
|
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||||
|
iterator = api.create_agentic_system_turn(
|
||||||
|
AgenticSystemTurnCreateRequest(
|
||||||
|
agent_id=create_response.agent_id,
|
||||||
|
session_id=session_response.session_id,
|
||||||
|
messages=[
|
||||||
|
UserMessage(content=content),
|
||||||
|
],
|
||||||
|
attachments=attachments,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async for event, log in EventLogger().log(iterator):
|
||||||
|
if log is not None:
|
||||||
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int):
|
async def run_main(host: str, port: int):
|
||||||
# client to test remote impl of agentic system
|
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
AgenticSystemToolDefinition(
|
BraveSearchToolDefinition(),
|
||||||
tool_name=BuiltinTool.brave_search,
|
WolframAlphaToolDefinition(),
|
||||||
),
|
CodeInterpreterToolDefinition(),
|
||||||
AgenticSystemToolDefinition(
|
|
||||||
tool_name=BuiltinTool.wolfram_alpha,
|
|
||||||
),
|
|
||||||
AgenticSystemToolDefinition(
|
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
tool_definitions += [
|
tool_definitions += [
|
||||||
AgenticSystemToolDefinition(
|
FunctionCallToolDefinition(
|
||||||
tool_name="get_boiling_point",
|
function_name="get_boiling_point",
|
||||||
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||||
parameters={
|
parameters={
|
||||||
"liquid_name": ToolParamDefinition(
|
"liquid_name": ToolParamDefinition(
|
||||||
|
@ -127,30 +157,6 @@ async def run_main(host: str, port: int):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
create_request = AgenticSystemCreateRequest(
|
|
||||||
model="Meta-Llama3.1-8B-Instruct",
|
|
||||||
instance_config=AgenticSystemInstanceConfig(
|
|
||||||
instructions="You are a helpful assistant",
|
|
||||||
sampling_params=SamplingParams(),
|
|
||||||
available_tools=tool_definitions,
|
|
||||||
input_shields=[],
|
|
||||||
output_shields=[],
|
|
||||||
debug_prefix_messages=[],
|
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await api.create_agentic_system(create_request)
|
|
||||||
print(create_response)
|
|
||||||
|
|
||||||
session_response = await api.create_agentic_system_session(
|
|
||||||
AgenticSystemSessionCreateRequest(
|
|
||||||
system_id=create_response.system_id,
|
|
||||||
session_name="test_session",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(session_response)
|
|
||||||
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
"Who are you?",
|
"Who are you?",
|
||||||
"what is the 100th prime number?",
|
"what is the 100th prime number?",
|
||||||
|
@ -158,26 +164,51 @@ async def run_main(host: str, port: int):
|
||||||
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
||||||
"What is the boiling point of polyjuicepotion ?",
|
"What is the boiling point of polyjuicepotion ?",
|
||||||
]
|
]
|
||||||
for content in user_prompts:
|
await _run_agent(api, tool_definitions, user_prompts)
|
||||||
cprint(f"User> {content}", color="blue")
|
|
||||||
iterator = api.create_agentic_system_turn(
|
|
||||||
AgenticSystemTurnCreateRequest(
|
async def run_rag(host: str, port: int):
|
||||||
system_id=create_response.system_id,
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
session_id=session_response.session_id,
|
|
||||||
messages=[
|
urls = [
|
||||||
UserMessage(content=content),
|
"memory_optimizations.rst",
|
||||||
],
|
"chat.rst",
|
||||||
stream=True,
|
"llama3.rst",
|
||||||
)
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
attachments = [
|
||||||
|
Attachment(
|
||||||
|
content=URL(
|
||||||
|
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||||
|
),
|
||||||
|
mime_type="text/plain",
|
||||||
)
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
async for event, log in EventLogger().log(iterator):
|
# Alternatively, you can pre-populate the memory bank with documents for example,
|
||||||
if log is not None:
|
# using `llama_toolchain.memory.client`. Then you can grab the bank_id
|
||||||
log.print()
|
# from the output of that run.
|
||||||
|
tool_definitions = [
|
||||||
|
MemoryToolDefinition(
|
||||||
|
max_tokens_in_context=2048,
|
||||||
|
memory_bank_configs=[],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
user_prompts = [
|
||||||
|
"How do I use Lora?",
|
||||||
|
"Tell me briefly about llama3 and torchtune",
|
||||||
|
]
|
||||||
|
|
||||||
|
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int, rag: bool = False):
|
||||||
asyncio.run(run_main(host, port))
|
fn = run_rag if rag else run_main
|
||||||
|
asyncio.run(fn(host, port))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolResponseMessage
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
@ -44,7 +44,12 @@ EventType = AgenticSystemTurnResponseEventType
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
class EventLogger:
|
||||||
async def log(self, event_generator, stream=True):
|
async def log(
|
||||||
|
self,
|
||||||
|
event_generator,
|
||||||
|
stream=True,
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
|
):
|
||||||
previous_event_type = None
|
previous_event_type = None
|
||||||
previous_step_type = None
|
previous_step_type = None
|
||||||
|
|
||||||
|
@ -132,7 +137,9 @@ class EventLogger:
|
||||||
if event_type == EventType.step_complete.value:
|
if event_type == EventType.step_complete.value:
|
||||||
response = event.payload.step_details.model_response
|
response = event.payload.step_details.model_response
|
||||||
if response.tool_calls:
|
if response.tool_calls:
|
||||||
content = ToolUtils.encode_tool_call(response.tool_calls[0])
|
content = ToolUtils.encode_tool_call(
|
||||||
|
response.tool_calls[0], tool_prompt_format
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = response.content
|
content = response.content
|
||||||
yield event, LogEvent(
|
yield event, LogEvent(
|
||||||
|
@ -162,5 +169,19 @@ class EventLogger:
|
||||||
color="green",
|
color="green",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
step_type == StepType.memory_retrieval
|
||||||
|
and event_type == EventType.step_complete.value
|
||||||
|
):
|
||||||
|
details = event.payload.step_details
|
||||||
|
content = interleaved_text_media_as_str(details.inserted_context)
|
||||||
|
content = content[:200] + "..." if len(content) > 200 else content
|
||||||
|
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=step_type,
|
||||||
|
content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>",
|
||||||
|
color="cyan",
|
||||||
|
)
|
||||||
|
|
||||||
preivous_event_type = event_type
|
preivous_event_type = event_type
|
||||||
previous_step_type = step_type
|
previous_step_type = step_type
|
||||||
|
|
96
llama_toolchain/agentic_system/execute_with_custom_tools.py
Normal file
96
llama_toolchain/agentic_system/execute_with_custom_tools.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import AsyncGenerator, List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_toolchain.agentic_system.api import (
|
||||||
|
AgenticSystemTurnResponseEventType as EventType,
|
||||||
|
)
|
||||||
|
from llama_toolchain.tools.custom.datatypes import CustomTool
|
||||||
|
|
||||||
|
|
||||||
|
class AgentWithCustomToolExecutor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api: AgenticSystem,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
agent_config: AgentConfig,
|
||||||
|
custom_tools: List[CustomTool],
|
||||||
|
):
|
||||||
|
self.api = api
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.session_id = session_id
|
||||||
|
self.agent_config = agent_config
|
||||||
|
self.custom_tools = custom_tools
|
||||||
|
|
||||||
|
async def execute_turn(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
attachments: Optional[List[Attachment]] = None,
|
||||||
|
max_iters: int = 5,
|
||||||
|
stream: bool = True,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
tools_dict = {t.get_name(): t for t in self.custom_tools}
|
||||||
|
|
||||||
|
current_messages = messages.copy()
|
||||||
|
n_iter = 0
|
||||||
|
while n_iter < max_iters:
|
||||||
|
n_iter += 1
|
||||||
|
|
||||||
|
request = AgenticSystemTurnCreateRequest(
|
||||||
|
agent_id=self.agent_id,
|
||||||
|
session_id=self.session_id,
|
||||||
|
messages=current_messages,
|
||||||
|
attachments=attachments,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn = None
|
||||||
|
async for chunk in self.api.create_agentic_system_turn(request):
|
||||||
|
if chunk.event.payload.event_type != EventType.turn_complete.value:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
turn = chunk.event.payload.turn
|
||||||
|
|
||||||
|
message = turn.output_message
|
||||||
|
if len(message.tool_calls) == 0:
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.stop_reason == StopReason.out_of_tokens:
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
|
if tool_call.tool_name not in tools_dict:
|
||||||
|
m = ToolResponseMessage(
|
||||||
|
call_id=tool_call.call_id,
|
||||||
|
tool_name=tool_call.tool_name,
|
||||||
|
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
|
||||||
|
)
|
||||||
|
next_message = m
|
||||||
|
else:
|
||||||
|
tool = tools_dict[tool_call.tool_name]
|
||||||
|
result_messages = await execute_custom_tool(tool, message)
|
||||||
|
next_message = result_messages[0]
|
||||||
|
|
||||||
|
yield next_message
|
||||||
|
current_messages = [next_message]
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]:
|
||||||
|
result_messages = await tool.run([message])
|
||||||
|
assert (
|
||||||
|
len(result_messages) == 1
|
||||||
|
), f"Expected single message, got {len(result_messages)}"
|
||||||
|
|
||||||
|
return result_messages
|
|
@ -4,5 +4,27 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .agentic_system import get_provider_impl # noqa
|
from typing import Dict
|
||||||
from .config import AgenticSystemConfig # noqa
|
|
||||||
|
from llama_toolchain.core.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||||
|
):
|
||||||
|
from .agentic_system import MetaReferenceAgenticSystemImpl
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, MetaReferenceImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = MetaReferenceAgenticSystemImpl(
|
||||||
|
config,
|
||||||
|
deps[Api.inference],
|
||||||
|
deps[Api.memory],
|
||||||
|
deps[Api.safety],
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
|
@ -4,111 +4,111 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import shutil
|
||||||
|
import string
|
||||||
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api.datatypes import (
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||||
AgenticSystemInstanceConfig,
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
AgenticSystemTurnResponseEvent,
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
AgenticSystemTurnResponseEventType,
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
AgenticSystemTurnResponseStepCompletePayload,
|
|
||||||
AgenticSystemTurnResponseStepProgressPayload,
|
|
||||||
AgenticSystemTurnResponseStepStartPayload,
|
|
||||||
AgenticSystemTurnResponseTurnCompletePayload,
|
|
||||||
AgenticSystemTurnResponseTurnStartPayload,
|
|
||||||
InferenceStep,
|
|
||||||
Session,
|
|
||||||
ShieldCallStep,
|
|
||||||
StepType,
|
|
||||||
ToolExecutionStep,
|
|
||||||
ToolPromptFormat,
|
|
||||||
Turn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
|
from llama_toolchain.tools.base import BaseTool
|
||||||
|
from llama_toolchain.tools.builtin import (
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
interpret_content_as_attachment,
|
||||||
Attachment,
|
SingleMessageBuiltinTool,
|
||||||
BuiltinTool,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
CompletionMessage,
|
|
||||||
Message,
|
|
||||||
Role,
|
|
||||||
SamplingParams,
|
|
||||||
StopReason,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolResponse,
|
|
||||||
ToolResponseMessage,
|
|
||||||
URL,
|
|
||||||
)
|
)
|
||||||
from llama_toolchain.safety.api import Safety
|
|
||||||
from llama_toolchain.safety.api.datatypes import (
|
|
||||||
BuiltinShield,
|
|
||||||
ShieldDefinition,
|
|
||||||
ShieldResponse,
|
|
||||||
)
|
|
||||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
|
||||||
|
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
from .system_prompt import get_agentic_prefix_messages
|
|
||||||
from .tools.base import BaseTool
|
|
||||||
from .tools.builtin import SingleMessageBuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
class AgentInstance(ShieldRunnerMixin):
|
def make_random_string(length: int = 8):
|
||||||
|
return "".join(
|
||||||
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
system_id: int,
|
agent_config: AgentConfig,
|
||||||
instance_config: AgenticSystemInstanceConfig,
|
|
||||||
model: str,
|
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
|
memory_api: Memory,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
builtin_tools: List[SingleMessageBuiltinTool],
|
builtin_tools: List[SingleMessageBuiltinTool],
|
||||||
custom_tool_definitions: List[ToolDefinition],
|
|
||||||
input_shields: List[ShieldDefinition],
|
|
||||||
output_shields: List[ShieldDefinition],
|
|
||||||
max_infer_iters: int = 10,
|
max_infer_iters: int = 10,
|
||||||
prefix_messages: Optional[List[Message]] = None,
|
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
||||||
):
|
):
|
||||||
self.system_id = system_id
|
self.agent_config = agent_config
|
||||||
self.instance_config = instance_config
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
|
|
||||||
if prefix_messages is not None and len(prefix_messages) > 0:
|
|
||||||
self.prefix_messages = prefix_messages
|
|
||||||
else:
|
|
||||||
self.prefix_messages = get_agentic_prefix_messages(
|
|
||||||
builtin_tools,
|
|
||||||
custom_tool_definitions,
|
|
||||||
tool_prompt_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
for m in self.prefix_messages:
|
|
||||||
print(m.content)
|
|
||||||
|
|
||||||
self.max_infer_iters = max_infer_iters
|
self.max_infer_iters = max_infer_iters
|
||||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
||||||
|
|
||||||
|
self.tempdir = tempfile.mkdtemp()
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
|
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
self,
|
self,
|
||||||
safety_api,
|
safety_api,
|
||||||
input_shields=input_shields,
|
input_shields=agent_config.input_shields,
|
||||||
output_shields=output_shields,
|
output_shields=agent_config.output_shields,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
shutil.rmtree(self.tempdir)
|
||||||
|
|
||||||
|
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# We do not want to keep adding RAG context to the input messages
|
||||||
|
# May be this should be a parameter of the agentic instance
|
||||||
|
# that can define its behavior in a custom way
|
||||||
|
for m in turn.input_messages:
|
||||||
|
msg = m.copy()
|
||||||
|
if isinstance(msg, UserMessage):
|
||||||
|
msg.context = None
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
# messages.extend(turn.input_messages)
|
||||||
|
for step in turn.steps:
|
||||||
|
if step.step_type == StepType.inference.value:
|
||||||
|
messages.append(step.model_response)
|
||||||
|
elif step.step_type == StepType.tool_execution.value:
|
||||||
|
for response in step.tool_responses:
|
||||||
|
messages.append(
|
||||||
|
ToolResponseMessage(
|
||||||
|
call_id=response.call_id,
|
||||||
|
tool_name=response.tool_name,
|
||||||
|
content=response.content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif step.step_type == StepType.shield_call.value:
|
||||||
|
response = step.response
|
||||||
|
if response.is_violation:
|
||||||
|
# CompletionMessage itself in the ShieldResponse
|
||||||
|
messages.append(
|
||||||
|
CompletionMessage(
|
||||||
|
content=response.violation_return_message,
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# print_dialog(messages)
|
||||||
|
return messages
|
||||||
|
|
||||||
def create_session(self, name: str) -> Session:
|
def create_session(self, name: str) -> Session:
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
session = Session(
|
session = Session(
|
||||||
|
@ -131,32 +131,7 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i, turn in enumerate(session.turns):
|
for i, turn in enumerate(session.turns):
|
||||||
# print(f"turn {i}")
|
messages.extend(self.turn_to_messages(turn))
|
||||||
# print_dialog(turn.input_messages)
|
|
||||||
messages.extend(turn.input_messages)
|
|
||||||
for step in turn.steps:
|
|
||||||
if step.step_type == StepType.inference.value:
|
|
||||||
messages.append(step.model_response)
|
|
||||||
elif step.step_type == StepType.tool_execution.value:
|
|
||||||
for response in step.tool_responses:
|
|
||||||
messages.append(
|
|
||||||
ToolResponseMessage(
|
|
||||||
call_id=response.call_id,
|
|
||||||
tool_name=response.tool_name,
|
|
||||||
content=response.content,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif step.step_type == StepType.shield_call.value:
|
|
||||||
response = step.response
|
|
||||||
if response.is_violation:
|
|
||||||
# TODO: Properly persist the
|
|
||||||
# CompletionMessage itself in the ShieldResponse
|
|
||||||
messages.append(
|
|
||||||
CompletionMessage(
|
|
||||||
content=response.violation_return_message,
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
|
@ -164,7 +139,6 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
# print_dialog(messages)
|
# print_dialog(messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
params = self.instance_config.sampling_params
|
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgenticSystemTurnResponseEvent(
|
||||||
|
@ -177,12 +151,12 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
steps = []
|
steps = []
|
||||||
output_message = None
|
output_message = None
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
|
session=session,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
temperature=params.temperature,
|
attachments=request.attachments or [],
|
||||||
top_p=params.top_p,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
max_gen_len=params.max_tokens,
|
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -227,6 +201,53 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
turn_id: str,
|
||||||
|
input_messages: List[Message],
|
||||||
|
attachments: List[Attachment],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
|
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||||
|
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||||
|
|
||||||
|
async for res in self.run_shields_wrapper(
|
||||||
|
turn_id, input_messages, self.input_shields, "user-input"
|
||||||
|
):
|
||||||
|
if isinstance(res, bool):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
yield res
|
||||||
|
|
||||||
|
async for res in self._run(
|
||||||
|
session, turn_id, input_messages, attachments, sampling_params, stream
|
||||||
|
):
|
||||||
|
if isinstance(res, bool):
|
||||||
|
return
|
||||||
|
elif isinstance(res, CompletionMessage):
|
||||||
|
final_response = res
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
yield res
|
||||||
|
|
||||||
|
assert final_response is not None
|
||||||
|
# for output shields run on the full input and output combination
|
||||||
|
messages = input_messages + [final_response]
|
||||||
|
|
||||||
|
async for res in self.run_shields_wrapper(
|
||||||
|
turn_id, messages, self.output_shields, "assistant-output"
|
||||||
|
):
|
||||||
|
if isinstance(res, bool):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
yield res
|
||||||
|
|
||||||
|
yield final_response
|
||||||
|
|
||||||
async def run_shields_wrapper(
|
async def run_shields_wrapper(
|
||||||
self,
|
self,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
|
@ -288,65 +309,62 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
turn_id: str,
|
|
||||||
input_messages: List[Message],
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
stream: bool = False,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
|
||||||
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
|
|
||||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
|
||||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
|
||||||
|
|
||||||
async for res in self.run_shields_wrapper(
|
|
||||||
turn_id, input_messages, self.input_shields, "user-input"
|
|
||||||
):
|
|
||||||
if isinstance(res, bool):
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
yield res
|
|
||||||
|
|
||||||
async for res in self._run(
|
|
||||||
turn_id, input_messages, temperature, top_p, stream, max_gen_len
|
|
||||||
):
|
|
||||||
if isinstance(res, bool):
|
|
||||||
return
|
|
||||||
elif isinstance(res, CompletionMessage):
|
|
||||||
final_response = res
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
yield res
|
|
||||||
|
|
||||||
assert final_response is not None
|
|
||||||
# for output shields run on the full input and output combination
|
|
||||||
messages = input_messages + [final_response]
|
|
||||||
|
|
||||||
async for res in self.run_shields_wrapper(
|
|
||||||
turn_id, messages, self.output_shields, "assistant-output"
|
|
||||||
):
|
|
||||||
if isinstance(res, bool):
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
yield res
|
|
||||||
|
|
||||||
yield final_response
|
|
||||||
|
|
||||||
async def _run(
|
async def _run(
|
||||||
self,
|
self,
|
||||||
|
session: Session,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
temperature: float,
|
attachments: List[Attachment],
|
||||||
top_p: float,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
input_messages = preprocess_dialog(input_messages, self.prefix_messages)
|
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||||
|
need_rag_context = await self._should_retrieve_context(
|
||||||
|
input_messages, attachments
|
||||||
|
)
|
||||||
|
if need_rag_context:
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
|
event=AgenticSystemTurnResponseEvent(
|
||||||
|
payload=AgenticSystemTurnResponseStepStartPayload(
|
||||||
|
step_type=StepType.memory_retrieval.value,
|
||||||
|
step_id=step_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
attachments = []
|
# TODO: find older context from the session and either replace it
|
||||||
|
# or append with a sliding window. this is really a very simplistic implementation
|
||||||
|
rag_context, bank_ids = await self._retrieve_context(
|
||||||
|
session, input_messages, attachments
|
||||||
|
)
|
||||||
|
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
|
event=AgenticSystemTurnResponseEvent(
|
||||||
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.memory_retrieval.value,
|
||||||
|
step_id=step_id,
|
||||||
|
step_details=MemoryRetrievalStep(
|
||||||
|
turn_id=turn_id,
|
||||||
|
step_id=step_id,
|
||||||
|
memory_bank_ids=bank_ids,
|
||||||
|
inserted_context=rag_context or "",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rag_context:
|
||||||
|
last_message = input_messages[-1]
|
||||||
|
last_message.context = "\n".join(rag_context)
|
||||||
|
|
||||||
|
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||||
|
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||||
|
msg = await attachment_message(self.tempdir, urls)
|
||||||
|
input_messages.append(msg)
|
||||||
|
|
||||||
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
while True:
|
while True:
|
||||||
|
@ -369,17 +387,13 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# where are the available tools?
|
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=self.model,
|
model=self.agent_config.model,
|
||||||
messages=input_messages,
|
messages=input_messages,
|
||||||
available_tools=self.instance_config.available_tools,
|
tools=self._get_tools(),
|
||||||
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=sampling_params,
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
max_tokens=max_gen_len,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
|
@ -464,7 +478,8 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
|
|
||||||
if len(message.tool_calls) == 0:
|
if len(message.tool_calls) == 0:
|
||||||
if stop_reason == StopReason.end_of_turn:
|
if stop_reason == StopReason.end_of_turn:
|
||||||
if len(attachments) > 0:
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||||
|
if len(output_attachments) > 0:
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
message.content += attachments
|
message.content += attachments
|
||||||
else:
|
else:
|
||||||
|
@ -572,63 +587,175 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
yield False
|
yield False
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(result_message.content, Attachment):
|
if out_attachment := interpret_content_as_attachment(
|
||||||
|
result_message.content
|
||||||
|
):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
# with the summary. We keep all generated attachments and then attach them to final message
|
# with the summary. We keep all generated attachments and then attach them to final message
|
||||||
attachments.append(result_message.content)
|
output_attachments.append(out_attachment)
|
||||||
elif isinstance(result_message.content, list) or isinstance(
|
|
||||||
result_message.content, tuple
|
|
||||||
):
|
|
||||||
for c in result_message.content:
|
|
||||||
if isinstance(c, Attachment):
|
|
||||||
attachments.append(c)
|
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
input_messages = input_messages + [message, result_message]
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
|
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
|
||||||
|
if session.memory_bank is None:
|
||||||
|
session.memory_bank = await self.memory_api.create_memory_bank(
|
||||||
|
name=f"memory_bank_{session.session_id}",
|
||||||
|
config=VectorMemoryBankConfig(
|
||||||
|
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def attachment_message(url: URL) -> ToolResponseMessage:
|
return session.memory_bank
|
||||||
uri = url.uri
|
|
||||||
assert uri.startswith("file://")
|
async def _should_retrieve_context(
|
||||||
filepath = uri[len("file://") :]
|
self, messages: List[Message], attachments: List[Attachment]
|
||||||
|
) -> bool:
|
||||||
|
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||||
|
if attachments:
|
||||||
|
if (
|
||||||
|
AgenticSystemTool.code_interpreter.value in enabled_tools
|
||||||
|
and self.agent_config.tool_choice == ToolChoice.required
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return AgenticSystemTool.memory.value in enabled_tools
|
||||||
|
|
||||||
|
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||||
|
for t in self.agent_config.tools:
|
||||||
|
if t.type == AgenticSystemTool.memory.value:
|
||||||
|
return t
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _retrieve_context(
|
||||||
|
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
||||||
|
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||||
|
bank_ids = []
|
||||||
|
|
||||||
|
memory = self._memory_tool_definition()
|
||||||
|
assert memory is not None, "Memory tool not configured"
|
||||||
|
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
||||||
|
|
||||||
|
if attachments:
|
||||||
|
bank = await self._ensure_memory_bank(session)
|
||||||
|
bank_ids.append(bank.bank_id)
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id=str(uuid.uuid4()),
|
||||||
|
content=a.content,
|
||||||
|
mime_type=a.mime_type,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for a in attachments
|
||||||
|
]
|
||||||
|
await self.memory_api.insert_documents(bank.bank_id, documents)
|
||||||
|
elif session.memory_bank:
|
||||||
|
bank_ids.append(session.memory_bank.bank_id)
|
||||||
|
|
||||||
|
if not bank_ids:
|
||||||
|
# this can happen if the per-session memory bank is not yet populated
|
||||||
|
# (i.e., no prior turns uploaded an Attachment)
|
||||||
|
return None, []
|
||||||
|
|
||||||
|
query = " ".join(m.content for m in messages)
|
||||||
|
tasks = [
|
||||||
|
self.memory_api.query_documents(
|
||||||
|
bank_id=bank_id,
|
||||||
|
query=query,
|
||||||
|
params={
|
||||||
|
"max_chunks": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for bank_id in bank_ids
|
||||||
|
]
|
||||||
|
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||||
|
chunks = [c for r in results for c in r.chunks]
|
||||||
|
scores = [s for r in results for s in r.scores]
|
||||||
|
|
||||||
|
# sort by score
|
||||||
|
chunks, scores = zip(
|
||||||
|
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||||
|
)
|
||||||
|
if not chunks:
|
||||||
|
return None, bank_ids
|
||||||
|
|
||||||
|
tokens = 0
|
||||||
|
picked = []
|
||||||
|
for c in chunks[: memory.max_chunks]:
|
||||||
|
tokens += c.token_count
|
||||||
|
if tokens > memory.max_tokens_in_context:
|
||||||
|
cprint(
|
||||||
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
break
|
||||||
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
|
*picked,
|
||||||
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
|
], bank_ids
|
||||||
|
|
||||||
|
def _get_tools(self) -> List[ToolDefinition]:
|
||||||
|
ret = []
|
||||||
|
for t in self.agent_config.tools:
|
||||||
|
if isinstance(t, BraveSearchToolDefinition):
|
||||||
|
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
||||||
|
elif isinstance(t, WolframAlphaToolDefinition):
|
||||||
|
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
||||||
|
elif isinstance(t, PhotogenToolDefinition):
|
||||||
|
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
|
||||||
|
elif isinstance(t, CodeInterpreterToolDefinition):
|
||||||
|
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
|
||||||
|
elif isinstance(t, FunctionCallToolDefinition):
|
||||||
|
ret.append(
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name=t.function_name,
|
||||||
|
description=t.description,
|
||||||
|
parameters=t.parameters,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||||
|
content = []
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
uri = url.uri
|
||||||
|
if uri.startswith("file://"):
|
||||||
|
filepath = uri[len("file://") :]
|
||||||
|
elif uri.startswith("http"):
|
||||||
|
path = urlparse(uri).path
|
||||||
|
basename = os.path.basename(path)
|
||||||
|
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||||
|
print(f"Downloading {url} -> {filepath}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(uri)
|
||||||
|
resp = r.text
|
||||||
|
with open(filepath, "w") as fp:
|
||||||
|
fp.write(resp)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
|
content.append(f'# There is a file accessible to you at "{filepath}"\n')
|
||||||
|
|
||||||
return ToolResponseMessage(
|
return ToolResponseMessage(
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
tool_name=BuiltinTool.code_interpreter,
|
||||||
content=f'# There is a file accessible to you at "{filepath}"',
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dialog(
|
|
||||||
messages: List[Message], prefix_messages: List[Message]
|
|
||||||
) -> List[Message]:
|
|
||||||
"""
|
|
||||||
Preprocesses the dialog by removing the system message and
|
|
||||||
adding the system message to the beginning of the dialog.
|
|
||||||
"""
|
|
||||||
ret = prefix_messages.copy()
|
|
||||||
|
|
||||||
for m in messages:
|
|
||||||
if m.role == Role.system.value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# NOTE: the ideal behavior is to use `file_path = ...` but that
|
|
||||||
# means we need to have stateful execution o f code which we currently
|
|
||||||
# do not have.
|
|
||||||
if isinstance(m.content, Attachment):
|
|
||||||
ret.append(attachment_message(m.content.url))
|
|
||||||
elif isinstance(m.content, list):
|
|
||||||
for c in m.content:
|
|
||||||
if isinstance(c, Attachment):
|
|
||||||
ret.append(attachment_message(c.url))
|
|
||||||
|
|
||||||
ret.append(m)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool_call_maybe(
|
async def execute_tool_call_maybe(
|
||||||
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
||||||
) -> List[ToolResponseMessage]:
|
) -> List[ToolResponseMessage]:
|
||||||
|
|
|
@ -8,62 +8,42 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, Dict
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
|
||||||
from llama_toolchain.inference.api import Inference
|
from llama_toolchain.inference.api import Inference
|
||||||
from llama_toolchain.inference.api.datatypes import BuiltinTool
|
from llama_toolchain.memory.api import Memory
|
||||||
from llama_toolchain.safety.api import Safety
|
from llama_toolchain.safety.api import Safety
|
||||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||||
from llama_toolchain.agentic_system.api import (
|
from llama_toolchain.tools.builtin import (
|
||||||
AgenticSystem,
|
|
||||||
AgenticSystemCreateRequest,
|
|
||||||
AgenticSystemCreateResponse,
|
|
||||||
AgenticSystemSessionCreateRequest,
|
|
||||||
AgenticSystemSessionCreateResponse,
|
|
||||||
AgenticSystemTurnCreateRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .agent_instance import AgentInstance
|
|
||||||
|
|
||||||
from .config import AgenticSystemConfig
|
|
||||||
|
|
||||||
from .tools.builtin import (
|
|
||||||
BraveSearchTool,
|
BraveSearchTool,
|
||||||
CodeInterpreterTool,
|
CodeInterpreterTool,
|
||||||
PhotogenTool,
|
PhotogenTool,
|
||||||
WolframAlphaTool,
|
WolframAlphaTool,
|
||||||
)
|
)
|
||||||
from .tools.safety import with_safety
|
from llama_toolchain.tools.safety import with_safety
|
||||||
|
|
||||||
|
from .agent_instance import ChatAgent
|
||||||
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
|
|
||||||
assert isinstance(
|
|
||||||
config, AgenticSystemConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = MetaReferenceAgenticSystemImpl(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.safety],
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
AGENT_INSTANCES_BY_ID = {}
|
AGENT_INSTANCES_BY_ID = {}
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: AgenticSystemConfig, inference_api: Inference, safety_api: Safety
|
self,
|
||||||
|
config: MetaReferenceImplConfig,
|
||||||
|
inference_api: Inference,
|
||||||
|
memory_api: Memory,
|
||||||
|
safety_api: Safety,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
@ -71,69 +51,61 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
|
|
||||||
async def create_agentic_system(
|
async def create_agentic_system(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemCreateRequest,
|
agent_config: AgentConfig,
|
||||||
) -> AgenticSystemCreateResponse:
|
) -> AgenticSystemCreateResponse:
|
||||||
system_id = str(uuid.uuid4())
|
agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
builtin_tools = []
|
builtin_tools = []
|
||||||
custom_tool_definitions = []
|
for tool_defn in agent_config.tools:
|
||||||
cfg = request.instance_config
|
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||||
for dfn in cfg.available_tools:
|
key = self.config.wolfram_api_key
|
||||||
if isinstance(dfn.tool_name, BuiltinTool):
|
if not key:
|
||||||
if dfn.tool_name == BuiltinTool.wolfram_alpha:
|
raise ValueError("Wolfram API key not defined in config")
|
||||||
key = self.config.wolfram_api_key
|
tool = WolframAlphaTool(key)
|
||||||
if not key:
|
elif isinstance(tool_defn, BraveSearchToolDefinition):
|
||||||
raise ValueError("Wolfram API key not defined in config")
|
key = self.config.brave_search_api_key
|
||||||
tool = WolframAlphaTool(key)
|
if not key:
|
||||||
elif dfn.tool_name == BuiltinTool.brave_search:
|
raise ValueError("Brave API key not defined in config")
|
||||||
key = self.config.brave_search_api_key
|
tool = BraveSearchTool(key)
|
||||||
if not key:
|
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||||
raise ValueError("Brave API key not defined in config")
|
tool = CodeInterpreterTool()
|
||||||
tool = BraveSearchTool(key)
|
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||||
elif dfn.tool_name == BuiltinTool.code_interpreter:
|
tool = PhotogenTool(
|
||||||
tool = CodeInterpreterTool()
|
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
|
||||||
elif dfn.tool_name == BuiltinTool.photogen:
|
|
||||||
tool = PhotogenTool(
|
|
||||||
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown builtin tool: {dfn.tool_name}")
|
|
||||||
|
|
||||||
builtin_tools.append(
|
|
||||||
with_safety(
|
|
||||||
tool, self.safety_api, dfn.input_shields, dfn.output_shields
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
custom_tool_definitions.append(dfn)
|
continue
|
||||||
|
|
||||||
AGENT_INSTANCES_BY_ID[system_id] = AgentInstance(
|
builtin_tools.append(
|
||||||
system_id=system_id,
|
with_safety(
|
||||||
instance_config=request.instance_config,
|
tool,
|
||||||
model=request.model,
|
self.safety_api,
|
||||||
|
tool_defn.input_shields,
|
||||||
|
tool_defn.output_shields,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
|
||||||
|
agent_config=agent_config,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
builtin_tools=builtin_tools,
|
|
||||||
custom_tool_definitions=custom_tool_definitions,
|
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
input_shields=cfg.input_shields,
|
memory_api=self.memory_api,
|
||||||
output_shields=cfg.output_shields,
|
builtin_tools=builtin_tools,
|
||||||
prefix_messages=cfg.debug_prefix_messages,
|
|
||||||
tool_prompt_format=cfg.tool_prompt_format,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgenticSystemCreateResponse(
|
return AgenticSystemCreateResponse(
|
||||||
system_id=system_id,
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agentic_system_session(
|
async def create_agentic_system_session(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemSessionCreateRequest,
|
agent_id: str,
|
||||||
|
session_name: str,
|
||||||
) -> AgenticSystemSessionCreateResponse:
|
) -> AgenticSystemSessionCreateResponse:
|
||||||
system_id = request.system_id
|
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||||
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
|
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||||
agent = AGENT_INSTANCES_BY_ID[system_id]
|
|
||||||
|
|
||||||
session = agent.create_session(request.session_name)
|
session = agent.create_session(session_name)
|
||||||
return AgenticSystemSessionCreateResponse(
|
return AgenticSystemSessionCreateResponse(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
)
|
)
|
||||||
|
@ -142,9 +114,9 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemTurnCreateRequest,
|
request: AgenticSystemTurnCreateRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
system_id = request.system_id
|
agent_id = request.agent_id
|
||||||
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
|
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||||
agent = AGENT_INSTANCES_BY_ID[system_id]
|
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
request.session_id in agent.sessions
|
request.session_id in agent.sessions
|
||||||
|
|
|
@ -9,6 +9,6 @@ from typing import Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemConfig(BaseModel):
|
class MetaReferenceImplConfig(BaseModel):
|
||||||
brave_search_api_key: Optional[str] = None
|
brave_search_api_key: Optional[str] = None
|
||||||
wolfram_api_key: Optional[str] = None
|
wolfram_api_key: Optional[str] = None
|
||||||
|
|
|
@ -9,12 +9,13 @@ from typing import List
|
||||||
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.safety.api.datatypes import (
|
from llama_toolchain.safety.api import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
|
RunShieldRequest,
|
||||||
|
Safety,
|
||||||
ShieldDefinition,
|
ShieldDefinition,
|
||||||
ShieldResponse,
|
ShieldResponse,
|
||||||
)
|
)
|
||||||
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
|
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
|
|
|
@ -1,180 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import textwrap
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api import (
|
|
||||||
BuiltinTool,
|
|
||||||
Message,
|
|
||||||
SystemMessage,
|
|
||||||
ToolDefinition,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .tools.builtin import SingleMessageBuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
def get_agentic_prefix_messages(
|
|
||||||
builtin_tools: List[SingleMessageBuiltinTool],
|
|
||||||
custom_tools: List[ToolDefinition],
|
|
||||||
tool_prompt_format: ToolPromptFormat,
|
|
||||||
) -> List[Message]:
|
|
||||||
messages = []
|
|
||||||
content = ""
|
|
||||||
if builtin_tools:
|
|
||||||
content += "Environment: ipython\n"
|
|
||||||
|
|
||||||
tool_str = ", ".join(
|
|
||||||
[
|
|
||||||
t.get_name()
|
|
||||||
for t in builtin_tools
|
|
||||||
if t.get_name() != BuiltinTool.code_interpreter.value
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if tool_str:
|
|
||||||
content += f"Tools: {tool_str}"
|
|
||||||
|
|
||||||
current_date = datetime.now()
|
|
||||||
formatted_date = current_date.strftime("%d %B %Y")
|
|
||||||
date_str = f"""
|
|
||||||
Cutting Knowledge Date: December 2023
|
|
||||||
Today Date: {formatted_date}\n"""
|
|
||||||
content += date_str
|
|
||||||
messages.append(SystemMessage(content=content))
|
|
||||||
|
|
||||||
if custom_tools:
|
|
||||||
if tool_prompt_format == ToolPromptFormat.function_tag:
|
|
||||||
text = prompt_for_function_tag(custom_tools)
|
|
||||||
messages.append(UserMessage(content=text))
|
|
||||||
elif tool_prompt_format == ToolPromptFormat.json:
|
|
||||||
text = prompt_for_json(custom_tools)
|
|
||||||
messages.append(UserMessage(content=text))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Tool prompt format {tool_prompt_format} is not supported"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
messages.append(SystemMessage(content=content))
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
|
|
||||||
tool_defs = "\n".join(
|
|
||||||
translate_custom_tool_definition_to_json(t) for t in custom_tools
|
|
||||||
)
|
|
||||||
content = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
Answer the user's question by making use of the following functions if needed.
|
|
||||||
If none of the function can be used, please say so.
|
|
||||||
Here is a list of functions in JSON format:
|
|
||||||
{tool_defs}
|
|
||||||
|
|
||||||
Return function calls in JSON format.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
content = content.lstrip("\n").format(tool_defs=tool_defs)
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
|
|
||||||
custom_tool_params = ""
|
|
||||||
for t in custom_tools:
|
|
||||||
custom_tool_params += get_instruction_string(t) + "\n"
|
|
||||||
custom_tool_params += get_parameters_string(t) + "\n\n"
|
|
||||||
|
|
||||||
content = f"""
|
|
||||||
You have access to the following functions:
|
|
||||||
|
|
||||||
{custom_tool_params}
|
|
||||||
Think very carefully before calling functions.
|
|
||||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
|
||||||
|
|
||||||
<function=example_function_name>{{"example_name": "example_value"}}</function>
|
|
||||||
|
|
||||||
Reminder:
|
|
||||||
- If looking for real time information use relevant functions before falling back to brave_search
|
|
||||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
|
||||||
- Required parameters MUST be specified
|
|
||||||
- Only call one function at a time
|
|
||||||
- Put the entire function call reply on one line
|
|
||||||
"""
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def get_instruction_string(custom_tool_definition) -> str:
|
|
||||||
return f"Use the function '{custom_tool_definition.tool_name}' to '{custom_tool_definition.description}'"
|
|
||||||
|
|
||||||
|
|
||||||
def get_parameters_string(custom_tool_definition) -> str:
|
|
||||||
return json.dumps(
|
|
||||||
{
|
|
||||||
"name": custom_tool_definition.tool_name,
|
|
||||||
"description": custom_tool_definition.description,
|
|
||||||
"parameters": {
|
|
||||||
name: definition.__dict__
|
|
||||||
for name, definition in custom_tool_definition.parameters.items()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def translate_custom_tool_definition_to_json(tool_def):
|
|
||||||
"""Translates ToolDefinition to json as expected by model
|
|
||||||
eg. output for a function
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "conv_int",
|
|
||||||
"description": "Convert serialized fract24 integer into int value.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "object",
|
|
||||||
"description": ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"required": ["data"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
assert isinstance(tool_def.tool_name, str)
|
|
||||||
func_def = {"type": "function", "function": {}}
|
|
||||||
func_def["function"]["name"] = tool_def.tool_name
|
|
||||||
func_def["function"]["description"] = tool_def.description or ""
|
|
||||||
if tool_def.parameters:
|
|
||||||
required = []
|
|
||||||
properties = []
|
|
||||||
for p_name, p_def in tool_def.parameters.items():
|
|
||||||
properties.append(
|
|
||||||
{
|
|
||||||
p_name: {
|
|
||||||
# TODO: see if this should not always be object
|
|
||||||
"type": "object",
|
|
||||||
"description": p_def.description or "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if p_def.required:
|
|
||||||
required.append(p_name)
|
|
||||||
func_def["function"]["parameters"] = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": properties,
|
|
||||||
"required": required,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
func_def["function"]["parameters"] = {}
|
|
||||||
|
|
||||||
return json.dumps(func_def, indent=4)
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_agentic_system_providers() -> List[ProviderSpec]:
|
def available_agentic_system_providers() -> List[ProviderSpec]:
|
||||||
|
@ -16,15 +16,19 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
"codeshield",
|
||||||
|
"matplotlib",
|
||||||
"pillow",
|
"pillow",
|
||||||
|
"pandas",
|
||||||
|
"scikit-learn",
|
||||||
"torch",
|
"torch",
|
||||||
"transformers",
|
"transformers",
|
||||||
],
|
],
|
||||||
module="llama_toolchain.agentic_system.meta_reference",
|
module="llama_toolchain.agentic_system.meta_reference",
|
||||||
config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig",
|
config_class="llama_toolchain.agentic_system.meta_reference.MetaReferenceImplConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
Api.safety,
|
Api.safety,
|
||||||
|
Api.memory,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,83 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, List
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
|
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
|
||||||
AgenticSystem,
|
|
||||||
AgenticSystemTurnCreateRequest,
|
|
||||||
AgenticSystemTurnResponseEventType as EventType,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api import Message
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_with_custom_tools(
|
|
||||||
system: AgenticSystem,
|
|
||||||
system_id: str,
|
|
||||||
session_id: str,
|
|
||||||
messages: List[Message],
|
|
||||||
custom_tools: List[Any],
|
|
||||||
max_iters: int = 5,
|
|
||||||
stream: bool = True,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
# first create a session, or do you keep a persistent session?
|
|
||||||
tools_dict = {t.get_name(): t for t in custom_tools}
|
|
||||||
|
|
||||||
current_messages = messages.copy()
|
|
||||||
n_iter = 0
|
|
||||||
while n_iter < max_iters:
|
|
||||||
n_iter += 1
|
|
||||||
|
|
||||||
request = AgenticSystemTurnCreateRequest(
|
|
||||||
system_id=system_id,
|
|
||||||
session_id=session_id,
|
|
||||||
messages=current_messages,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn = None
|
|
||||||
async for chunk in system.create_agentic_system_turn(request):
|
|
||||||
if chunk.event.payload.event_type != EventType.turn_complete.value:
|
|
||||||
yield chunk
|
|
||||||
else:
|
|
||||||
turn = chunk.event.payload.turn
|
|
||||||
|
|
||||||
message = turn.output_message
|
|
||||||
if len(message.tool_calls) == 0:
|
|
||||||
yield chunk
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.stop_reason == StopReason.out_of_tokens:
|
|
||||||
yield chunk
|
|
||||||
return
|
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
if tool_call.tool_name not in tools_dict:
|
|
||||||
m = ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
|
|
||||||
)
|
|
||||||
next_message = m
|
|
||||||
else:
|
|
||||||
tool = tools_dict[tool_call.tool_name]
|
|
||||||
result_messages = await execute_custom_tool(tool, message)
|
|
||||||
next_message = result_messages[0]
|
|
||||||
|
|
||||||
yield next_message
|
|
||||||
current_messages = [next_message]
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_custom_tool(tool: Any, message: Message) -> List[Message]:
|
|
||||||
result_messages = await tool.run([message])
|
|
||||||
assert (
|
|
||||||
len(result_messages) == 1
|
|
||||||
), f"Expected single message, got {len(result_messages)}"
|
|
||||||
|
|
||||||
return result_messages
|
|
|
@ -1,122 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
|
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
|
||||||
AgenticSystemCreateRequest,
|
|
||||||
AgenticSystemInstanceConfig,
|
|
||||||
AgenticSystemSessionCreateRequest,
|
|
||||||
AgenticSystemToolDefinition,
|
|
||||||
)
|
|
||||||
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
|
|
||||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.tools.custom.execute import (
|
|
||||||
execute_with_custom_tools,
|
|
||||||
)
|
|
||||||
from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should move back to the llama-agentic-system repo
|
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemClientWrapper:
|
|
||||||
def __init__(self, api, system_id, custom_tools):
|
|
||||||
self.api = api
|
|
||||||
self.system_id = system_id
|
|
||||||
self.custom_tools = custom_tools
|
|
||||||
self.session_id = None
|
|
||||||
|
|
||||||
async def create_session(self, name: str = None):
|
|
||||||
if name is None:
|
|
||||||
name = f"Session-{uuid.uuid4()}"
|
|
||||||
|
|
||||||
response = await self.api.create_agentic_system_session(
|
|
||||||
AgenticSystemSessionCreateRequest(
|
|
||||||
system_id=self.system_id,
|
|
||||||
session_name=name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.session_id = response.session_id
|
|
||||||
return self.session_id
|
|
||||||
|
|
||||||
async def run(self, messages: List[Message], stream: bool = True):
|
|
||||||
async for chunk in execute_with_custom_tools(
|
|
||||||
self.api,
|
|
||||||
self.system_id,
|
|
||||||
self.session_id,
|
|
||||||
messages,
|
|
||||||
self.custom_tools,
|
|
||||||
stream=stream,
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
async def get_agent_system_instance(
|
|
||||||
host: str,
|
|
||||||
port: int,
|
|
||||||
custom_tools: Optional[List[Any]] = None,
|
|
||||||
disable_safety: bool = False,
|
|
||||||
model: str = "Meta-Llama3.1-8B-Instruct",
|
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
||||||
) -> AgenticSystemClientWrapper:
|
|
||||||
custom_tools = custom_tools or []
|
|
||||||
|
|
||||||
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
|
|
||||||
|
|
||||||
tool_definitions = [
|
|
||||||
AgenticSystemToolDefinition(
|
|
||||||
tool_name=BuiltinTool.brave_search,
|
|
||||||
),
|
|
||||||
AgenticSystemToolDefinition(
|
|
||||||
tool_name=BuiltinTool.wolfram_alpha,
|
|
||||||
),
|
|
||||||
AgenticSystemToolDefinition(
|
|
||||||
tool_name=BuiltinTool.photogen,
|
|
||||||
),
|
|
||||||
AgenticSystemToolDefinition(
|
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
|
||||||
),
|
|
||||||
] + [t.get_tool_definition() for t in custom_tools]
|
|
||||||
|
|
||||||
if not disable_safety:
|
|
||||||
for t in tool_definitions:
|
|
||||||
t.input_shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
|
||||||
t.output_shields = [
|
|
||||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
|
||||||
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
|
|
||||||
]
|
|
||||||
|
|
||||||
create_request = AgenticSystemCreateRequest(
|
|
||||||
model=model,
|
|
||||||
instance_config=AgenticSystemInstanceConfig(
|
|
||||||
instructions="You are a helpful assistant",
|
|
||||||
available_tools=tool_definitions,
|
|
||||||
input_shields=(
|
|
||||||
[]
|
|
||||||
if disable_safety
|
|
||||||
else [
|
|
||||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
|
||||||
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
output_shields=(
|
|
||||||
[]
|
|
||||||
if disable_safety
|
|
||||||
else [
|
|
||||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
sampling_params=SamplingParams(),
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
create_response = await api.create_agentic_system(create_request)
|
|
||||||
return AgenticSystemClientWrapper(api, create_response.system_id, custom_tools)
|
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .distribution import DistributionParser # noqa
|
from .api import * # noqa: F401 F403
|
61
llama_toolchain/batch_inference/api/api.py
Normal file
61
llama_toolchain/batch_inference/api/api.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchCompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
content_batch: List[InterleavedTextMedia]
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchCompletionResponse(BaseModel):
|
||||||
|
completion_message_batch: List[CompletionMessage]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchChatCompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages_batch: List[List[Message]]
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
|
# zero-shot tool definitions as input to the model
|
||||||
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
|
default=ToolPromptFormat.json
|
||||||
|
)
|
||||||
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
|
completion_message_batch: List[CompletionMessage]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchInference(Protocol):
|
||||||
|
@webmethod(route="/batch_inference/completion")
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
request: BatchCompletionRequest,
|
||||||
|
) -> BatchCompletionResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/batch_inference/chat_completion")
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
request: BatchChatCompletionRequest,
|
||||||
|
) -> BatchChatCompletionResponse: ...
|
|
@ -1,106 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import shlex
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionConfigure(Subcommand):
|
|
||||||
"""Llama cli for configuring llama toolchain configs"""
|
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"configure",
|
|
||||||
prog="llama distribution configure",
|
|
||||||
description="configure a llama stack distribution",
|
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
|
|
||||||
|
|
||||||
def _add_arguments(self):
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--name",
|
|
||||||
type=str,
|
|
||||||
help="Name of the distribution to configure",
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
|
||||||
from llama_toolchain.distribution.datatypes import DistributionConfig
|
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
|
||||||
|
|
||||||
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
|
|
||||||
if not config_file.exists():
|
|
||||||
self.parser.error(
|
|
||||||
f"Could not find {config_file}. Please run `llama distribution install` first"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# we need to find the spec from the name
|
|
||||||
with open(config_file, "r") as f:
|
|
||||||
config = DistributionConfig(**yaml.safe_load(f))
|
|
||||||
|
|
||||||
dist = resolve_distribution_spec(config.spec)
|
|
||||||
if dist is None:
|
|
||||||
raise ValueError(f"Could not find any registered spec `{config.spec}`")
|
|
||||||
|
|
||||||
configure_llama_distribution(dist, config)
|
|
||||||
|
|
||||||
|
|
||||||
def configure_llama_distribution(dist: "Distribution", config: "DistributionConfig"):
|
|
||||||
from llama_toolchain.common.exec import run_command
|
|
||||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
|
||||||
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
|
||||||
|
|
||||||
python_exe = run_command(shlex.split("which python"))
|
|
||||||
# simple check
|
|
||||||
conda_env = config.conda_env
|
|
||||||
if conda_env not in python_exe:
|
|
||||||
raise ValueError(
|
|
||||||
f"Please re-run configure by activating the `{conda_env}` conda environment"
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.providers:
|
|
||||||
cprint(
|
|
||||||
f"Configuration already exists for {config.name}. Will overwrite...",
|
|
||||||
"yellow",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
|
||||||
|
|
||||||
for api, provider_spec in dist.provider_specs.items():
|
|
||||||
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
|
||||||
provider_config = prompt_for_config(
|
|
||||||
config_type,
|
|
||||||
(
|
|
||||||
config_type(**config.providers[api.value])
|
|
||||||
if api.value in config.providers
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
print("")
|
|
||||||
|
|
||||||
config.providers[api.value] = {
|
|
||||||
"provider_id": provider_spec.provider_id,
|
|
||||||
**provider_config.dict(),
|
|
||||||
}
|
|
||||||
|
|
||||||
config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
|
|
||||||
with open(config_path, "w") as fp:
|
|
||||||
dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
|
||||||
fp.write(yaml.dump(dist_config, sort_keys=False))
|
|
||||||
|
|
||||||
print(f"YAML configuration has been written to {config_path}")
|
|
|
@ -1,43 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionCreate(Subcommand):
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"create",
|
|
||||||
prog="llama distribution create",
|
|
||||||
description="create a Llama stack distribution",
|
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._run_distribution_create_cmd)
|
|
||||||
|
|
||||||
def _add_arguments(self):
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--name",
|
|
||||||
type=str,
|
|
||||||
help="Name of the distribution to create",
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
# for each Api the user wants to support, we should
|
|
||||||
# get the list of available providers, ask which one the user
|
|
||||||
# wants to pick and then ask for their configuration.
|
|
||||||
|
|
||||||
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
|
||||||
|
|
||||||
dist = resolve_distribution_spec(args.name)
|
|
||||||
if dist is not None:
|
|
||||||
self.parser.error(f"Distribution with name {args.name} already exists")
|
|
||||||
return
|
|
||||||
|
|
||||||
raise NotImplementedError()
|
|
|
@ -1,34 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
|
|
||||||
from .configure import DistributionConfigure
|
|
||||||
from .create import DistributionCreate
|
|
||||||
from .install import DistributionInstall
|
|
||||||
from .list import DistributionList
|
|
||||||
from .start import DistributionStart
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionParser(Subcommand):
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"distribution",
|
|
||||||
prog="llama distribution",
|
|
||||||
description="Operate on llama stack distributions",
|
|
||||||
)
|
|
||||||
|
|
||||||
subparsers = self.parser.add_subparsers(title="distribution_subcommands")
|
|
||||||
|
|
||||||
# Add sub-commands
|
|
||||||
DistributionList.create(subparsers)
|
|
||||||
DistributionInstall.create(subparsers)
|
|
||||||
DistributionCreate.create(subparsers)
|
|
||||||
DistributionConfigure.create(subparsers)
|
|
||||||
DistributionStart.create(subparsers)
|
|
|
@ -1,111 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pkg_resources
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionInstall(Subcommand):
|
|
||||||
"""Llama cli for configuring llama toolchain configs"""
|
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"install",
|
|
||||||
prog="llama distribution install",
|
|
||||||
description="Install a llama stack distribution",
|
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._run_distribution_install_cmd)
|
|
||||||
|
|
||||||
def _add_arguments(self):
|
|
||||||
from llama_toolchain.distribution.registry import available_distribution_specs
|
|
||||||
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--spec",
|
|
||||||
type=str,
|
|
||||||
help="Distribution spec to install (try local-ollama)",
|
|
||||||
required=True,
|
|
||||||
choices=[d.spec_id for d in available_distribution_specs()],
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--name",
|
|
||||||
type=str,
|
|
||||||
help="What should the installation be called locally?",
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--conda-env",
|
|
||||||
type=str,
|
|
||||||
help="conda env in which this distribution will run (default = distribution name)",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
|
||||||
from llama_toolchain.distribution.datatypes import DistributionConfig
|
|
||||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
|
||||||
|
|
||||||
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
|
||||||
script = pkg_resources.resource_filename(
|
|
||||||
"llama_toolchain",
|
|
||||||
"distribution/install_distribution.sh",
|
|
||||||
)
|
|
||||||
|
|
||||||
dist = resolve_distribution_spec(args.spec)
|
|
||||||
if dist is None:
|
|
||||||
self.parser.error(f"Could not find distribution {args.spec}")
|
|
||||||
return
|
|
||||||
|
|
||||||
distrib_dir = DISTRIBS_BASE_DIR / args.name
|
|
||||||
os.makedirs(distrib_dir, exist_ok=True)
|
|
||||||
|
|
||||||
deps = distribution_dependencies(dist)
|
|
||||||
if not args.conda_env:
|
|
||||||
print(f"Using {args.name} as the Conda environment for this distribution")
|
|
||||||
|
|
||||||
conda_env = args.conda_env or args.name
|
|
||||||
|
|
||||||
config_file = distrib_dir / "config.yaml"
|
|
||||||
if config_file.exists():
|
|
||||||
c = DistributionConfig(**yaml.safe_load(config_file.read_text()))
|
|
||||||
if c.spec != dist.spec_id:
|
|
||||||
self.parser.error(
|
|
||||||
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if c.conda_env != conda_env:
|
|
||||||
self.parser.error(
|
|
||||||
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
with open(config_file, "w") as f:
|
|
||||||
c = DistributionConfig(
|
|
||||||
spec=dist.spec_id,
|
|
||||||
name=args.name,
|
|
||||||
conda_env=conda_env,
|
|
||||||
)
|
|
||||||
f.write(yaml.dump(c.dict(), sort_keys=False))
|
|
||||||
|
|
||||||
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
|
|
||||||
|
|
||||||
assert return_code == 0, cprint(
|
|
||||||
f"Failed to install distribution {dist.spec_id}", color="red"
|
|
||||||
)
|
|
||||||
cprint(
|
|
||||||
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
|
|
||||||
color="green",
|
|
||||||
)
|
|
|
@ -1,81 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import pkg_resources
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionStart(Subcommand):
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"start",
|
|
||||||
prog="llama distribution start",
|
|
||||||
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
|
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._run_distribution_start_cmd)
|
|
||||||
|
|
||||||
def _add_arguments(self):
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--name",
|
|
||||||
type=str,
|
|
||||||
help="Name of the distribution to start",
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
help="Port to run the server on. Defaults to 5000",
|
|
||||||
default=5000,
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--disable-ipv6",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable IPv6 support",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
|
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
|
||||||
|
|
||||||
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
|
|
||||||
if not config_file.exists():
|
|
||||||
self.parser.error(
|
|
||||||
f"Could not find {config_file}. Please run `llama distribution install` first"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# we need to find the spec from the name
|
|
||||||
with open(config_file, "r") as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
dist = resolve_distribution_spec(config["spec"])
|
|
||||||
if dist is None:
|
|
||||||
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
|
|
||||||
|
|
||||||
conda_env = config["conda_env"]
|
|
||||||
if not conda_env:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not find Conda environment for distribution `{args.name}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
script = pkg_resources.resource_filename(
|
|
||||||
"llama_toolchain",
|
|
||||||
"distribution/start_distribution.sh",
|
|
||||||
)
|
|
||||||
args = [script, conda_env, config_file, "--port", str(args.port)] + (
|
|
||||||
["--disable-ipv6"] if args.disable_ipv6 else []
|
|
||||||
)
|
|
||||||
|
|
||||||
run_with_pty(args)
|
|
|
@ -6,9 +6,9 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from .distribution import DistributionParser
|
|
||||||
from .download import Download
|
from .download import Download
|
||||||
from .model import ModelParser
|
from .model import ModelParser
|
||||||
|
from .stack import StackParser
|
||||||
|
|
||||||
|
|
||||||
class LlamaCLIParser:
|
class LlamaCLIParser:
|
||||||
|
@ -29,7 +29,7 @@ class LlamaCLIParser:
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
Download.create(subparsers)
|
Download.create(subparsers)
|
||||||
ModelParser.create(subparsers)
|
ModelParser.create(subparsers)
|
||||||
DistributionParser.create(subparsers)
|
StackParser.create(subparsers)
|
||||||
|
|
||||||
# Import sub-commands from agentic_system if they exist
|
# Import sub-commands from agentic_system if they exist
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -32,6 +32,16 @@ class ModelTemplate(Subcommand):
|
||||||
self._add_arguments()
|
self._add_arguments()
|
||||||
self.parser.set_defaults(func=self._run_model_template_cmd)
|
self.parser.set_defaults(func=self._run_model_template_cmd)
|
||||||
|
|
||||||
|
def _prompt_type(self, value):
|
||||||
|
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ToolPromptFormat(value.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"{value} is not a valid ToolPromptFormat. Choose from {', '.join(t.value for t in ToolPromptFormat)}"
|
||||||
|
) from None
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"-m",
|
"-m",
|
||||||
|
@ -46,6 +56,18 @@ class ModelTemplate(Subcommand):
|
||||||
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
|
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
|
||||||
required=False,
|
required=False,
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--format",
|
||||||
|
type=str,
|
||||||
|
help="ToolPromptFormat (json or function_tag). This flag is used to print the template in a specific formats.",
|
||||||
|
required=False,
|
||||||
|
default="json",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--raw",
|
||||||
|
action="store_true",
|
||||||
|
help="If set to true, don't pretty-print into a table. Useful to copy-paste.",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_models.llama3.api.interface import (
|
from llama_models.llama3.api.interface import (
|
||||||
|
@ -56,22 +78,32 @@ class ModelTemplate(Subcommand):
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_toolchain.cli.table import print_table
|
||||||
|
|
||||||
if args.name:
|
if args.name:
|
||||||
template, tokens_info = render_jinja_template(args.name)
|
tool_prompt_format = self._prompt_type(args.format)
|
||||||
|
template, tokens_info = render_jinja_template(args.name, tool_prompt_format)
|
||||||
rendered = ""
|
rendered = ""
|
||||||
for tok, is_special in tokens_info:
|
for tok, is_special in tokens_info:
|
||||||
if is_special:
|
if is_special:
|
||||||
rendered += colored(tok, "yellow", attrs=["bold"])
|
rendered += colored(tok, "yellow", attrs=["bold"])
|
||||||
else:
|
else:
|
||||||
rendered += tok
|
rendered += tok
|
||||||
rendered += "\n"
|
|
||||||
print_table(
|
if not args.raw:
|
||||||
[
|
rendered = rendered.replace("\n", "↵\n")
|
||||||
("Name", colored(template.template_name, "white", attrs=["bold"])),
|
print_table(
|
||||||
("Template", rendered),
|
[
|
||||||
("Notes", template.notes),
|
(
|
||||||
],
|
"Name",
|
||||||
separate_rows=True,
|
colored(template.template_name, "white", attrs=["bold"]),
|
||||||
)
|
),
|
||||||
|
("Template", rendered),
|
||||||
|
("Notes", template.notes),
|
||||||
|
],
|
||||||
|
separate_rows=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Template: ", template.template_name)
|
||||||
|
print("=" * 40)
|
||||||
|
print(rendered)
|
||||||
else:
|
else:
|
||||||
templates = list_jinja_templates()
|
templates = list_jinja_templates()
|
||||||
headers = ["Role", "Template Name"]
|
headers = ["Role", "Template Name"]
|
||||||
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import OllamaImplConfig # noqa
|
from .stack import StackParser # noqa
|
||||||
from .ollama import get_provider_impl # noqa
|
|
133
llama_toolchain/cli/stack/build.py
Normal file
133
llama_toolchain/cli/stack/build.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def parse_api_provider_tuples(
|
||||||
|
tuples: str, parser: argparse.ArgumentParser
|
||||||
|
) -> Dict[str, ProviderSpec]:
|
||||||
|
from llama_toolchain.core.distribution import api_providers
|
||||||
|
|
||||||
|
all_providers = api_providers()
|
||||||
|
|
||||||
|
deps = {}
|
||||||
|
for dep in tuples.split(","):
|
||||||
|
dep = dep.strip()
|
||||||
|
if not dep:
|
||||||
|
continue
|
||||||
|
api_str, provider = dep.split("=")
|
||||||
|
api = Api(api_str)
|
||||||
|
|
||||||
|
provider = provider.strip()
|
||||||
|
if provider not in all_providers[api]:
|
||||||
|
parser.error(f"Provider `{provider}` is not available for API `{api}`")
|
||||||
|
return
|
||||||
|
deps[api] = all_providers[api][provider]
|
||||||
|
|
||||||
|
return deps
|
||||||
|
|
||||||
|
|
||||||
|
class StackBuild(Subcommand):
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"build",
|
||||||
|
prog="llama stack build",
|
||||||
|
description="Build a Llama stack container",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_stack_build_command)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
from llama_toolchain.core.distribution_registry import available_distribution_specs
|
||||||
|
from llama_toolchain.core.package import (
|
||||||
|
BuildType,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
|
||||||
|
self.parser.add_argument(
|
||||||
|
"distribution",
|
||||||
|
type=str,
|
||||||
|
help="Distribution to build (either \"adhoc\" OR one of: {})".format(allowed_ids),
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"api_providers",
|
||||||
|
nargs='?',
|
||||||
|
help="Comma separated list of (api=provider) tuples",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the build target (image, conda env)",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--type",
|
||||||
|
type=str,
|
||||||
|
default="conda_env",
|
||||||
|
choices=[v.value for v in BuildType],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.core.distribution_registry import resolve_distribution_spec
|
||||||
|
from llama_toolchain.core.package import (
|
||||||
|
ApiInput,
|
||||||
|
BuildType,
|
||||||
|
build_package,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_inputs = []
|
||||||
|
if args.distribution == "adhoc":
|
||||||
|
if not args.api_providers:
|
||||||
|
self.parser.error("You must specify API providers with (api=provider,...) for building an adhoc distribution")
|
||||||
|
return
|
||||||
|
|
||||||
|
parsed = parse_api_provider_tuples(args.api_providers, self.parser)
|
||||||
|
for api, provider_spec in parsed.items():
|
||||||
|
for dep in provider_spec.api_dependencies:
|
||||||
|
if dep not in parsed:
|
||||||
|
self.parser.error(f"API {api} needs dependency {dep} provided also")
|
||||||
|
return
|
||||||
|
|
||||||
|
api_inputs.append(
|
||||||
|
ApiInput(
|
||||||
|
api=api,
|
||||||
|
provider=provider_spec.provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
docker_image = None
|
||||||
|
else:
|
||||||
|
if args.api_providers:
|
||||||
|
self.parser.error("You cannot specify API providers for pre-registered distributions")
|
||||||
|
return
|
||||||
|
|
||||||
|
dist = resolve_distribution_spec(args.distribution)
|
||||||
|
if dist is None:
|
||||||
|
self.parser.error(f"Could not find distribution {args.distribution}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for api, provider_id in dist.providers.items():
|
||||||
|
api_inputs.append(
|
||||||
|
ApiInput(
|
||||||
|
api=api,
|
||||||
|
provider=provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
docker_image = dist.docker_image
|
||||||
|
|
||||||
|
build_package(
|
||||||
|
api_inputs,
|
||||||
|
build_type=BuildType(args.type),
|
||||||
|
name=args.name,
|
||||||
|
distribution_id=args.distribution,
|
||||||
|
docker_image=docker_image,
|
||||||
|
)
|
106
llama_toolchain/cli/stack/configure.py
Normal file
106
llama_toolchain/cli/stack/configure.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||||
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class StackConfigure(Subcommand):
|
||||||
|
"""Llama cli for configuring llama toolchain configs"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"configure",
|
||||||
|
prog="llama stack configure",
|
||||||
|
description="configure a llama stack distribution",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_stack_configure_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
from llama_toolchain.core.distribution_registry import (
|
||||||
|
available_distribution_specs,
|
||||||
|
)
|
||||||
|
from llama_toolchain.core.package import BuildType
|
||||||
|
|
||||||
|
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
|
||||||
|
self.parser.add_argument(
|
||||||
|
"distribution",
|
||||||
|
type=str,
|
||||||
|
choices=allowed_ids,
|
||||||
|
help="Distribution (one of: {})".format(allowed_ids),
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the build",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--type",
|
||||||
|
type=str,
|
||||||
|
default="conda_env",
|
||||||
|
choices=[v.value for v in BuildType],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.core.package import BuildType
|
||||||
|
|
||||||
|
build_type = BuildType(args.type)
|
||||||
|
name = args.name
|
||||||
|
config_file = (
|
||||||
|
BUILDS_BASE_DIR
|
||||||
|
/ args.distribution
|
||||||
|
/ build_type.descriptor()
|
||||||
|
/ f"{name}.yaml"
|
||||||
|
)
|
||||||
|
if not config_file.exists():
|
||||||
|
self.parser.error(
|
||||||
|
f"Could not find {config_file}. Please run `llama stack build` first"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
configure_llama_distribution(config_file)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_llama_distribution(config_file: Path) -> None:
|
||||||
|
from llama_toolchain.common.serialize import EnumEncoder
|
||||||
|
from llama_toolchain.core.configure import configure_api_providers
|
||||||
|
from llama_toolchain.core.distribution_registry import resolve_distribution_spec
|
||||||
|
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config = PackageConfig(**yaml.safe_load(f))
|
||||||
|
|
||||||
|
dist = resolve_distribution_spec(config.distribution_id)
|
||||||
|
if dist is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find any registered distribution `{config.distribution_id}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.providers:
|
||||||
|
cprint(
|
||||||
|
f"Configuration already exists for {config.distribution_id}. Will overwrite...",
|
||||||
|
"yellow",
|
||||||
|
attrs=["bold"],
|
||||||
|
)
|
||||||
|
|
||||||
|
config.providers = configure_api_providers(config.providers)
|
||||||
|
|
||||||
|
with open(config_file, "w") as fp:
|
||||||
|
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
||||||
|
fp.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
|
print(f"YAML configuration has been written to {config_file}")
|
|
@ -10,13 +10,13 @@ import json
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class DistributionList(Subcommand):
|
class StackList(Subcommand):
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
"list",
|
"list-distributions",
|
||||||
prog="llama distribution list",
|
prog="llama stack list-distributions",
|
||||||
description="Show available llama stack distributions",
|
description="Show available Llama Stack Distributions",
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
self._add_arguments()
|
self._add_arguments()
|
||||||
|
@ -27,21 +27,23 @@ class DistributionList(Subcommand):
|
||||||
|
|
||||||
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_toolchain.cli.table import print_table
|
||||||
from llama_toolchain.distribution.registry import available_distribution_specs
|
from llama_toolchain.core.distribution_registry import (
|
||||||
|
available_distribution_specs,
|
||||||
|
)
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
||||||
"Spec ID",
|
"Distribution ID",
|
||||||
"ProviderSpecs",
|
"Providers",
|
||||||
"Description",
|
"Description",
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for spec in available_distribution_specs():
|
for spec in available_distribution_specs():
|
||||||
providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()}
|
providers = {k.value: v for k, v in spec.providers.items()}
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
spec.spec_id,
|
spec.distribution_id,
|
||||||
json.dumps(providers, indent=2),
|
json.dumps(providers, indent=2),
|
||||||
spec.description,
|
spec.description,
|
||||||
]
|
]
|
106
llama_toolchain/cli/stack/run.py
Normal file
106
llama_toolchain/cli/stack/run.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class StackRun(Subcommand):
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"run",
|
||||||
|
prog="llama stack run",
|
||||||
|
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_stack_run_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
from llama_toolchain.core.package import BuildType
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"distribution",
|
||||||
|
type=str,
|
||||||
|
help="Distribution whose build you want to start",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the build you want to start",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--type",
|
||||||
|
type=str,
|
||||||
|
default="conda_env",
|
||||||
|
choices=[v.value for v in BuildType],
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
help="Port to run the server on. Defaults to 5000",
|
||||||
|
default=5000,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--disable-ipv6",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable IPv6 support",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.common.exec import run_with_pty
|
||||||
|
from llama_toolchain.core.package import BuildType
|
||||||
|
|
||||||
|
build_type = BuildType(args.type)
|
||||||
|
build_dir = BUILDS_BASE_DIR / args.distribution / build_type.descriptor()
|
||||||
|
path = build_dir / f"{args.name}.yaml"
|
||||||
|
|
||||||
|
config_file = Path(path)
|
||||||
|
|
||||||
|
if not config_file.exists():
|
||||||
|
self.parser.error(
|
||||||
|
f"File {str(config_file)} does not exist. Did you run `llama stack build`?"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config = PackageConfig(**yaml.safe_load(f))
|
||||||
|
|
||||||
|
if not config.distribution_id:
|
||||||
|
raise ValueError("Build config appears to be corrupt.")
|
||||||
|
|
||||||
|
if config.docker_image:
|
||||||
|
script = pkg_resources.resource_filename(
|
||||||
|
"llama_toolchain",
|
||||||
|
"core/start_container.sh",
|
||||||
|
)
|
||||||
|
run_args = [script, config.docker_image]
|
||||||
|
else:
|
||||||
|
script = pkg_resources.resource_filename(
|
||||||
|
"llama_toolchain",
|
||||||
|
"core/start_conda_env.sh",
|
||||||
|
)
|
||||||
|
run_args = [
|
||||||
|
script,
|
||||||
|
config.conda_env,
|
||||||
|
]
|
||||||
|
|
||||||
|
run_args.extend([str(config_file), str(args.port)])
|
||||||
|
if args.disable_ipv6:
|
||||||
|
run_args.append("--disable-ipv6")
|
||||||
|
|
||||||
|
run_with_pty(run_args)
|
32
llama_toolchain/cli/stack/stack.py
Normal file
32
llama_toolchain/cli/stack/stack.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
from .build import StackBuild
|
||||||
|
from .configure import StackConfigure
|
||||||
|
from .list import StackList
|
||||||
|
from .run import StackRun
|
||||||
|
|
||||||
|
|
||||||
|
class StackParser(Subcommand):
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"stack",
|
||||||
|
prog="llama stack",
|
||||||
|
description="Operations for the Llama Stack / Distributions",
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = self.parser.add_subparsers(title="stack_subcommands")
|
||||||
|
|
||||||
|
# Add sub-commands
|
||||||
|
StackBuild.create(subparsers)
|
||||||
|
StackConfigure.create(subparsers)
|
||||||
|
StackList.create(subparsers)
|
||||||
|
StackRun.create(subparsers)
|
|
@ -13,3 +13,5 @@ LLAMA_STACK_CONFIG_DIR = Path(os.path.expanduser("~/.llama/"))
|
||||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||||
|
|
||||||
|
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,4 +13,6 @@ class EnumEncoder(json.JSONEncoder):
|
||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
if isinstance(obj, Enum):
|
if isinstance(obj, Enum):
|
||||||
return obj.value
|
return obj.value
|
||||||
|
elif isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
return super().default(obj)
|
return super().default(obj)
|
||||||
|
|
|
@ -10,20 +10,36 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||||
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||||
|
echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
|
||||||
|
fi
|
||||||
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
|
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ "$#" -ne 3 ]; then
|
||||||
|
echo "Usage: $0 <distribution_id> <build_name> <pip_dependencies>" >&2
|
||||||
|
echo "Example: $0 <distribution_id> mybuild 'numpy pandas scipy'" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
distribution_id="$1"
|
||||||
|
build_name="$2"
|
||||||
|
env_name="llamastack-$build_name"
|
||||||
|
pip_dependencies="$3"
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
GREEN='\033[0;32m'
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
error_handler() {
|
# this is set if we actually create a new conda in which case we need to clean up
|
||||||
echo "Error occurred in script at line: ${1}" >&2
|
ENVNAME=""
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
# Set up the error trap
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
trap 'error_handler ${LINENO}' ERR
|
source "$SCRIPT_DIR/common.sh"
|
||||||
|
|
||||||
ensure_conda_env_python310() {
|
ensure_conda_env_python310() {
|
||||||
local env_name="$1"
|
local env_name="$1"
|
||||||
|
@ -32,26 +48,29 @@ ensure_conda_env_python310() {
|
||||||
|
|
||||||
# Check if conda command is available
|
# Check if conda command is available
|
||||||
if ! command -v conda &>/dev/null; then
|
if ! command -v conda &>/dev/null; then
|
||||||
echo -e "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check if the environment exists
|
# Check if the environment exists
|
||||||
if conda env list | grep -q "^${env_name} "; then
|
if conda env list | grep -q "^${env_name} "; then
|
||||||
echo "Conda environment '${env_name}' exists. Checking Python version..."
|
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
|
||||||
|
|
||||||
# Check Python version in the environment
|
# Check Python version in the environment
|
||||||
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
|
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
|
||||||
|
|
||||||
if [ "$current_version" = "$python_version" ]; then
|
if [ "$current_version" = "$python_version" ]; then
|
||||||
echo "Environment '${env_name}' already has Python ${python_version}. No action needed."
|
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
|
||||||
else
|
else
|
||||||
echo "Updating environment '${env_name}' to Python ${python_version}..."
|
printf "Updating environment '${env_name}' to Python ${python_version}...\n"
|
||||||
conda install -n "${env_name}" python="${python_version}" -y
|
conda install -n "${env_name}" python="${python_version}" -y
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
|
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
|
||||||
conda create -n "${env_name}" python="${python_version}" -y
|
conda create -n "${env_name}" python="${python_version}" -y
|
||||||
|
|
||||||
|
ENVNAME="${env_name}"
|
||||||
|
# setup_cleanup_handlers
|
||||||
fi
|
fi
|
||||||
|
|
||||||
eval "$(conda shell.bash hook)"
|
eval "$(conda shell.bash hook)"
|
||||||
|
@ -65,48 +84,45 @@ ensure_conda_env_python310() {
|
||||||
# Re-installing llama-toolchain in the new conda environment
|
# Re-installing llama-toolchain in the new conda environment
|
||||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
|
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||||
echo -e "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
|
printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}\n" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR"
|
printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR\n"
|
||||||
pip install -e "$LLAMA_TOOLCHAIN_DIR"
|
pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR"
|
||||||
else
|
else
|
||||||
pip install llama-toolchain
|
pip install --no-cache-dir llama-toolchain
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||||
echo -e "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
|
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR"
|
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
||||||
pip uninstall -y llama-models
|
pip uninstall -y llama-models
|
||||||
pip install -e "$LLAMA_MODELS_DIR"
|
pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install pip dependencies
|
# Install pip dependencies
|
||||||
if [ -n "$pip_dependencies" ]; then
|
if [ -n "$pip_dependencies" ]; then
|
||||||
echo "Installing pip dependencies: $pip_dependencies"
|
printf "Installing pip dependencies: $pip_dependencies\n"
|
||||||
pip install $pip_dependencies
|
pip install $pip_dependencies
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
if [ "$#" -ne 3 ]; then
|
|
||||||
echo "Usage: $0 <environment_name> <distribution_name> <pip_dependencies>" >&2
|
|
||||||
echo "Example: $0 my_env local-llama-8b 'numpy pandas scipy'" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
env_name="$1"
|
|
||||||
distribution_name="$2"
|
|
||||||
pip_dependencies="$3"
|
|
||||||
|
|
||||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
||||||
|
|
||||||
echo -e "${GREEN}Successfully setup distribution environment. Configuring...${NC}"
|
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
|
||||||
|
|
||||||
which python3
|
if [ "$distribution_id" = "adhoc" ]; then
|
||||||
python3 -m llama_toolchain.cli.llama distribution configure --name "$distribution_name"
|
subcommand="api"
|
||||||
|
target=""
|
||||||
|
else
|
||||||
|
subcommand="stack"
|
||||||
|
target="$distribution_id"
|
||||||
|
fi
|
||||||
|
|
||||||
|
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type conda_env
|
120
llama_toolchain/core/build_container.sh
Executable file
120
llama_toolchain/core/build_container.sh
Executable file
|
@ -0,0 +1,120 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||||
|
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
||||||
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
|
|
||||||
|
if [ "$#" -ne 4 ]; then
|
||||||
|
echo "Usage: $0 <distribution_id> <build_name> <docker_base> <pip_dependencies>
|
||||||
|
echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
distribution_id=$1
|
||||||
|
build_name="$2"
|
||||||
|
image_name="llamastack-$build_name"
|
||||||
|
docker_base=$3
|
||||||
|
pip_dependencies=$4
|
||||||
|
|
||||||
|
# Define color codes
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
|
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
||||||
|
|
||||||
|
TEMP_DIR=$(mktemp -d)
|
||||||
|
|
||||||
|
add_to_docker() {
|
||||||
|
local input
|
||||||
|
output_file="$TEMP_DIR/Dockerfile"
|
||||||
|
if [ -t 0 ]; then
|
||||||
|
printf '%s\n' "$1" >>"$output_file"
|
||||||
|
else
|
||||||
|
# If stdin is not a terminal, read from it (heredoc)
|
||||||
|
cat >>"$output_file"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
add_to_docker <<EOF
|
||||||
|
FROM $docker_base
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
iputils-ping net-tools iproute2 dnsutils telnet \
|
||||||
|
curl wget telnet \
|
||||||
|
procps psmisc lsof \
|
||||||
|
traceroute \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
EOF
|
||||||
|
|
||||||
|
toolchain_mount="/app/llama-toolchain-source"
|
||||||
|
models_mount="/app/llama-models-source"
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||||
|
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||||
|
echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
add_to_docker "RUN pip install $toolchain_mount"
|
||||||
|
else
|
||||||
|
add_to_docker "RUN pip install llama-toolchain"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
|
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||||
|
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
add_to_docker <<EOF
|
||||||
|
RUN pip uninstall -y llama-models
|
||||||
|
RUN pip install $models_mount
|
||||||
|
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$pip_dependencies" ]; then
|
||||||
|
add_to_docker "RUN pip install $pip_dependencies"
|
||||||
|
fi
|
||||||
|
|
||||||
|
add_to_docker <<EOF
|
||||||
|
|
||||||
|
# This would be good in production but for debugging flexibility lets not add it right now
|
||||||
|
# We need a more solid production ready entrypoint.sh anyway
|
||||||
|
#
|
||||||
|
# ENTRYPOINT ["python", "-m", "llama_toolchain.core.server"]
|
||||||
|
|
||||||
|
EOF
|
||||||
|
|
||||||
|
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
|
||||||
|
cat $TEMP_DIR/Dockerfile
|
||||||
|
printf "\n"
|
||||||
|
|
||||||
|
mounts=""
|
||||||
|
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||||
|
mounts="$mounts -v $(readlink -f $LLAMA_TOOLCHAIN_DIR):$toolchain_mount"
|
||||||
|
fi
|
||||||
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
|
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||||
|
fi
|
||||||
|
set -x
|
||||||
|
podman build -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||||
|
set +x
|
||||||
|
|
||||||
|
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
|
||||||
|
echo "You can run it with: podman run -p 8000:8000 $image_name"
|
||||||
|
|
||||||
|
if [ "$distribution_id" = "adhoc" ]; then
|
||||||
|
subcommand="api"
|
||||||
|
target=""
|
||||||
|
else
|
||||||
|
subcommand="stack"
|
||||||
|
target="$distribution_id"
|
||||||
|
fi
|
||||||
|
|
||||||
|
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type container
|
40
llama_toolchain/core/common.sh
Executable file
40
llama_toolchain/core/common.sh
Executable file
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
envname="$1"
|
||||||
|
|
||||||
|
set +x
|
||||||
|
echo "Cleaning up..."
|
||||||
|
conda deactivate
|
||||||
|
conda env remove --name $envname -y
|
||||||
|
}
|
||||||
|
|
||||||
|
handle_int() {
|
||||||
|
if [ -n $ENVNAME ]; then
|
||||||
|
cleanup $ENVNAME
|
||||||
|
fi
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
handle_exit() {
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo -e "\033[1;31mABORTING.\033[0m"
|
||||||
|
if [ -n $ENVNAME ]; then
|
||||||
|
cleanup $ENVNAME
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
setup_cleanup_handlers() {
|
||||||
|
trap handle_int INT
|
||||||
|
trap handle_exit EXIT
|
||||||
|
|
||||||
|
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||||
|
eval "$__conda_setup"
|
||||||
|
|
||||||
|
conda deactivate
|
||||||
|
}
|
50
llama_toolchain/core/configure.py
Normal file
50
llama_toolchain/core/configure.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||||
|
from llama_toolchain.core.distribution import api_providers
|
||||||
|
from llama_toolchain.core.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
|
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
|
||||||
|
all_providers = api_providers()
|
||||||
|
|
||||||
|
provider_configs = {}
|
||||||
|
for api_str, stub_config in existing_configs.items():
|
||||||
|
api = Api(api_str)
|
||||||
|
providers = all_providers[api]
|
||||||
|
provider_id = stub_config["provider_id"]
|
||||||
|
if provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_spec = providers[provider_id]
|
||||||
|
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_provider_config = config_type(**stub_config)
|
||||||
|
except Exception:
|
||||||
|
existing_provider_config = None
|
||||||
|
|
||||||
|
provider_config = prompt_for_config(
|
||||||
|
config_type,
|
||||||
|
existing_provider_config,
|
||||||
|
)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
provider_configs[api_str] = {
|
||||||
|
"provider_id": provider_id,
|
||||||
|
**provider_config.dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider_configs
|
190
llama_toolchain/core/datatypes.py
Normal file
190
llama_toolchain/core/datatypes.py
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Api(Enum):
|
||||||
|
inference = "inference"
|
||||||
|
safety = "safety"
|
||||||
|
agentic_system = "agentic_system"
|
||||||
|
memory = "memory"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ApiEndpoint(BaseModel):
|
||||||
|
route: str
|
||||||
|
method: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ProviderSpec(BaseModel):
|
||||||
|
api: Api
|
||||||
|
provider_id: str
|
||||||
|
config_class: str = Field(
|
||||||
|
...,
|
||||||
|
description="Fully-qualified classname of the config for this provider",
|
||||||
|
)
|
||||||
|
api_dependencies: List[Api] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AdapterSpec(BaseModel):
|
||||||
|
adapter_id: str = Field(
|
||||||
|
...,
|
||||||
|
description="Unique identifier for this adapter",
|
||||||
|
)
|
||||||
|
module: str = Field(
|
||||||
|
...,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
pip_packages: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The pip dependencies needed for this implementation",
|
||||||
|
)
|
||||||
|
config_class: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Fully-qualified classname of the config for this provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class InlineProviderSpec(ProviderSpec):
|
||||||
|
pip_packages: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The pip dependencies needed for this implementation",
|
||||||
|
)
|
||||||
|
docker_image: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
The docker image to use for this implementation. If one is provided, pip_packages will be ignored.
|
||||||
|
If a provider depends on other providers, the dependencies MUST NOT specify a docker image.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
module: str = Field(
|
||||||
|
...,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteProviderConfig(BaseModel):
|
||||||
|
url: str = Field(..., description="The URL for the provider")
|
||||||
|
|
||||||
|
@validator("url")
|
||||||
|
@classmethod
|
||||||
|
def validate_url(cls, url: str) -> str:
|
||||||
|
if not url.startswith("http"):
|
||||||
|
raise ValueError(f"URL must start with http: {url}")
|
||||||
|
return url.rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
|
def remote_provider_id(adapter_id: str) -> str:
|
||||||
|
return f"remote::{adapter_id}"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
|
adapter: Optional[AdapterSpec] = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||||
|
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||||
|
as being "Llama Stack compatible"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def docker_image(self) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module(self) -> str:
|
||||||
|
if self.adapter:
|
||||||
|
return self.adapter.module
|
||||||
|
return f"llama_toolchain.{self.api.value}.client"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> List[str]:
|
||||||
|
if self.adapter:
|
||||||
|
return self.adapter.pip_packages
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# Can avoid this by using Pydantic computed_field
|
||||||
|
def remote_provider_spec(
|
||||||
|
api: Api, adapter: Optional[AdapterSpec] = None
|
||||||
|
) -> RemoteProviderSpec:
|
||||||
|
config_class = (
|
||||||
|
adapter.config_class
|
||||||
|
if adapter and adapter.config_class
|
||||||
|
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
|
||||||
|
)
|
||||||
|
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
||||||
|
|
||||||
|
return RemoteProviderSpec(
|
||||||
|
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DistributionSpec(BaseModel):
|
||||||
|
distribution_id: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
providers: Dict[Api, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Provider IDs for each of the APIs provided by this distribution",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PackageConfig(BaseModel):
|
||||||
|
built_at: datetime
|
||||||
|
|
||||||
|
package_name: str = Field(
|
||||||
|
...,
|
||||||
|
description="""
|
||||||
|
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
|
||||||
|
this could be just a hash
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
distribution_id: Optional[str] = None
|
||||||
|
|
||||||
|
docker_image: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Reference to the docker image if this package refers to a container",
|
||||||
|
)
|
||||||
|
conda_env: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Reference to the conda environment if this package refers to a conda environment",
|
||||||
|
)
|
||||||
|
providers: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="""
|
||||||
|
Provider configurations for each of the APIs provided by this package. This includes configurations for
|
||||||
|
the dependencies of these providers as well.
|
||||||
|
""",
|
||||||
|
)
|
|
@ -7,11 +7,13 @@
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
from llama_toolchain.agentic_system.api import AgenticSystem
|
||||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||||
from llama_toolchain.inference.api.endpoints import Inference
|
from llama_toolchain.inference.api import Inference
|
||||||
from llama_toolchain.inference.providers import available_inference_providers
|
from llama_toolchain.inference.providers import available_inference_providers
|
||||||
from llama_toolchain.safety.api.endpoints import Safety
|
from llama_toolchain.memory.api import Memory
|
||||||
|
from llama_toolchain.memory.providers import available_memory_providers
|
||||||
|
from llama_toolchain.safety.api import Safety
|
||||||
from llama_toolchain.safety.providers import available_safety_providers
|
from llama_toolchain.safety.providers import available_safety_providers
|
||||||
|
|
||||||
from .datatypes import (
|
from .datatypes import (
|
||||||
|
@ -20,6 +22,7 @@ from .datatypes import (
|
||||||
DistributionSpec,
|
DistributionSpec,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
|
remote_provider_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
# These are the dependencies needed by the distribution server.
|
||||||
|
@ -40,6 +43,10 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
|
||||||
] + SERVER_DEPENDENCIES
|
] + SERVER_DEPENDENCIES
|
||||||
|
|
||||||
|
|
||||||
|
def stack_apis() -> List[Api]:
|
||||||
|
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
|
||||||
|
|
||||||
|
|
||||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
|
@ -47,6 +54,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.agentic_system: AgenticSystem,
|
Api.agentic_system: AgenticSystem,
|
||||||
|
Api.memory: Memory,
|
||||||
}
|
}
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
|
@ -60,9 +68,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
webmethod = method.__webmethod__
|
webmethod = method.__webmethod__
|
||||||
route = webmethod.route
|
route = webmethod.route
|
||||||
|
|
||||||
# use `post` for all methods right now until we fix up the `webmethod` openapi
|
if webmethod.method == "GET":
|
||||||
# annotation and write our own openapi generator
|
method = "get"
|
||||||
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
|
elif webmethod.method == "DELETE":
|
||||||
|
method = "delete"
|
||||||
|
else:
|
||||||
|
method = "post"
|
||||||
|
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||||
|
|
||||||
apis[api] = endpoints
|
apis[api] = endpoints
|
||||||
|
|
||||||
|
@ -78,8 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
a.provider_id: a for a in available_agentic_system_providers()
|
a.provider_id: a for a in available_agentic_system_providers()
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
ret = {
|
||||||
Api.inference: inference_providers_by_id,
|
Api.inference: inference_providers_by_id,
|
||||||
Api.safety: safety_providers_by_id,
|
Api.safety: safety_providers_by_id,
|
||||||
Api.agentic_system: agentic_system_providers_by_id,
|
Api.agentic_system: agentic_system_providers_by_id,
|
||||||
|
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||||
}
|
}
|
||||||
|
for k, v in ret.items():
|
||||||
|
v["remote"] = remote_provider_spec(k)
|
||||||
|
return ret
|
69
llama_toolchain/core/distribution_registry.py
Normal file
69
llama_toolchain/core/distribution_registry.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
|
return [
|
||||||
|
DistributionSpec(
|
||||||
|
distribution_id="local",
|
||||||
|
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||||
|
providers={
|
||||||
|
Api.inference: "meta-reference",
|
||||||
|
Api.memory: "meta-reference-faiss",
|
||||||
|
Api.safety: "meta-reference",
|
||||||
|
Api.agentic_system: "meta-reference",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
DistributionSpec(
|
||||||
|
distribution_id="remote",
|
||||||
|
description="Point to remote services for all llama stack APIs",
|
||||||
|
providers={x: "remote" for x in Api},
|
||||||
|
),
|
||||||
|
DistributionSpec(
|
||||||
|
distribution_id="local-ollama",
|
||||||
|
description="Like local, but use ollama for running LLM inference",
|
||||||
|
providers={
|
||||||
|
Api.inference: remote_provider_id("ollama"),
|
||||||
|
Api.safety: "meta-reference",
|
||||||
|
Api.agentic_system: "meta-reference",
|
||||||
|
Api.memory: "meta-reference-faiss",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
DistributionSpec(
|
||||||
|
distribution_id="local-plus-fireworks-inference",
|
||||||
|
description="Use Fireworks.ai for running LLM inference",
|
||||||
|
providers={
|
||||||
|
Api.inference: remote_provider_id("fireworks"),
|
||||||
|
Api.safety: "meta-reference",
|
||||||
|
Api.agentic_system: "meta-reference",
|
||||||
|
Api.memory: "meta-reference-faiss",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
DistributionSpec(
|
||||||
|
distribution_id="local-plus-together-inference",
|
||||||
|
description="Use Together.ai for running LLM inference",
|
||||||
|
providers={
|
||||||
|
Api.inference: remote_provider_id("together"),
|
||||||
|
Api.safety: "meta-reference",
|
||||||
|
Api.agentic_system: "meta-reference",
|
||||||
|
Api.memory: "meta-reference-faiss",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
|
||||||
|
for spec in available_distribution_specs():
|
||||||
|
if spec.distribution_id == distribution_id:
|
||||||
|
return spec
|
||||||
|
return None
|
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
from .datatypes import ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class_type(fully_qualified_name):
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
@ -19,18 +19,24 @@ def instantiate_class_type(fully_qualified_name):
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
def instantiate_provider(
|
def instantiate_provider(
|
||||||
provider_spec: InlineProviderSpec,
|
provider_spec: ProviderSpec,
|
||||||
provider_config: Dict[str, Any],
|
provider_config: Dict[str, Any],
|
||||||
deps: Dict[str, ProviderSpec],
|
deps: Dict[str, ProviderSpec],
|
||||||
):
|
):
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
|
if provider_spec.adapter:
|
||||||
|
method = "get_adapter_impl"
|
||||||
|
else:
|
||||||
|
method = "get_client_impl"
|
||||||
|
else:
|
||||||
|
method = "get_provider_impl"
|
||||||
|
|
||||||
config = config_type(**provider_config)
|
config = config_type(**provider_config)
|
||||||
return asyncio.run(module.get_provider_impl(config, deps))
|
fn = getattr(module, method)
|
||||||
|
impl = asyncio.run(fn(config, deps))
|
||||||
|
impl.__provider_spec__ = provider_spec
|
||||||
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
|
impl.__provider_config__ = config
|
||||||
module = importlib.import_module(provider_spec.module)
|
return impl
|
||||||
|
|
||||||
return asyncio.run(module.get_client_impl(base_url))
|
|
149
llama_toolchain/core/package.py
Normal file
149
llama_toolchain/core/package.py
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||||
|
from llama_toolchain.common.exec import run_with_pty
|
||||||
|
from llama_toolchain.common.serialize import EnumEncoder
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES
|
||||||
|
|
||||||
|
|
||||||
|
class BuildType(Enum):
|
||||||
|
container = "container"
|
||||||
|
conda_env = "conda_env"
|
||||||
|
|
||||||
|
def descriptor(self) -> str:
|
||||||
|
return "docker" if self == self.container else "conda"
|
||||||
|
|
||||||
|
|
||||||
|
class Dependencies(BaseModel):
|
||||||
|
pip_packages: List[str]
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ApiInput(BaseModel):
|
||||||
|
api: Api
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
|
||||||
|
def build_package(
|
||||||
|
api_inputs: List[ApiInput],
|
||||||
|
build_type: BuildType,
|
||||||
|
name: str,
|
||||||
|
distribution_id: Optional[str] = None,
|
||||||
|
docker_image: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if not distribution_id:
|
||||||
|
distribution_id = "adhoc"
|
||||||
|
|
||||||
|
build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor()
|
||||||
|
os.makedirs(build_dir, exist_ok=True)
|
||||||
|
|
||||||
|
package_name = name.replace("::", "-")
|
||||||
|
package_file = build_dir / f"{package_name}.yaml"
|
||||||
|
|
||||||
|
all_providers = api_providers()
|
||||||
|
|
||||||
|
package_deps = Dependencies(
|
||||||
|
docker_image=docker_image or "python:3.10-slim",
|
||||||
|
pip_packages=SERVER_DEPENDENCIES,
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_config = {}
|
||||||
|
for api_input in api_inputs:
|
||||||
|
api = api_input.api
|
||||||
|
providers_for_api = all_providers[api]
|
||||||
|
if api_input.provider not in providers_for_api:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider `{api_input.provider}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = providers_for_api[api_input.provider]
|
||||||
|
package_deps.pip_packages.extend(provider.pip_packages)
|
||||||
|
if provider.docker_image:
|
||||||
|
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||||
|
|
||||||
|
stub_config[api.value] = {"provider_id": api_input.provider}
|
||||||
|
|
||||||
|
if package_file.exists():
|
||||||
|
cprint(
|
||||||
|
f"Build `{package_name}` exists; will reconfigure",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
|
||||||
|
for api_str, new_config in stub_config.items():
|
||||||
|
if api_str not in c.providers:
|
||||||
|
c.providers[api_str] = new_config
|
||||||
|
else:
|
||||||
|
existing_config = c.providers[api_str]
|
||||||
|
if existing_config["provider_id"] != new_config["provider_id"]:
|
||||||
|
cprint(
|
||||||
|
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
c.providers[api_str] = new_config
|
||||||
|
else:
|
||||||
|
c = PackageConfig(
|
||||||
|
built_at=datetime.now(),
|
||||||
|
package_name=package_name,
|
||||||
|
providers=stub_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
c.distribution_id = distribution_id
|
||||||
|
c.docker_image = package_name if build_type == BuildType.container else None
|
||||||
|
c.conda_env = package_name if build_type == BuildType.conda_env else None
|
||||||
|
|
||||||
|
with open(package_file, "w") as f:
|
||||||
|
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
|
||||||
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
|
if build_type == BuildType.container:
|
||||||
|
script = pkg_resources.resource_filename(
|
||||||
|
"llama_toolchain", "core/build_container.sh"
|
||||||
|
)
|
||||||
|
args = [
|
||||||
|
script,
|
||||||
|
distribution_id,
|
||||||
|
package_name,
|
||||||
|
package_deps.docker_image,
|
||||||
|
" ".join(package_deps.pip_packages),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
script = pkg_resources.resource_filename(
|
||||||
|
"llama_toolchain", "core/build_conda_env.sh"
|
||||||
|
)
|
||||||
|
args = [
|
||||||
|
script,
|
||||||
|
distribution_id,
|
||||||
|
package_name,
|
||||||
|
" ".join(package_deps.pip_packages),
|
||||||
|
]
|
||||||
|
|
||||||
|
return_code = run_with_pty(args)
|
||||||
|
if return_code != 0:
|
||||||
|
cprint(
|
||||||
|
f"Failed to build target {package_name} with return code {return_code}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
cprint(
|
||||||
|
f"Target `{package_name}` built with configuration at {str(package_file)}",
|
||||||
|
color="green",
|
||||||
|
)
|
|
@ -5,8 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
|
import traceback
|
||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
AsyncGenerator as AsyncGeneratorABC,
|
AsyncGenerator as AsyncGeneratorABC,
|
||||||
AsyncIterator as AsyncIteratorABC,
|
AsyncIterator as AsyncIteratorABC,
|
||||||
|
@ -28,18 +30,17 @@ import fire
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
from .distribution import api_endpoints
|
from .distribution import api_endpoints, api_providers
|
||||||
from .dynamic import instantiate_client, instantiate_provider
|
from .dynamic import instantiate_provider
|
||||||
|
|
||||||
from .registry import resolve_distribution_spec
|
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def is_async_iterator_type(typ):
|
||||||
|
@ -66,6 +67,7 @@ def create_sse_event(data: Any) -> str:
|
||||||
|
|
||||||
|
|
||||||
async def global_exception_handler(request: Request, exc: Exception):
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
traceback.print_exception(exc)
|
||||||
http_exc = translate_exception(exc)
|
http_exc = translate_exception(exc)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
@ -155,9 +157,8 @@ def create_dynamic_passthrough(
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any):
|
def create_dynamic_typed_route(func: Any, method: str):
|
||||||
hints = get_type_hints(func)
|
hints = get_type_hints(func)
|
||||||
request_model = next(iter(hints.values()))
|
|
||||||
response_model = hints["return"]
|
response_model = hints["return"]
|
||||||
|
|
||||||
# NOTE: I think it is better to just add a method within each Api
|
# NOTE: I think it is better to just add a method within each Api
|
||||||
|
@ -168,7 +169,7 @@ def create_dynamic_typed_route(func: Any):
|
||||||
|
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
|
|
||||||
async def endpoint(request: request_model):
|
async def endpoint(**kwargs):
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
try:
|
try:
|
||||||
async for item in event_gen:
|
async for item in event_gen:
|
||||||
|
@ -178,10 +179,7 @@ def create_dynamic_typed_route(func: Any):
|
||||||
print("Generator cancelled")
|
print("Generator cancelled")
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
traceback.print_exception(e)
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
{
|
{
|
||||||
"error": {
|
"error": {
|
||||||
|
@ -191,25 +189,38 @@ def create_dynamic_typed_route(func: Any):
|
||||||
)
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
sse_generator(func(request)), media_type="text/event-stream"
|
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def endpoint(request: request_model):
|
async def endpoint(**kwargs):
|
||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
await func(request)
|
await func(**kwargs)
|
||||||
if asyncio.iscoroutinefunction(func)
|
if asyncio.iscoroutinefunction(func)
|
||||||
else func(request)
|
else func(**kwargs)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
traceback.print_exception(e)
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
if method == "post":
|
||||||
|
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||||
|
# do anything too intelligent and ask for some parameters in the query
|
||||||
|
# and some in the body
|
||||||
|
endpoint.__signature__ = sig.replace(
|
||||||
|
parameters=[
|
||||||
|
param.replace(
|
||||||
|
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||||
|
)
|
||||||
|
for param in sig.parameters.values()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
endpoint.__signature__ = sig
|
||||||
|
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -219,10 +230,9 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||||
visited.add(a.api)
|
visited.add(a.api)
|
||||||
|
|
||||||
if not isinstance(a, RemoteProviderSpec):
|
for api in a.api_dependencies:
|
||||||
for api in a.api_dependencies:
|
if api not in visited:
|
||||||
if api not in visited:
|
dfs(by_id[api], visited, stack)
|
||||||
dfs(by_id[api], visited, stack)
|
|
||||||
|
|
||||||
stack.append(a.api)
|
stack.append(a.api)
|
||||||
|
|
||||||
|
@ -236,9 +246,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
return [by_id[x] for x in stack]
|
return [by_id[x] for x in stack]
|
||||||
|
|
||||||
|
|
||||||
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
|
def resolve_impls(
|
||||||
|
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
|
||||||
|
) -> Dict[Api, Any]:
|
||||||
provider_configs = config["providers"]
|
provider_configs = config["providers"]
|
||||||
provider_specs = topological_sort(dist.provider_specs.values())
|
provider_specs = topological_sort(provider_specs.values())
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
for provider_spec in provider_specs:
|
for provider_spec in provider_specs:
|
||||||
|
@ -248,15 +260,13 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
|
||||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_config = provider_configs[api.value]
|
if isinstance(provider_spec, InlineProviderSpec):
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
|
||||||
impls[api] = instantiate_client(
|
|
||||||
provider_spec, provider_config["base_url"].rstrip("/")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
else:
|
||||||
impls[api] = impl
|
deps = {}
|
||||||
|
provider_config = provider_configs[api.value]
|
||||||
|
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||||
|
impls[api] = impl
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
@ -265,24 +275,36 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
with open(yaml_config, "r") as fp:
|
with open(yaml_config, "r") as fp:
|
||||||
config = yaml.safe_load(fp)
|
config = yaml.safe_load(fp)
|
||||||
|
|
||||||
spec = config["spec"]
|
|
||||||
dist = resolve_distribution_spec(spec)
|
|
||||||
if dist is None:
|
|
||||||
raise ValueError(f"Could not find distribution specification `{spec}`")
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = api_endpoints()
|
||||||
impls = resolve_impls(dist, config)
|
all_providers = api_providers()
|
||||||
|
|
||||||
for provider_spec in dist.provider_specs.values():
|
provider_specs = {}
|
||||||
|
for api_str, provider_config in config["providers"].items():
|
||||||
|
api = Api(api_str)
|
||||||
|
providers = all_providers[api]
|
||||||
|
provider_id = provider_config["provider_id"]
|
||||||
|
if provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_specs[api] = providers[provider_id]
|
||||||
|
|
||||||
|
impls = resolve_impls(provider_specs, config)
|
||||||
|
|
||||||
|
for provider_spec in provider_specs.values():
|
||||||
api = provider_spec.api
|
api = provider_spec.api
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if (
|
||||||
|
isinstance(provider_spec, RemoteProviderSpec)
|
||||||
|
and provider_spec.adapter is None
|
||||||
|
):
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
url = impl.base_url + endpoint.route
|
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
getattr(app, endpoint.method)(endpoint.route)(
|
||||||
create_dynamic_passthrough(url)
|
create_dynamic_passthrough(url)
|
||||||
)
|
)
|
||||||
|
@ -296,7 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
create_dynamic_typed_route(impl_method)
|
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||||
)
|
)
|
||||||
|
|
||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
|
@ -307,6 +329,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, handle_sigint)
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# Define color codes
|
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
@ -17,20 +16,27 @@ error_handler() {
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set up the error trap
|
|
||||||
trap 'error_handler ${LINENO}' ERR
|
trap 'error_handler ${LINENO}' ERR
|
||||||
|
|
||||||
if [ $# -lt 2 ]; then
|
if [ $# -lt 3 ]; then
|
||||||
echo "Usage: $0 <environment_name> <script_args...>"
|
echo "Usage: $0 <build_name> <yaml_config> <port> <script_args...>"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
build_name="$1"
|
||||||
|
env_name="llamastack-$build_name"
|
||||||
|
shift
|
||||||
|
|
||||||
env_name="$1"
|
yaml_config="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
port="$1"
|
||||||
shift
|
shift
|
||||||
|
|
||||||
eval "$(conda shell.bash hook)"
|
eval "$(conda shell.bash hook)"
|
||||||
conda deactivate && conda activate "$env_name"
|
conda deactivate && conda activate "$env_name"
|
||||||
|
|
||||||
python_interp=$(conda run -n "$env_name" which python)
|
$CONDA_PREFIX/bin/python \
|
||||||
$python_interp -m llama_toolchain.distribution.server "$@"
|
-m llama_toolchain.core.server \
|
||||||
|
--yaml_config "$yaml_config" \
|
||||||
|
--port "$port" "$@"
|
43
llama_toolchain/core/start_container.sh
Executable file
43
llama_toolchain/core/start_container.sh
Executable file
|
@ -0,0 +1,43 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
error_handler() {
|
||||||
|
echo "Error occurred in script at line: ${1}" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
trap 'error_handler ${LINENO}' ERR
|
||||||
|
|
||||||
|
if [ $# -lt 3 ]; then
|
||||||
|
echo "Usage: $0 <build_name> <yaml_config> <port> <other_args...>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
build_name="$1"
|
||||||
|
docker_image="llamastack-$build_name"
|
||||||
|
shift
|
||||||
|
|
||||||
|
yaml_config="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
port="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
set -x
|
||||||
|
podman run -it \
|
||||||
|
-p $port:$port \
|
||||||
|
-v "$yaml_config:/app/config.yaml" \
|
||||||
|
$docker_image \
|
||||||
|
python -m llama_toolchain.core.server \
|
||||||
|
--yaml_config /app/config.yaml \
|
||||||
|
--port $port "$@"
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -4,13 +4,34 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Protocol
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainEvalDatasetColumnType(Enum):
|
||||||
|
dialog = "dialog"
|
||||||
|
text = "text"
|
||||||
|
media = "media"
|
||||||
|
number = "number"
|
||||||
|
json = "json"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainEvalDataset(BaseModel):
|
||||||
|
"""Dataset to be used for training or evaluating language models."""
|
||||||
|
|
||||||
|
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
|
||||||
|
|
||||||
|
columns: Dict[str, TrainEvalDatasetColumnType]
|
||||||
|
content_url: URL
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,34 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TrainEvalDatasetColumnType(Enum):
|
|
||||||
dialog = "dialog"
|
|
||||||
text = "text"
|
|
||||||
media = "media"
|
|
||||||
number = "number"
|
|
||||||
json = "json"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TrainEvalDataset(BaseModel):
|
|
||||||
"""Dataset to be used for training or evaluating language models."""
|
|
||||||
|
|
||||||
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
|
|
||||||
|
|
||||||
columns: Dict[str, TrainEvalDatasetColumnType]
|
|
||||||
content_url: URL
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
|
@ -1,106 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Api(Enum):
|
|
||||||
inference = "inference"
|
|
||||||
safety = "safety"
|
|
||||||
agentic_system = "agentic_system"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ApiEndpoint(BaseModel):
|
|
||||||
route: str
|
|
||||||
method: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ProviderSpec(BaseModel):
|
|
||||||
api: Api
|
|
||||||
provider_id: str
|
|
||||||
config_class: str = Field(
|
|
||||||
...,
|
|
||||||
description="Fully-qualified classname of the config for this provider",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class InlineProviderSpec(ProviderSpec):
|
|
||||||
pip_packages: List[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="The pip dependencies needed for this implementation",
|
|
||||||
)
|
|
||||||
module: str = Field(
|
|
||||||
...,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
|
|
||||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
api_dependencies: List[Api] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteProviderConfig(BaseModel):
|
|
||||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
|
||||||
api_key: Optional[str] = Field(
|
|
||||||
..., description="API key, if needed, for the provider"
|
|
||||||
)
|
|
||||||
|
|
||||||
@validator("base_url")
|
|
||||||
@classmethod
|
|
||||||
def validate_base_url(cls, base_url: str) -> str:
|
|
||||||
if not base_url.startswith("http"):
|
|
||||||
raise ValueError(f"URL must start with http: {base_url}")
|
|
||||||
return base_url
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
|
||||||
module: str = Field(
|
|
||||||
...,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DistributionSpec(BaseModel):
|
|
||||||
spec_id: str
|
|
||||||
description: str
|
|
||||||
|
|
||||||
provider_specs: Dict[Api, ProviderSpec] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Provider specifications for each of the APIs provided by this distribution",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DistributionConfig(BaseModel):
|
|
||||||
"""References to a installed / configured DistributionSpec"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
spec: str
|
|
||||||
conda_env: str
|
|
||||||
providers: Dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Provider configurations for each of the APIs provided by this distribution",
|
|
||||||
)
|
|
|
@ -1,79 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
|
|
||||||
from .distribution import api_providers
|
|
||||||
|
|
||||||
|
|
||||||
def client_module(api: Api) -> str:
|
|
||||||
return f"llama_toolchain.{api.value}.client"
|
|
||||||
|
|
||||||
|
|
||||||
def remote_spec(api: Api) -> RemoteProviderSpec:
|
|
||||||
return RemoteProviderSpec(
|
|
||||||
api=api,
|
|
||||||
provider_id=f"{api.value}-remote",
|
|
||||||
module=client_module(api),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def available_distribution_specs() -> List[DistributionSpec]:
|
|
||||||
providers = api_providers()
|
|
||||||
return [
|
|
||||||
DistributionSpec(
|
|
||||||
spec_id="local",
|
|
||||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
|
||||||
provider_specs={
|
|
||||||
Api.inference: providers[Api.inference]["meta-reference"],
|
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
DistributionSpec(
|
|
||||||
spec_id="remote",
|
|
||||||
description="Point to remote services for all llama stack APIs",
|
|
||||||
provider_specs={x: remote_spec(x) for x in providers},
|
|
||||||
),
|
|
||||||
DistributionSpec(
|
|
||||||
spec_id="local-ollama",
|
|
||||||
description="Like local, but use ollama for running LLM inference",
|
|
||||||
provider_specs={
|
|
||||||
Api.inference: providers[Api.inference]["meta-ollama"],
|
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
DistributionSpec(
|
|
||||||
spec_id="remote-fireworks",
|
|
||||||
description="Use Fireworks.ai for running LLM inference",
|
|
||||||
provider_specs={
|
|
||||||
Api.inference: providers[Api.inference]["fireworks"],
|
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
DistributionSpec(
|
|
||||||
spec_id="remote-together",
|
|
||||||
description="Use Together.ai for running LLM inference",
|
|
||||||
provider_specs={
|
|
||||||
Api.inference: providers[Api.inference]["together"],
|
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]:
|
|
||||||
for spec in available_distribution_specs():
|
|
||||||
if spec.spec_id == spec_id:
|
|
||||||
return spec
|
|
||||||
return None
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import List, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import webmethod
|
||||||
|
@ -11,11 +12,34 @@ from llama_models.schema_utils import webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
from llama_toolchain.dataset.api import * # noqa: F403
|
||||||
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
|
||||||
from llama_toolchain.common.training_types import * # noqa: F403
|
from llama_toolchain.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class TextGenerationMetric(Enum):
|
||||||
|
perplexity = "perplexity"
|
||||||
|
rouge = "rouge"
|
||||||
|
bleu = "bleu"
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionAnsweringMetric(Enum):
|
||||||
|
em = "em"
|
||||||
|
f1 = "f1"
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationMetric(Enum):
|
||||||
|
rouge = "rouge"
|
||||||
|
bleu = "bleu"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationJob(BaseModel):
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationJobLogStream(BaseModel):
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
class EvaluateTaskRequestCommon(BaseModel):
|
class EvaluateTaskRequestCommon(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
dataset: TrainEvalDataset
|
dataset: TrainEvalDataset
|
|
@ -1,33 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationMetric(Enum):
|
|
||||||
perplexity = "perplexity"
|
|
||||||
rouge = "rouge"
|
|
||||||
bleu = "bleu"
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionAnsweringMetric(Enum):
|
|
||||||
em = "em"
|
|
||||||
f1 = "f1"
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizationMetric(Enum):
|
|
||||||
rouge = "rouge"
|
|
||||||
bleu = "bleu"
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationJob(BaseModel):
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationJobLogStream(BaseModel):
|
|
||||||
job_uuid: str
|
|
18
llama_toolchain/inference/adapters/fireworks/__init__.py
Normal file
18
llama_toolchain/inference/adapters/fireworks/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference:
|
||||||
|
from .fireworks import FireworksInferenceAdapter
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, FireworksImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = FireworksInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -5,9 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, Dict
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
from fireworks.client import Fireworks
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -18,20 +18,8 @@ from llama_models.llama3.api.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from fireworks.client import Fireworks
|
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseEvent,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
CompletionRequest,
|
|
||||||
Inference,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
@ -42,18 +30,7 @@ FIREWORKS_SUPPORTED_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
class FireworksInferenceAdapter(Inference):
|
||||||
config: FireworksImplConfig, _deps: Dict[Api, ProviderSpec]
|
|
||||||
) -> Inference:
|
|
||||||
assert isinstance(
|
|
||||||
config, FireworksImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
impl = FireworksInference(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
class FireworksInference(Inference):
|
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
15
llama_toolchain/inference/adapters/ollama/__init__.py
Normal file
15
llama_toolchain/inference/adapters/ollama/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||||
|
from .ollama import OllamaInferenceAdapter
|
||||||
|
|
||||||
|
impl = OllamaInferenceAdapter(config.url)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -4,63 +4,37 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
from typing import AsyncGenerator
|
||||||
from typing import AsyncGenerator, Dict
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
BuiltinTool,
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
CompletionMessage,
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
Message,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import (
|
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseEvent,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
CompletionRequest,
|
|
||||||
Inference,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import OllamaImplConfig
|
|
||||||
|
|
||||||
# TODO: Eventually this will move to the llama cli model list command
|
# TODO: Eventually this will move to the llama cli model list command
|
||||||
# mapping of Model SKUs to ollama models
|
# mapping of Model SKUs to ollama models
|
||||||
OLLAMA_SUPPORTED_SKUS = {
|
OLLAMA_SUPPORTED_SKUS = {
|
||||||
|
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
|
||||||
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||||
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
class OllamaInferenceAdapter(Inference):
|
||||||
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
|
def __init__(self, url: str) -> None:
|
||||||
) -> Inference:
|
self.url = url
|
||||||
assert isinstance(
|
tokenizer = Tokenizer.get_instance()
|
||||||
config, OllamaImplConfig
|
self.formatter = ChatFormat(tokenizer)
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
impl = OllamaInference(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaInference(Inference):
|
|
||||||
def __init__(self, config: OllamaImplConfig) -> None:
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> AsyncClient:
|
def client(self) -> AsyncClient:
|
||||||
return AsyncClient(host=self.config.url)
|
return AsyncClient(host=self.url)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -111,6 +85,7 @@ class OllamaInference(Inference):
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
messages = prepare_messages(request)
|
||||||
# accumulate sampling params and other options to pass to ollama
|
# accumulate sampling params and other options to pass to ollama
|
||||||
options = self.get_ollama_chat_options(request)
|
options = self.get_ollama_chat_options(request)
|
||||||
ollama_model = self.resolve_ollama_model(request.model)
|
ollama_model = self.resolve_ollama_model(request.model)
|
||||||
|
@ -132,7 +107,7 @@ class OllamaInference(Inference):
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
r = await self.client.chat(
|
r = await self.client.chat(
|
||||||
model=ollama_model,
|
model=ollama_model,
|
||||||
messages=self._messages_to_ollama_messages(request.messages),
|
messages=self._messages_to_ollama_messages(messages),
|
||||||
stream=False,
|
stream=False,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
@ -143,9 +118,8 @@ class OllamaInference(Inference):
|
||||||
elif r["done_reason"] == "length":
|
elif r["done_reason"] == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
completion_message = decode_assistant_message_from_content(
|
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||||
r["message"]["content"],
|
r["message"]["content"], stop_reason
|
||||||
stop_reason,
|
|
||||||
)
|
)
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
completion_message=completion_message,
|
completion_message=completion_message,
|
||||||
|
@ -160,7 +134,7 @@ class OllamaInference(Inference):
|
||||||
)
|
)
|
||||||
stream = await self.client.chat(
|
stream = await self.client.chat(
|
||||||
model=ollama_model,
|
model=ollama_model,
|
||||||
messages=self._messages_to_ollama_messages(request.messages),
|
messages=self._messages_to_ollama_messages(messages),
|
||||||
stream=True,
|
stream=True,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
@ -228,7 +202,9 @@ class OllamaInference(Inference):
|
||||||
)
|
)
|
||||||
|
|
||||||
# parse tool calls and report errors
|
# parse tool calls and report errors
|
||||||
message = decode_assistant_message_from_content(buffer, stop_reason)
|
message = self.formatter.decode_assistant_message_from_content(
|
||||||
|
buffer, stop_reason
|
||||||
|
)
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
@ -261,70 +237,3 @@ class OllamaInference(Inference):
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Consolidate this with impl in llama-models
|
|
||||||
def decode_assistant_message_from_content(
|
|
||||||
content: str,
|
|
||||||
stop_reason: StopReason,
|
|
||||||
) -> CompletionMessage:
|
|
||||||
ipython = content.startswith("<|python_tag|>")
|
|
||||||
if ipython:
|
|
||||||
content = content[len("<|python_tag|>") :]
|
|
||||||
|
|
||||||
if content.endswith("<|eot_id|>"):
|
|
||||||
content = content[: -len("<|eot_id|>")]
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif content.endswith("<|eom_id|>"):
|
|
||||||
content = content[: -len("<|eom_id|>")]
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
|
|
||||||
tool_name = None
|
|
||||||
tool_arguments = {}
|
|
||||||
|
|
||||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
|
||||||
if custom_tool_info is not None:
|
|
||||||
tool_name, tool_arguments = custom_tool_info
|
|
||||||
# Sometimes when agent has custom tools alongside builin tools
|
|
||||||
# Agent responds for builtin tool calls in the format of the custom tools
|
|
||||||
# This code tries to handle that case
|
|
||||||
if tool_name in BuiltinTool.__members__:
|
|
||||||
tool_name = BuiltinTool[tool_name]
|
|
||||||
tool_arguments = {
|
|
||||||
"query": list(tool_arguments.values())[0],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
|
||||||
if builtin_tool_info is not None:
|
|
||||||
tool_name, query = builtin_tool_info
|
|
||||||
tool_arguments = {
|
|
||||||
"query": query,
|
|
||||||
}
|
|
||||||
if tool_name in BuiltinTool.__members__:
|
|
||||||
tool_name = BuiltinTool[tool_name]
|
|
||||||
elif ipython:
|
|
||||||
tool_name = BuiltinTool.code_interpreter
|
|
||||||
tool_arguments = {
|
|
||||||
"code": content,
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
if tool_name is not None and tool_arguments is not None:
|
|
||||||
call_id = str(uuid.uuid4())
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
call_id=call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
arguments=tool_arguments,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
content = ""
|
|
||||||
|
|
||||||
if stop_reason is None:
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
return CompletionMessage(
|
|
||||||
content=content,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
)
|
|
18
llama_toolchain/inference/adapters/together/__init__.py
Normal file
18
llama_toolchain/inference/adapters/together/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference:
|
||||||
|
from .together import TogetherInferenceAdapter
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, TogetherImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = TogetherInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, Dict
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -18,18 +18,7 @@ from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseEvent,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
CompletionRequest,
|
|
||||||
Inference,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
@ -40,18 +29,7 @@ TOGETHER_SUPPORTED_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
class TogetherInferenceAdapter(Inference):
|
||||||
config: TogetherImplConfig, _deps: Dict[Api, ProviderSpec]
|
|
||||||
) -> Inference:
|
|
||||||
assert isinstance(
|
|
||||||
config, TogetherImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
impl = TogetherInference(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
class TogetherInference(Inference):
|
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -4,17 +4,79 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
from enum import Enum
|
||||||
from typing import Optional, Protocol
|
|
||||||
|
|
||||||
# this dependency is annoying and we need a forked up version anyway
|
from typing import List, Literal, Optional, Protocol, Union
|
||||||
from llama_models.schema_utils import webmethod
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class LogProbConfig(BaseModel):
|
||||||
|
top_k: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QuantizationType(Enum):
|
||||||
|
bf16 = "bf16"
|
||||||
|
fp8 = "fp8"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Fp8QuantizationConfig(BaseModel):
|
||||||
|
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Bf16QuantizationConfig(BaseModel):
|
||||||
|
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
||||||
|
|
||||||
|
|
||||||
|
QuantizationConfig = Annotated[
|
||||||
|
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseEventType(Enum):
|
||||||
|
start = "start"
|
||||||
|
complete = "complete"
|
||||||
|
progress = "progress"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolCallParseStatus(Enum):
|
||||||
|
started = "started"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
failure = "failure"
|
||||||
|
success = "success"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolCallDelta(BaseModel):
|
||||||
|
content: Union[str, ToolCall]
|
||||||
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseEvent(BaseModel):
|
||||||
|
"""Chat completion response event."""
|
||||||
|
|
||||||
|
event_type: ChatCompletionResponseEventType
|
||||||
|
delta: Union[str, ToolCallDelta]
|
||||||
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
stop_reason: Optional[StopReason] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content: InterleavedTextAttachment
|
content: InterleavedTextMedia
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
@ -39,7 +101,7 @@ class CompletionResponseStreamChunk(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionRequest(BaseModel):
|
class BatchCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content_batch: List[InterleavedTextAttachment]
|
content_batch: List[InterleavedTextMedia]
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
@ -56,7 +118,11 @@ class ChatCompletionRequest(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
|
default=ToolPromptFormat.json
|
||||||
|
)
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
@ -82,8 +148,11 @@ class BatchChatCompletionRequest(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
|
default=ToolPromptFormat.json
|
||||||
|
)
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,6 +161,11 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
completion_message_batch: List[CompletionMessage]
|
completion_message_batch: List[CompletionMessage]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
embeddings: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -105,14 +179,9 @@ class Inference(Protocol):
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/batch_completion")
|
@webmethod(route="/inference/embeddings")
|
||||||
async def batch_completion(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
request: BatchCompletionRequest,
|
model: str,
|
||||||
) -> BatchCompletionResponse: ...
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse: ...
|
||||||
@webmethod(route="/inference/batch_chat_completion")
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
request: BatchChatCompletionRequest,
|
|
||||||
) -> BatchChatCompletionResponse: ...
|
|
|
@ -1,72 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
|
||||||
top_k: Optional[int] = 0
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class QuantizationType(Enum):
|
|
||||||
bf16 = "bf16"
|
|
||||||
fp8 = "fp8"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Fp8QuantizationConfig(BaseModel):
|
|
||||||
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Bf16QuantizationConfig(BaseModel):
|
|
||||||
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
|
||||||
|
|
||||||
|
|
||||||
QuantizationConfig = Annotated[
|
|
||||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponseEventType(Enum):
|
|
||||||
start = "start"
|
|
||||||
complete = "complete"
|
|
||||||
progress = "progress"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolCallParseStatus(Enum):
|
|
||||||
started = "started"
|
|
||||||
in_progress = "in_progress"
|
|
||||||
failure = "failure"
|
|
||||||
success = "success"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolCallDelta(BaseModel):
|
|
||||||
content: Union[str, ToolCall]
|
|
||||||
parse_status: ToolCallParseStatus
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponseEvent(BaseModel):
|
|
||||||
"""Chat completion response event."""
|
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
|
||||||
delta: Union[str, ToolCallDelta]
|
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
|
||||||
stop_reason: Optional[StopReason] = None
|
|
|
@ -6,12 +6,15 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .api import (
|
from .api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -23,13 +26,16 @@ from .api import (
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(base_url: str):
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
||||||
return InferenceClient(base_url)
|
return InferenceClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
def encodable_dict(d: BaseModel):
|
||||||
|
return json.loads(d.json())
|
||||||
|
|
||||||
|
|
||||||
class InferenceClient(Inference):
|
class InferenceClient(Inference):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
print(f"Initializing client for {base_url}")
|
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
@ -46,7 +52,9 @@ class InferenceClient(Inference):
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/inference/chat_completion",
|
f"{self.base_url}/inference/chat_completion",
|
||||||
data=request.json(),
|
json={
|
||||||
|
"request": encodable_dict(request),
|
||||||
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as response:
|
) as response:
|
||||||
|
|
|
@ -5,4 +5,15 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig # noqa
|
from .config import MetaReferenceImplConfig # noqa
|
||||||
from .inference import get_provider_impl # noqa
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
|
||||||
|
from .inference import MetaReferenceInferenceImpl
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, MetaReferenceImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = MetaReferenceInferenceImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
|
@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_toolchain.inference.api import QuantizationConfig
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from llama_toolchain.inference.api import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MetaReferenceImplConfig(BaseModel):
|
class MetaReferenceImplConfig(BaseModel):
|
||||||
|
|
|
@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
@ -279,6 +279,7 @@ class Llama:
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
if (
|
if (
|
||||||
max_gen_len is None
|
max_gen_len is None
|
||||||
|
@ -288,7 +289,10 @@ class Llama:
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(messages),
|
model_input=self.formatter.encode_dialog_prompt(
|
||||||
|
messages,
|
||||||
|
tool_prompt_format,
|
||||||
|
),
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
|
|
@ -6,12 +6,11 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from typing import AsyncIterator, Dict, Union
|
from typing import AsyncIterator, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
|
||||||
from llama_toolchain.inference.api import (
|
from llama_toolchain.inference.api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -22,23 +21,11 @@ from llama_toolchain.inference.api import (
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
|
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
|
||||||
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
|
|
||||||
):
|
|
||||||
assert isinstance(
|
|
||||||
config, MetaReferenceImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = MetaReferenceInferenceImpl(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
# we don't support multiple concurrent requests to this process.
|
# we don't support multiple concurrent requests to this process.
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
@ -67,6 +54,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
) -> AsyncIterator[
|
) -> AsyncIterator[
|
||||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||||
]:
|
]:
|
||||||
|
messages = prepare_messages(request)
|
||||||
model = resolve_model(request.model)
|
model = resolve_model(request.model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -98,11 +86,12 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
ipython = False
|
ipython = False
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
for token_result in self.generator.chat_completion(
|
||||||
messages=request.messages,
|
messages=messages,
|
||||||
temperature=request.sampling_params.temperature,
|
temperature=request.sampling_params.temperature,
|
||||||
top_p=request.sampling_params.top_p,
|
top_p=request.sampling_params.top_p,
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
):
|
):
|
||||||
buffer += token_result.text
|
buffer += token_result.text
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from functools import partial
|
||||||
from typing import Generator, List, Optional
|
from typing import Generator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ class InferenceArgs:
|
||||||
top_p: float
|
top_p: float
|
||||||
max_gen_len: int
|
max_gen_len: int
|
||||||
logprobs: bool
|
logprobs: bool
|
||||||
|
tool_prompt_format: ToolPromptFormat
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
@ -41,6 +42,7 @@ class ModelRunner:
|
||||||
task.top_p,
|
task.top_p,
|
||||||
task.max_gen_len,
|
task.max_gen_len,
|
||||||
task.logprobs,
|
task.logprobs,
|
||||||
|
task.tool_prompt_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,6 +95,7 @@ class LlamaModelParallelGenerator:
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = InferenceArgs(
|
req_obj = InferenceArgs(
|
||||||
messages=deepcopy(messages),
|
messages=deepcopy(messages),
|
||||||
|
@ -100,6 +103,7 @@ class LlamaModelParallelGenerator:
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(req_obj)
|
||||||
|
|
84
llama_toolchain/inference/prepare_messages.py
Normal file
84
llama_toolchain/inference/prepare_messages.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
|
from llama_models.llama3.prompt_templates import (
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
FunctionTagCustomToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
|
||||||
|
|
||||||
|
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||||
|
|
||||||
|
existing_messages = request.messages
|
||||||
|
existing_system_message = None
|
||||||
|
if existing_messages[0].role == Role.system.value:
|
||||||
|
existing_system_message = existing_messages.pop(0)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
existing_messages[0].role != Role.system.value
|
||||||
|
), "Should only have 1 system message"
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
default_gen = SystemDefaultGenerator()
|
||||||
|
default_template = default_gen.gen()
|
||||||
|
|
||||||
|
sys_content = ""
|
||||||
|
|
||||||
|
tool_template = None
|
||||||
|
if request.tools:
|
||||||
|
tool_gen = BuiltinToolGenerator()
|
||||||
|
tool_template = tool_gen.gen(request.tools)
|
||||||
|
|
||||||
|
sys_content += tool_template.render()
|
||||||
|
sys_content += "\n"
|
||||||
|
|
||||||
|
sys_content += default_template.render()
|
||||||
|
|
||||||
|
if existing_system_message:
|
||||||
|
# TODO: this fn is needed in many places
|
||||||
|
def _process(c):
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
|
sys_content += "\n"
|
||||||
|
|
||||||
|
if isinstance(existing_system_message.content, str):
|
||||||
|
sys_content += _process(existing_system_message.content)
|
||||||
|
elif isinstance(existing_system_message.content, list):
|
||||||
|
sys_content += "\n".join(
|
||||||
|
[_process(c) for c in existing_system_message.content]
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(SystemMessage(content=sys_content))
|
||||||
|
|
||||||
|
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||||
|
if has_custom_tools:
|
||||||
|
if request.tool_prompt_format == ToolPromptFormat.json:
|
||||||
|
tool_gen = JsonCustomToolGenerator()
|
||||||
|
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||||
|
tool_gen = FunctionTagCustomToolGenerator()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
||||||
|
custom_template = tool_gen.gen(custom_tools)
|
||||||
|
messages.append(UserMessage(content=custom_template.render()))
|
||||||
|
|
||||||
|
# Add back existing messages from the request
|
||||||
|
messages += existing_messages
|
||||||
|
|
||||||
|
return messages
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
def available_inference_providers() -> List[ProviderSpec]:
|
def available_inference_providers() -> List[ProviderSpec]:
|
||||||
|
@ -27,14 +27,13 @@ def available_inference_providers() -> List[ProviderSpec]:
|
||||||
module="llama_toolchain.inference.meta_reference",
|
module="llama_toolchain.inference.meta_reference",
|
||||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_id="meta-ollama",
|
adapter=AdapterSpec(
|
||||||
pip_packages=[
|
adapter_id="ollama",
|
||||||
"ollama",
|
pip_packages=["ollama"],
|
||||||
],
|
module="llama_toolchain.inference.adapters.ollama",
|
||||||
module="llama_toolchain.inference.ollama",
|
),
|
||||||
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
|
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
|
|
|
@ -14,12 +14,12 @@ import torch
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||||
|
from llama_toolchain.inference.api import QuantizationType
|
||||||
|
|
||||||
from llama_toolchain.inference.api.config import (
|
from llama_toolchain.inference.api.config import (
|
||||||
CheckpointQuantizationFormat,
|
CheckpointQuantizationFormat,
|
||||||
MetaReferenceImplConfig,
|
MetaReferenceImplConfig,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api.datatypes import QuantizationType
|
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
157
llama_toolchain/memory/api/api.py
Normal file
157
llama_toolchain/memory/api/api.py
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryBankDocument(BaseModel):
|
||||||
|
document_id: str
|
||||||
|
content: InterleavedTextMedia | URL
|
||||||
|
mime_type: str
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryBankType(Enum):
|
||||||
|
vector = "vector"
|
||||||
|
keyvalue = "keyvalue"
|
||||||
|
keyword = "keyword"
|
||||||
|
graph = "graph"
|
||||||
|
|
||||||
|
|
||||||
|
class VectorMemoryBankConfig(BaseModel):
|
||||||
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
embedding_model: str
|
||||||
|
chunk_size_in_tokens: int
|
||||||
|
overlap_size_in_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValueMemoryBankConfig(BaseModel):
|
||||||
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordMemoryBankConfig(BaseModel):
|
||||||
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||||
|
|
||||||
|
|
||||||
|
class GraphMemoryBankConfig(BaseModel):
|
||||||
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
VectorMemoryBankConfig,
|
||||||
|
KeyValueMemoryBankConfig,
|
||||||
|
KeywordMemoryBankConfig,
|
||||||
|
GraphMemoryBankConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
content: InterleavedTextMedia
|
||||||
|
token_count: int
|
||||||
|
document_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QueryDocumentsResponse(BaseModel):
|
||||||
|
chunks: List[Chunk]
|
||||||
|
scores: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QueryAPI(Protocol):
|
||||||
|
@webmethod(route="/query_documents")
|
||||||
|
def query_documents(
|
||||||
|
self,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryBank(BaseModel):
|
||||||
|
bank_id: str
|
||||||
|
name: str
|
||||||
|
config: MemoryBankConfig
|
||||||
|
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
|
||||||
|
url: Optional[URL] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(Protocol):
|
||||||
|
@webmethod(route="/memory_banks/create")
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_banks/list", method="GET")
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_banks/get", method="GET")
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_banks/drop", method="DELETE")
|
||||||
|
async def drop_memory_bank(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
) -> str: ...
|
||||||
|
|
||||||
|
# this will just block now until documents are inserted, but it should
|
||||||
|
# probably return a Job instance which can be polled for completion
|
||||||
|
@webmethod(route="/memory_bank/insert")
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/update")
|
||||||
|
async def update_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/query")
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/documents/get", method="GET")
|
||||||
|
async def get_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
document_ids: List[str],
|
||||||
|
) -> List[MemoryBankDocument]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
|
||||||
|
async def delete_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
document_ids: List[str],
|
||||||
|
) -> None: ...
|
|
@ -1,25 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBank(BaseModel):
|
|
||||||
memory_bank_id: str
|
|
||||||
memory_bank_name: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBankDocument(BaseModel):
|
|
||||||
document_id: str
|
|
||||||
content: bytes
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
mime_type: str
|
|
|
@ -1,61 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List, Protocol
|
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanks(Protocol):
|
|
||||||
@webmethod(route="/memory_banks/create")
|
|
||||||
def create_memory_bank(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
bank_name: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/list")
|
|
||||||
def get_memory_banks(self) -> List[MemoryBank]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get")
|
|
||||||
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/drop")
|
|
||||||
def delete_memory_bank(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
) -> str: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/insert")
|
|
||||||
def insert_memory_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/update")
|
|
||||||
def update_memory_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/get")
|
|
||||||
def get_memory_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
document_uuids: List[str],
|
|
||||||
) -> List[MemoryBankDocument]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/delete")
|
|
||||||
def delete_memory_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
document_uuids: List[str],
|
|
||||||
) -> List[str]: ...
|
|
181
llama_toolchain/memory/client.py
Normal file
181
llama_toolchain/memory/client.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
|
||||||
|
return MemoryClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryClient(Memory):
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(
|
||||||
|
f"{self.base_url}/memory_banks/get",
|
||||||
|
params={
|
||||||
|
"bank_id": bank_id,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
if not d:
|
||||||
|
return None
|
||||||
|
return MemoryBank(**d)
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.post(
|
||||||
|
f"{self.base_url}/memory_banks/create",
|
||||||
|
json={
|
||||||
|
"name": name,
|
||||||
|
"config": config.dict(),
|
||||||
|
"url": url,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
if not d:
|
||||||
|
return None
|
||||||
|
return MemoryBank(**d)
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.post(
|
||||||
|
f"{self.base_url}/memory_bank/insert",
|
||||||
|
json={
|
||||||
|
"bank_id": bank_id,
|
||||||
|
"documents": [d.dict() for d in documents],
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.post(
|
||||||
|
f"{self.base_url}/memory_bank/query",
|
||||||
|
json={
|
||||||
|
"bank_id": bank_id,
|
||||||
|
"query": query,
|
||||||
|
"params": params,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
return QueryDocumentsResponse(**r.json())
|
||||||
|
|
||||||
|
|
||||||
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
|
client = MemoryClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
# create a memory bank
|
||||||
|
bank = await client.create_memory_bank(
|
||||||
|
name="test_bank",
|
||||||
|
config=VectorMemoryBankConfig(
|
||||||
|
bank_id="test_bank",
|
||||||
|
embedding_model="dragon-roberta-query-2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
print(bank)
|
||||||
|
|
||||||
|
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||||
|
assert retrieved_bank is not None
|
||||||
|
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
||||||
|
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
documents = [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=URL(
|
||||||
|
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||||
|
),
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
# insert some documents
|
||||||
|
await client.insert_documents(
|
||||||
|
bank_id=bank.bank_id,
|
||||||
|
documents=documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# query the documents
|
||||||
|
response = await client.query_documents(
|
||||||
|
bank_id=bank.bank_id,
|
||||||
|
query=[
|
||||||
|
"How do I use Lora?",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for chunk, score in zip(response.chunks, response.scores):
|
||||||
|
print(f"Score: {score}")
|
||||||
|
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||||
|
|
||||||
|
response = await client.query_documents(
|
||||||
|
bank_id=bank.bank_id,
|
||||||
|
query=[
|
||||||
|
"Tell me more about llama3 and torchtune",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for chunk, score in zip(response.chunks, response.scores):
|
||||||
|
print(f"Score: {score}")
|
||||||
|
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str, port: int, stream: bool = True):
|
||||||
|
asyncio.run(run_main(host, port, stream))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
19
llama_toolchain/memory/meta_reference/faiss/__init__.py
Normal file
19
llama_toolchain/memory/meta_reference/faiss/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: FaissImplConfig, _deps):
|
||||||
|
from .faiss import FaissMemoryImpl
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, FaissImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = FaissMemoryImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -5,12 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OllamaImplConfig(BaseModel):
|
class FaissImplConfig(BaseModel): ...
|
||||||
url: str = Field(
|
|
||||||
default="http://localhost:11434",
|
|
||||||
description="The URL for the ollama server",
|
|
||||||
)
|
|
194
llama_toolchain/memory/meta_reference/faiss/faiss.py
Normal file
194
llama_toolchain/memory/meta_reference/faiss/faiss.py
Normal file
|
@ -0,0 +1,194 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
|
if isinstance(doc.content, URL):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(doc.content.uri)
|
||||||
|
return r.text
|
||||||
|
|
||||||
|
return interleaved_text_media_as_str(doc.content)
|
||||||
|
|
||||||
|
|
||||||
|
def make_overlapped_chunks(
|
||||||
|
text: str, window_len: int, overlap_len: int
|
||||||
|
) -> List[Tuple[str, int]]:
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(tokens), window_len - overlap_len):
|
||||||
|
toks = tokens[i : i + window_len]
|
||||||
|
chunk = tokenizer.decode(toks)
|
||||||
|
chunks.append((chunk, len(toks)))
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BankState:
|
||||||
|
bank: MemoryBank
|
||||||
|
index: Optional[faiss.IndexFlatL2] = None
|
||||||
|
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
|
||||||
|
id_by_index: Dict[int, str] = field(default_factory=dict)
|
||||||
|
chunk_by_index: Dict[int, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
model: "SentenceTransformer",
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None:
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
chunk_size = self.bank.config.chunk_size_in_tokens
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
indexlen = len(self.id_by_index)
|
||||||
|
self.doc_by_id[doc.document_id] = doc
|
||||||
|
|
||||||
|
content = await content_from_doc(doc)
|
||||||
|
chunks = make_overlapped_chunks(
|
||||||
|
content,
|
||||||
|
self.bank.config.chunk_size_in_tokens,
|
||||||
|
self.bank.config.overlap_size_in_tokens
|
||||||
|
or (self.bank.config.chunk_size_in_tokens // 4),
|
||||||
|
)
|
||||||
|
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
|
||||||
|
await self._ensure_index(embeddings.shape[1])
|
||||||
|
|
||||||
|
self.index.add(embeddings)
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
self.chunk_by_index[indexlen + i] = Chunk(
|
||||||
|
content=chunk[0],
|
||||||
|
token_count=chunk[1],
|
||||||
|
document_id=doc.document_id,
|
||||||
|
)
|
||||||
|
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
||||||
|
self.id_by_index[indexlen + i] = doc.document_id
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
model: "SentenceTransformer",
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
k = params.get("max_chunks", 3)
|
||||||
|
|
||||||
|
def _process(c) -> str:
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
|
if isinstance(query, list):
|
||||||
|
query_str = " ".join([_process(c) for c in query])
|
||||||
|
else:
|
||||||
|
query_str = _process(query)
|
||||||
|
|
||||||
|
query_vector = model.encode([query_str])[0]
|
||||||
|
distances, indices = self.index.search(
|
||||||
|
query_vector.reshape(1, -1).astype(np.float32), k
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for d, i in zip(distances[0], indices[0]):
|
||||||
|
if i < 0:
|
||||||
|
continue
|
||||||
|
chunks.append(self.chunk_by_index[int(i)])
|
||||||
|
scores.append(1.0 / float(d))
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
||||||
|
if self.index is None:
|
||||||
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
|
return self.index
|
||||||
|
|
||||||
|
|
||||||
|
class FaissMemoryImpl(Memory):
|
||||||
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.model = None
|
||||||
|
self.states = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
assert url is None, "URL is not supported for this implementation"
|
||||||
|
assert (
|
||||||
|
config.type == MemoryBankType.vector.value
|
||||||
|
), f"Only vector banks are supported {config.type}"
|
||||||
|
|
||||||
|
bank_id = str(uuid.uuid4())
|
||||||
|
bank = MemoryBank(
|
||||||
|
bank_id=bank_id,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
state = BankState(bank=bank)
|
||||||
|
self.states[bank_id] = state
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
if bank_id not in self.states:
|
||||||
|
return None
|
||||||
|
return self.states[bank_id].bank
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||||
|
state = self.states[bank_id]
|
||||||
|
|
||||||
|
await state.insert_documents(self.get_model(), documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||||
|
state = self.states[bank_id]
|
||||||
|
|
||||||
|
return await state.query_documents(self.get_model(), query, params)
|
||||||
|
|
||||||
|
def get_model(self) -> "SentenceTransformer":
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
print("Loading sentence transformer")
|
||||||
|
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
return self.model
|
25
llama_toolchain/memory/providers.py
Normal file
25
llama_toolchain/memory/providers.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
def available_memory_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.memory,
|
||||||
|
provider_id="meta-reference-faiss",
|
||||||
|
pip_packages=[
|
||||||
|
"blobfile",
|
||||||
|
"faiss-cpu",
|
||||||
|
"sentence-transformers",
|
||||||
|
],
|
||||||
|
module="llama_toolchain.memory.meta_reference.faiss",
|
||||||
|
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
||||||
|
),
|
||||||
|
]
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -5,12 +5,79 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from enum import Enum
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
@json_schema_type
|
||||||
|
class ExperimentStatus(Enum):
|
||||||
|
NOT_STARTED = "not_started"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Experiment(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
status: ExperimentStatus
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Run(BaseModel):
|
||||||
|
id: str
|
||||||
|
experiment_id: str
|
||||||
|
status: str
|
||||||
|
started_at: datetime
|
||||||
|
ended_at: Optional[datetime]
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Metric(BaseModel):
|
||||||
|
name: str
|
||||||
|
value: Union[float, int, str, bool]
|
||||||
|
timestamp: datetime
|
||||||
|
run_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Log(BaseModel):
|
||||||
|
message: str
|
||||||
|
level: str
|
||||||
|
timestamp: datetime
|
||||||
|
additional_info: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ArtifactType(Enum):
|
||||||
|
MODEL = "model"
|
||||||
|
DATASET = "dataset"
|
||||||
|
CHECKPOINT = "checkpoint"
|
||||||
|
PLOT = "plot"
|
||||||
|
METRIC = "metric"
|
||||||
|
CONFIG = "config"
|
||||||
|
CODE = "code"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Artifact(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: ArtifactType
|
||||||
|
size: int
|
||||||
|
created_at: datetime
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,80 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ExperimentStatus(Enum):
|
|
||||||
NOT_STARTED = "not_started"
|
|
||||||
RUNNING = "running"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Experiment(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
status: ExperimentStatus
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Run(BaseModel):
|
|
||||||
id: str
|
|
||||||
experiment_id: str
|
|
||||||
status: str
|
|
||||||
started_at: datetime
|
|
||||||
ended_at: Optional[datetime]
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Metric(BaseModel):
|
|
||||||
name: str
|
|
||||||
value: Union[float, int, str, bool]
|
|
||||||
timestamp: datetime
|
|
||||||
run_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Log(BaseModel):
|
|
||||||
message: str
|
|
||||||
level: str
|
|
||||||
timestamp: datetime
|
|
||||||
additional_info: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ArtifactType(Enum):
|
|
||||||
MODEL = "model"
|
|
||||||
DATASET = "dataset"
|
|
||||||
CHECKPOINT = "checkpoint"
|
|
||||||
PLOT = "plot"
|
|
||||||
METRIC = "metric"
|
|
||||||
CONFIG = "config"
|
|
||||||
CODE = "code"
|
|
||||||
OTHER = "other"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Artifact(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: ArtifactType
|
|
||||||
size: int
|
|
||||||
created_at: datetime
|
|
||||||
metadata: Dict[str, Any]
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
@ -13,9 +14,90 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
from llama_toolchain.dataset.api import * # noqa: F403
|
||||||
from llama_toolchain.common.training_types import * # noqa: F403
|
from llama_toolchain.common.training_types import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
class OptimizerType(Enum):
|
||||||
|
adam = "adam"
|
||||||
|
adamw = "adamw"
|
||||||
|
sgd = "sgd"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OptimizerConfig(BaseModel):
|
||||||
|
optimizer_type: OptimizerType
|
||||||
|
lr: float
|
||||||
|
lr_min: float
|
||||||
|
weight_decay: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainingConfig(BaseModel):
|
||||||
|
n_epochs: int
|
||||||
|
batch_size: int
|
||||||
|
shuffle: bool
|
||||||
|
n_iters: int
|
||||||
|
|
||||||
|
enable_activation_checkpointing: bool
|
||||||
|
memory_efficient_fsdp_wrap: bool
|
||||||
|
fsdp_cpu_offload: bool
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FinetuningAlgorithm(Enum):
|
||||||
|
full = "full"
|
||||||
|
lora = "lora"
|
||||||
|
qlora = "qlora"
|
||||||
|
dora = "dora"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LoraFinetuningConfig(BaseModel):
|
||||||
|
lora_attn_modules: List[str]
|
||||||
|
apply_lora_to_mlp: bool
|
||||||
|
apply_lora_to_output: bool
|
||||||
|
rank: int
|
||||||
|
alpha: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DoraFinetuningConfig(LoraFinetuningConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobLogStream(BaseModel):
|
||||||
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
log_lines: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobStatus(Enum):
|
||||||
|
running = "running"
|
||||||
|
completed = "completed"
|
||||||
|
failed = "failed"
|
||||||
|
scheduled = "scheduled"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RLHFAlgorithm(Enum):
|
||||||
|
dpo = "dpo"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DPOAlignmentConfig(BaseModel):
|
||||||
|
reward_scale: float
|
||||||
|
reward_clip: float
|
||||||
|
epsilon: float
|
||||||
|
gamma: float
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,94 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerType(Enum):
|
|
||||||
adam = "adam"
|
|
||||||
adamw = "adamw"
|
|
||||||
sgd = "sgd"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OptimizerConfig(BaseModel):
|
|
||||||
optimizer_type: OptimizerType
|
|
||||||
lr: float
|
|
||||||
lr_min: float
|
|
||||||
weight_decay: float
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TrainingConfig(BaseModel):
|
|
||||||
n_epochs: int
|
|
||||||
batch_size: int
|
|
||||||
shuffle: bool
|
|
||||||
n_iters: int
|
|
||||||
|
|
||||||
enable_activation_checkpointing: bool
|
|
||||||
memory_efficient_fsdp_wrap: bool
|
|
||||||
fsdp_cpu_offload: bool
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class FinetuningAlgorithm(Enum):
|
|
||||||
full = "full"
|
|
||||||
lora = "lora"
|
|
||||||
qlora = "qlora"
|
|
||||||
dora = "dora"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class LoraFinetuningConfig(BaseModel):
|
|
||||||
lora_attn_modules: List[str]
|
|
||||||
apply_lora_to_mlp: bool
|
|
||||||
apply_lora_to_output: bool
|
|
||||||
rank: int
|
|
||||||
alpha: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DoraFinetuningConfig(LoraFinetuningConfig):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobLogStream(BaseModel):
|
|
||||||
"""Stream of logs from a finetuning job."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
log_lines: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobStatus(Enum):
|
|
||||||
running = "running"
|
|
||||||
completed = "completed"
|
|
||||||
failed = "failed"
|
|
||||||
scheduled = "scheduled"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RLHFAlgorithm(Enum):
|
|
||||||
dpo = "dpo"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DPOAlignmentConfig(BaseModel):
|
|
||||||
reward_scale: float
|
|
||||||
reward_clip: float
|
|
||||||
epsilon: float
|
|
||||||
gamma: float
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -5,9 +5,30 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol, Union
|
from typing import List, Protocol, Union
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoredMessage(BaseModel):
|
||||||
|
message: Message
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DialogGenerations(BaseModel):
|
||||||
|
dialog: List[Message]
|
||||||
|
sampled_generations: List[Message]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoredDialogGenerations(BaseModel):
|
||||||
|
dialog: List[Message]
|
||||||
|
scored_generations: List[ScoredMessage]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,31 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoredMessage(BaseModel):
|
|
||||||
message: Message
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DialogGenerations(BaseModel):
|
|
||||||
dialog: List[Message]
|
|
||||||
sampled_generations: List[Message]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoredDialogGenerations(BaseModel):
|
|
||||||
dialog: List[Message]
|
|
||||||
scored_generations: List[ScoredMessage]
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa
|
|
||||||
|
|
|
@ -5,13 +5,12 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,3 +69,22 @@ class ShieldResponse(BaseModel):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return v
|
return v
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RunShieldRequest(BaseModel):
|
||||||
|
messages: List[Message]
|
||||||
|
shields: List[ShieldDefinition]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RunShieldResponse(BaseModel):
|
||||||
|
responses: List[ShieldResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class Safety(Protocol):
|
||||||
|
@webmethod(route="/safety/run_shields")
|
||||||
|
async def run_shields(
|
||||||
|
self,
|
||||||
|
request: RunShieldRequest,
|
||||||
|
) -> RunShieldResponse: ...
|
|
@ -1,32 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
from typing import List, Protocol
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
|
|
||||||
# this dependency is annoying and we need a forked up version anyway
|
|
||||||
from llama_models.schema_utils import webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RunShieldRequest(BaseModel):
|
|
||||||
messages: List[Message]
|
|
||||||
shields: List[ShieldDefinition]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RunShieldResponse(BaseModel):
|
|
||||||
responses: List[ShieldResponse]
|
|
||||||
|
|
||||||
|
|
||||||
class Safety(Protocol):
|
|
||||||
@webmethod(route="/safety/run_shields")
|
|
||||||
async def run_shields(
|
|
||||||
self,
|
|
||||||
request: RunShieldRequest,
|
|
||||||
) -> RunShieldResponse: ...
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue