Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

9
.gitignore vendored
View file

@ -5,6 +5,11 @@ dist
dev_requirements.txt dev_requirements.txt
build build
.DS_Store .DS_Store
.idea
*.iml
llama_stack/configs/* llama_stack/configs/*
xcuserdata/
*.hmap
.DS_Store
.build/
Package.resolved
*.pte
*.ipynb_checkpoints*

View file

@ -37,50 +37,74 @@ llama model list
You should see a table like this: You should see a table like this:
<pre style="font-family: monospace;"> <pre style="font-family: monospace;">
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Model Descriptor | HuggingFace Repo | Context Length | Hardware Requirements | | Model Descriptor | HuggingFace Repo | Context Length |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-8B | meta-llama/Meta-Llama-3.1-8B | 128K | 1 GPU, each >= 20GB VRAM | | Llama3.1-8B | meta-llama/Llama-3.1-8B | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-70B | meta-llama/Meta-Llama-3.1-70B | 128K | 8 GPUs, each >= 20GB VRAM | | Llama3.1-70B | meta-llama/Llama-3.1-70B | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-405B:bf16-mp8 | | 128K | 8 GPUs, each >= 120GB VRAM | | Llama3.1-405B:bf16-mp8 | meta-llama/Llama-3.1-405B | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-405B | meta-llama/Meta-Llama-3.1-405B-FP8 | 128K | 8 GPUs, each >= 70GB VRAM | | Llama3.1-405B | meta-llama/Llama-3.1-405B-FP8 | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-405B:bf16-mp16 | meta-llama/Meta-Llama-3.1-405B | 128K | 16 GPUs, each >= 70GB VRAM | | Llama3.1-405B:bf16-mp16 | meta-llama/Llama-3.1-405B | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-8B-Instruct | meta-llama/Meta-Llama-3.1-8B-Instruct | 128K | 1 GPU, each >= 20GB VRAM | | Llama3.1-8B-Instruct | meta-llama/Llama-3.1-8B-Instruct | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-70B-Instruct | meta-llama/Meta-Llama-3.1-70B-Instruct | 128K | 8 GPUs, each >= 20GB VRAM | | Llama3.1-70B-Instruct | meta-llama/Llama-3.1-70B-Instruct | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-405B-Instruct:bf16-mp8 | | 128K | 8 GPUs, each >= 120GB VRAM | | Llama3.1-405B-Instruct:bf16-mp8 | meta-llama/Llama-3.1-405B-Instruct | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-405B-Instruct | meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 | 128K | 8 GPUs, each >= 70GB VRAM | | Llama3.1-405B-Instruct | meta-llama/Llama-3.1-405B-Instruct-FP8 | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Meta-Llama3.1-405B-Instruct:bf16-mp16 | meta-llama/Meta-Llama-3.1-405B-Instruct | 128K | 16 GPUs, each >= 70GB VRAM | | Llama3.1-405B-Instruct:bf16-mp16 | meta-llama/Llama-3.1-405B-Instruct | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-8B | meta-llama/Llama-Guard-3-8B | 128K | 1 GPU, each >= 20GB VRAM | | Llama3.2-1B | meta-llama/Llama-3.2-1B | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-8B:int8-mp1 | meta-llama/Llama-Guard-3-8B-INT8 | 128K | 1 GPU, each >= 10GB VRAM | | Llama3.2-3B | meta-llama/Llama-3.2-3B | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Prompt-Guard-86M | meta-llama/Prompt-Guard-86M | 128K | 1 GPU, each >= 1GB VRAM | | Llama3.2-11B-Vision | meta-llama/Llama-3.2-11B-Vision | 128K |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+ +----------------------------------+------------------------------------------+----------------+
| Llama3.2-90B-Vision | meta-llama/Llama-3.2-90B-Vision | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-1B-Instruct | meta-llama/Llama-3.2-1B-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-3B-Instruct | meta-llama/Llama-3.2-3B-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-11B-Vision-Instruct | meta-llama/Llama-3.2-11B-Vision-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-90B-Vision-Instruct | meta-llama/Llama-3.2-90B-Vision-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-11B-Vision | meta-llama/Llama-Guard-3-11B-Vision | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-1B:int4-mp1 | meta-llama/Llama-Guard-3-1B-INT4 | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-1B | meta-llama/Llama-Guard-3-1B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-8B | meta-llama/Llama-Guard-3-8B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-8B:int8-mp1 | meta-llama/Llama-Guard-3-8B-INT8 | 128K |
+----------------------------------+------------------------------------------+----------------+
| Prompt-Guard-86M | meta-llama/Prompt-Guard-86M | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-2-8B | meta-llama/Llama-Guard-2-8B | 4K |
+----------------------------------+------------------------------------------+----------------+
</pre> </pre>
To download models, you can use the llama download command. To download models, you can use the llama download command.
#### Downloading from [Meta](https://llama.meta.com/llama-downloads/) #### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
Here is an example download command to get the 8B/70B Instruct model. You will need META_URL which can be obtained from [here](https://llama.meta.com/docs/getting_the_models/meta/) Here is an example download command to get the 3B-Instruct/11B-Vision-Instruct model. You will need META_URL which can be obtained from [here](https://llama.meta.com/docs/getting_the_models/meta/)
Download the required checkpoints using the following commands: Download the required checkpoints using the following commands:
```bash ```bash
# download the 8B model, this can be run on a single GPU # download the 8B model, this can be run on a single GPU
llama download --source meta --model-id Meta-Llama3.1-8B-Instruct --meta-url META_URL llama download --source meta --model-id Llama3.2-3B-Instruct --meta-url META_URL
# you can also get the 70B model, this will require 8 GPUs however # you can also get the 70B model, this will require 8 GPUs however
llama download --source meta --model-id Meta-Llama3.1-70B-Instruct --meta-url META_URL llama download --source meta --model-id Llama3.2-11B-Vision-Instruct --meta-url META_URL
# llama-agents have safety enabled by default. For this, you will need # llama-agents have safety enabled by default. For this, you will need
# safety models -- Llama-Guard and Prompt-Guard # safety models -- Llama-Guard and Prompt-Guard
@ -124,7 +148,7 @@ The `llama model` command helps you explore the models interface.
### 2.1 Subcommands ### 2.1 Subcommands
1. `download`: Download the model from different sources. (meta, huggingface) 1. `download`: Download the model from different sources. (meta, huggingface)
2. `list`: Lists all the models available for download with hardware requirements to deploy the models. 2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
3. `template`: <TODO: What is a template?> 3. `prompt-format`: Show llama model message formats.
4. `describe`: Describes all the properties of the model. 4. `describe`: Describes all the properties of the model.
### 2.2 Sample Usage ### 2.2 Sample Usage
@ -135,7 +159,7 @@ The `llama model` command helps you explore the models interface.
llama model --help llama model --help
``` ```
<pre style="font-family: monospace;"> <pre style="font-family: monospace;">
usage: llama model [-h] {download,list,template,describe} ... usage: llama model [-h] {download,list,prompt-format,describe} ...
Work with llama models Work with llama models
@ -143,124 +167,67 @@ options:
-h, --help show this help message and exit -h, --help show this help message and exit
model_subcommands: model_subcommands:
{download,list,template,describe} {download,list,prompt-format,describe}
</pre> </pre>
You can use the describe command to know more about a model: You can use the describe command to know more about a model:
``` ```
llama model describe -m Meta-Llama3.1-8B-Instruct llama model describe -m Llama3.2-3B-Instruct
``` ```
### 2.3 Describe ### 2.3 Describe
<pre style="font-family: monospace;"> <pre style="font-family: monospace;">
+-----------------------------+---------------------------------------+ +-----------------------------+----------------------------------+
| Model | Meta- | | Model | Llama3.2-3B-Instruct |
| | Llama3.1-8B-Instruct | +-----------------------------+----------------------------------+
+-----------------------------+---------------------------------------+ | HuggingFace ID | meta-llama/Llama-3.2-3B-Instruct |
| HuggingFace ID | meta-llama/Meta-Llama-3.1-8B-Instruct | +-----------------------------+----------------------------------+
+-----------------------------+---------------------------------------+ | Description | Llama 3.2 3b instruct model |
| Description | Llama 3.1 8b instruct model | +-----------------------------+----------------------------------+
+-----------------------------+---------------------------------------+
| Context Length | 128K tokens | | Context Length | 128K tokens |
+-----------------------------+---------------------------------------+ +-----------------------------+----------------------------------+
| Weights format | bf16 | | Weights format | bf16 |
+-----------------------------+---------------------------------------+ +-----------------------------+----------------------------------+
| Model params.json | { | | Model params.json | { |
| | "dim": 4096, | | | "dim": 3072, |
| | "n_layers": 32, | | | "n_layers": 28, |
| | "n_heads": 32, | | | "n_heads": 24, |
| | "n_kv_heads": 8, | | | "n_kv_heads": 8, |
| | "vocab_size": 128256, | | | "vocab_size": 128256, |
| | "ffn_dim_multiplier": 1.3, | | | "ffn_dim_multiplier": 1.0, |
| | "multiple_of": 1024, | | | "multiple_of": 256, |
| | "norm_eps": 1e-05, | | | "norm_eps": 1e-05, |
| | "rope_theta": 500000.0, | | | "rope_theta": 500000.0, |
| | "use_scaled_rope": true | | | "use_scaled_rope": true |
| | } | | | } |
+-----------------------------+---------------------------------------+ +-----------------------------+----------------------------------+
| Recommended sampling params | { | | Recommended sampling params | { |
| | "strategy": "top_p", | | | "strategy": "top_p", |
| | "temperature": 1.0, | | | "temperature": 1.0, |
| | "top_p": 0.9, | | | "top_p": 0.9, |
| | "top_k": 0 | | | "top_k": 0 |
| | } | | | } |
+-----------------------------+---------------------------------------+ +-----------------------------+----------------------------------+
</pre> </pre>
### 2.4 Template ### 2.4 Prompt Format
You can even run `llama model template` see all of the templates and their tokens: You can even run `llama model prompt-format` see all of the templates and their tokens:
``` ```
llama model template llama model prompt-format -m Llama3.2-3B-Instruct
``` ```
<p align="center">
<img width="719" alt="image" src="https://github.com/user-attachments/assets/c5332026-8c0b-4edc-b438-ec60cd7ca554">
</p>
<pre style="font-family: monospace;">
+-----------+---------------------------------+
| Role | Template Name |
+-----------+---------------------------------+
| user | user-default |
| assistant | assistant-builtin-tool-call |
| assistant | assistant-custom-tool-call |
| assistant | assistant-default |
| system | system-builtin-and-custom-tools |
| system | system-builtin-tools-only |
| system | system-custom-tools-only |
| system | system-default |
| tool | tool-success |
| tool | tool-failure |
+-----------+---------------------------------+
</pre>
And fetch an example by passing it to `--name`: You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
```
llama model template --name tool-success
```
<pre style="font-family: monospace;">
+----------+----------------------------------------------------------------+
| Name | tool-success |
+----------+----------------------------------------------------------------+
| Template | <|start_header_id|>ipython<|end_header_id|> |
| | |
| | completed |
| | [stdout]{"results":["something |
| | something"]}[/stdout]<|eot_id|> |
| | |
+----------+----------------------------------------------------------------+
| Notes | Note ipython header and [stdout] |
+----------+----------------------------------------------------------------+
</pre>
Or:
```
llama model template --name system-builtin-tools-only
```
<pre style="font-family: monospace;">
+----------+--------------------------------------------+
| Name | system-builtin-tools-only |
+----------+--------------------------------------------+
| Template | <|start_header_id|>system<|end_header_id|> |
| | |
| | Environment: ipython |
| | Tools: brave_search, wolfram_alpha |
| | |
| | Cutting Knowledge Date: December 2023 |
| | Today Date: 21 August 2024 |
| | <|eot_id|> |
| | |
+----------+--------------------------------------------+
| Notes | |
+----------+--------------------------------------------+
</pre>
These commands can help understand the model interface and how prompts / messages are formatted for various scenarios.
**NOTE**: Outputs in terminal are color printed to show special tokens. **NOTE**: Outputs in terminal are color printed to show special tokens.
## Step 3: Building, and Configuring Llama Stack Distributions ## Step 3: Building, and Configuring Llama Stack Distributions
- Please see our [Getting Started](getting_started.md) guide for details. - Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
### Step 3.1 Build ### Step 3.1 Build
In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify: In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:

BIN
docs/dog.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

325
docs/getting_started.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -1,9 +1,70 @@
# llama-stack
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU)
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
The Llama Stack defines and standardizes the building blocks needed to bring generative AI applications to market. These blocks span the entire development lifecycle: from model training and fine-tuning, through product evaluation, to invoking AI agents in production. Beyond definition, we're developing open-source versions and partnering with cloud providers, ensuring developers can assemble AI solutions using consistent, interlocking pieces across platforms. The ultimate goal is to accelerate innovation in the AI space.
The Stack APIs are rapidly improving, but still very much work in progress and we invite feedback as well as direct contributions.
## APIs
The Llama Stack consists of the following set of APIs:
- Inference
- Safety
- Memory
- Agentic System
- Evaluation
- Post Training
- Synthetic Data Generation
- Reward Scoring
Each of the APIs themselves is a collection of REST endpoints.
## API Providers
A Provider is what makes the API real -- they provide the actual implementation backing the API.
As an example, for Inference, we could have the implementation be backed by open source libraries like `[ torch | vLLM | TensorRT ]` as possible options.
A provider can also be just a pointer to a remote REST service -- for example, cloud providers or dedicated inference providers could serve these APIs.
## Llama Stack Distribution
A Distribution is where APIs and Providers are assembled together to provide a consistent whole to the end application developer. You can mix-and-match providers -- some could be backed by local code and some could be remote. As a hobbyist, you can serve a small model locally, but can choose a cloud provider for a large model. Regardless, the higher level APIs your app needs to work with don't need to change at all. You can even imagine moving across the server / mobile-device boundary as well always using the same uniform set of APIs for developing Generative AI applications.
## Installation
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
If you want to install from source:
```bash
mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
conda create -n stack python=3.10
conda activate stack
cd llama-stack
$CONDA_PREFIX/bin/pip install -e .
```
# Getting Started # Getting Started
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package. The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes! This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes!
You may also checkout this [notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) for trying out out demo scripts.
## Quick Cheatsheet ## Quick Cheatsheet
- Quick 3 line command to build and start a LlamaStack server using our Meta Reference implementation for all API endpoints with `conda` as build type. - Quick 3 line command to build and start a LlamaStack server using our Meta Reference implementation for all API endpoints with `conda` as build type.
@ -12,7 +73,7 @@ This guides allows you to quickly get started with building and running a Llama
``` ```
llama stack build llama stack build
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-llama-stack > Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack
> Enter the image type you want your distribution to be built with (docker or conda): conda > Enter the image type you want your distribution to be built with (docker or conda): conda
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs. Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
@ -24,47 +85,57 @@ llama stack build
> (Optional) Enter a short description for your Llama Stack distribution: > (Optional) Enter a short description for your Llama Stack distribution:
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-llama-stack/my-local-llama-stack-build.yaml Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml
You can now run `llama stack configure my-local-stack`
``` ```
**`llama stack configure`** **`llama stack configure`**
- Run `llama stack configure <name>` with the name you have previously defined in `build` step. - Run `llama stack configure <name>` with the name you have previously defined in `build` step.
``` ```
llama stack configure my-local-llama-stack llama stack configure <name>
```
- You will be prompted to enter configurations for your Llama Stack
Configuring APIs to serve... ```
Enter comma-separated list of APIs to serve: $ llama stack configure my-local-stack
Could not find my-local-stack. Trying conda build name instead...
Configuring API `inference`... Configuring API `inference`...
=== Configuring provider `meta-reference` for API inference...
Configuring provider `meta-reference`... Enter value for model (default: Llama3.1-8B-Instruct) (required):
Enter value for model (default: Meta-Llama3.1-8B-Instruct) (required):
Do you want to configure quantization? (y/n): n Do you want to configure quantization? (y/n): n
Enter value for torch_seed (optional): Enter value for torch_seed (optional):
Enter value for max_seq_len (required): 4096 Enter value for max_seq_len (default: 4096) (required):
Enter value for max_batch_size (default: 1) (required): Enter value for max_batch_size (default: 1) (required):
Configuring API `safety`...
Configuring provider `meta-reference`... Configuring API `safety`...
=== Configuring provider `meta-reference` for API safety...
Do you want to configure llama_guard_shield? (y/n): n Do you want to configure llama_guard_shield? (y/n): n
Do you want to configure prompt_guard_shield? (y/n): n Do you want to configure prompt_guard_shield? (y/n): n
Configuring API `agents`... Configuring API `agents`...
=== Configuring provider `meta-reference` for API agents...
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite):
Configuring SqliteKVStoreConfig:
Enter value for namespace (optional):
Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required):
Configuring provider `meta-reference`...
Configuring API `memory`... Configuring API `memory`...
=== Configuring provider `meta-reference` for API memory...
> Please enter the supported memory bank type your provider has for memory: vector
Configuring provider `meta-reference`...
Configuring API `telemetry`... Configuring API `telemetry`...
=== Configuring provider `meta-reference` for API telemetry...
Configuring provider `meta-reference`... > YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml.
> YAML configuration has been written to ~/.llama/builds/conda/my-local-llama-stack-run.yaml. You can now run `llama stack run my-local-stack --port PORT`
You can now run `llama stack run my-local-llama-stack --port PORT` or `llama stack run ~/.llama/builds/conda/my-local-llama-stack-run.yaml --port PORT
``` ```
**`llama stack run`** **`llama stack run`**
- Run `llama stack run <name>` with the name you have previously defined. - Run `llama stack run <name>` with the name you have previously defined.
``` ```
llama stack run my-local-llama-stack llama stack run my-local-stack
... ...
> initializing model parallel with size 1 > initializing model parallel with size 1
@ -126,7 +197,7 @@ llama stack build
Running the command above will allow you to fill in the configuration to build your Llama Stack distribution, you will see the following outputs. Running the command above will allow you to fill in the configuration to build your Llama Stack distribution, you will see the following outputs.
``` ```
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-llama-stack > Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): 8b-instruct
> Enter the image type you want your distribution to be built with (docker or conda): conda > Enter the image type you want your distribution to be built with (docker or conda): conda
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs. Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
@ -138,9 +209,14 @@ Running the command above will allow you to fill in the configuration to build y
> (Optional) Enter a short description for your Llama Stack distribution: > (Optional) Enter a short description for your Llama Stack distribution:
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-llama-stack/my-local-llama-stack-build.yaml Build spec configuration saved at ~/.conda/envs/llamastack-my-local-llama-stack/8b-instruct-build.yaml
``` ```
**Ollama (optional)**
If you plan to use Ollama for inference, you'll need to install the server [via these instructions](https://ollama.com/download).
#### Building from templates #### Building from templates
- To build from alternative API providers, we provide distribution templates for users to get started building a distribution backed by different providers. - To build from alternative API providers, we provide distribution templates for users to get started building a distribution backed by different providers.
@ -236,7 +312,7 @@ llama stack configure [ <name> | <docker-image-name> | <path/to/name.build.yaml>
- Run `docker images` to check list of available images on your machine. - Run `docker images` to check list of available images on your machine.
``` ```
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml $ llama stack configure 8b-instruct
Configuring API: inference (meta-reference) Configuring API: inference (meta-reference)
Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required): Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required):
@ -284,13 +360,13 @@ Note that all configurations as well as models are stored in `~/.llama`
Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack configure` step. Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack configure` step.
``` ```
llama stack run ~/.llama/builds/conda/8b-instruct-run.yaml llama stack run 8b-instruct
``` ```
You should see the Llama Stack server start and print the APIs that it is supporting You should see the Llama Stack server start and print the APIs that it is supporting
``` ```
$ llama stack run ~/.llama/builds/local/conda/8b-instruct.yaml $ llama stack run 8b-instruct
> initializing model parallel with size 1 > initializing model parallel with size 1
> initializing ddp with size 1 > initializing ddp with size 1
@ -357,4 +433,4 @@ Similarly you can test safety (if you configured llama-guard and/or prompt-guard
python -m llama_stack.apis.safety.client localhost 5000 python -m llama_stack.apis.safety.client localhost 5000
``` ```
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo. You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps) repo.

View file

@ -21,7 +21,7 @@
"info": { "info": {
"title": "[DRAFT] Llama Stack Specification", "title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1", "version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 10:56:42.866760" "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
}, },
"servers": [ "servers": [
{ {
@ -2027,10 +2027,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -2053,6 +2063,35 @@
"tool_calls" "tool_calls"
] ]
}, },
"ImageMedia": {
"type": "object",
"properties": {
"image": {
"oneOf": [
{
"type": "object",
"properties": {
"format": {
"type": "string"
},
"format_description": {
"type": "string"
}
},
"additionalProperties": false,
"title": "This class represents an image object. To create"
},
{
"$ref": "#/components/schemas/URL"
}
]
}
},
"additionalProperties": false,
"required": [
"image"
]
},
"SamplingParams": { "SamplingParams": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -2115,10 +2154,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -2267,6 +2316,28 @@
"required": { "required": {
"type": "boolean", "type": "boolean",
"default": true "default": true
},
"default": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -2278,7 +2349,8 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"json", "json",
"function_tag" "function_tag",
"python_list"
], ],
"title": "This Enum refers to the prompt format for calling custom / zero shot tools", "title": "This Enum refers to the prompt format for calling custom / zero shot tools",
"description": "`json` --\n Refers to the json format for calling tools.\n The json format takes the form like\n {\n \"type\": \"function\",\n \"function\" : {\n \"name\": \"function_name\",\n \"description\": \"function_description\",\n \"parameters\": {...}\n }\n }\n\n`function_tag` --\n This is an example of how you could define\n your own user defined format for making tool calls.\n The function_tag format looks like this,\n <function=function_name>(parameters)</function>\n\nThe detailed prompts for each of these formats are added to llama cli" "description": "`json` --\n Refers to the json format for calling tools.\n The json format takes the form like\n {\n \"type\": \"function\",\n \"function\" : {\n \"name\": \"function_name\",\n \"description\": \"function_description\",\n \"parameters\": {...}\n }\n }\n\n`function_tag` --\n This is an example of how you could define\n your own user defined format for making tool calls.\n The function_tag format looks like this,\n <function=function_name>(parameters)</function>\n\nThe detailed prompts for each of these formats are added to llama cli"
@ -2309,10 +2381,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -2326,6 +2408,11 @@
"content" "content"
] ]
}, },
"URL": {
"type": "string",
"format": "uri",
"pattern": "^(https?://|file://|data:)"
},
"UserMessage": { "UserMessage": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -2339,10 +2426,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -2352,10 +2449,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -2455,10 +2562,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -2714,10 +2831,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -3298,11 +3425,6 @@
"engine" "engine"
] ]
}, },
"URL": {
"type": "string",
"format": "uri",
"pattern": "^(https?://|file://|data:)"
},
"WolframAlphaToolDefinition": { "WolframAlphaToolDefinition": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -3396,10 +3518,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
}, },
{ {
@ -3731,10 +3863,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -3888,10 +4030,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -4316,10 +4468,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -4515,10 +4677,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
}, },
{ {
@ -5407,10 +5579,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -5460,10 +5642,20 @@
{ {
"type": "string" "type": "string"
}, },
{
"$ref": "#/components/schemas/ImageMedia"
},
{ {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [
{
"type": "string" "type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
} }
} }
] ]
@ -6027,32 +6219,32 @@
} }
], ],
"tags": [ "tags": [
{
"name": "Inference"
},
{ {
"name": "Shields" "name": "Shields"
}, },
{ {
"name": "Models" "name": "BatchInference"
},
{
"name": "MemoryBanks"
},
{
"name": "SyntheticDataGeneration"
}, },
{ {
"name": "RewardScoring" "name": "RewardScoring"
}, },
{ {
"name": "PostTraining" "name": "SyntheticDataGeneration"
},
{
"name": "Agents"
},
{
"name": "MemoryBanks"
}, },
{ {
"name": "Safety" "name": "Safety"
}, },
{ {
"name": "Evaluations" "name": "Models"
},
{
"name": "Inference"
}, },
{ {
"name": "Memory" "name": "Memory"
@ -6061,14 +6253,14 @@
"name": "Telemetry" "name": "Telemetry"
}, },
{ {
"name": "Agents" "name": "PostTraining"
},
{
"name": "BatchInference"
}, },
{ {
"name": "Datasets" "name": "Datasets"
}, },
{
"name": "Evaluations"
},
{ {
"name": "BuiltinTool", "name": "BuiltinTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
@ -6077,6 +6269,10 @@
"name": "CompletionMessage", "name": "CompletionMessage",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/CompletionMessage\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/CompletionMessage\" />"
}, },
{
"name": "ImageMedia",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ImageMedia\" />"
},
{ {
"name": "SamplingParams", "name": "SamplingParams",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingParams\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingParams\" />"
@ -6117,6 +6313,10 @@
"name": "ToolResponseMessage", "name": "ToolResponseMessage",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ToolResponseMessage\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/ToolResponseMessage\" />"
}, },
{
"name": "URL",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/URL\" />"
},
{ {
"name": "UserMessage", "name": "UserMessage",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserMessage\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserMessage\" />"
@ -6221,10 +6421,6 @@
"name": "SearchToolDefinition", "name": "SearchToolDefinition",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SearchToolDefinition\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/SearchToolDefinition\" />"
}, },
{
"name": "URL",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/URL\" />"
},
{ {
"name": "WolframAlphaToolDefinition", "name": "WolframAlphaToolDefinition",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/WolframAlphaToolDefinition\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/WolframAlphaToolDefinition\" />"
@ -6661,6 +6857,7 @@
"FunctionCallToolDefinition", "FunctionCallToolDefinition",
"GetAgentsSessionRequest", "GetAgentsSessionRequest",
"GetDocumentsRequest", "GetDocumentsRequest",
"ImageMedia",
"InferenceStep", "InferenceStep",
"InsertDocumentsRequest", "InsertDocumentsRequest",
"LogEventRequest", "LogEventRequest",

View file

@ -210,8 +210,11 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
- $ref: '#/components/schemas/URL' - $ref: '#/components/schemas/URL'
mime_type: mime_type:
@ -273,8 +276,11 @@ components:
items: items:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
type: array type: array
logprobs: logprobs:
@ -441,8 +447,11 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
role: role:
const: assistant const: assistant
@ -466,8 +475,11 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
logprobs: logprobs:
additionalProperties: false additionalProperties: false
@ -742,8 +754,11 @@ components:
items: items:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
type: array type: array
model: model:
@ -893,6 +908,23 @@ components:
required: required:
- document_ids - document_ids
type: object type: object
ImageMedia:
additionalProperties: false
properties:
image:
oneOf:
- additionalProperties: false
properties:
format:
type: string
format_description:
type: string
title: This class represents an image object. To create
type: object
- $ref: '#/components/schemas/URL'
required:
- image
type: object
InferenceStep: InferenceStep:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1041,8 +1073,11 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
- $ref: '#/components/schemas/URL' - $ref: '#/components/schemas/URL'
document_id: document_id:
@ -1108,8 +1143,11 @@ components:
inserted_context: inserted_context:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
memory_bank_ids: memory_bank_ids:
items: items:
@ -1545,8 +1583,11 @@ components:
query: query:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
required: required:
- bank_id - bank_id
@ -1562,8 +1603,11 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
document_id: document_id:
type: string type: string
@ -2067,8 +2111,11 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
role: role:
const: system const: system
@ -2203,6 +2250,14 @@ components:
ToolParamDefinition: ToolParamDefinition:
additionalProperties: false additionalProperties: false
properties: properties:
default:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: description:
type: string type: string
param_type: param_type:
@ -2225,6 +2280,7 @@ components:
enum: enum:
- json - json
- function_tag - function_tag
- python_list
title: This Enum refers to the prompt format for calling custom / zero shot title: This Enum refers to the prompt format for calling custom / zero shot
tools tools
type: string type: string
@ -2236,8 +2292,11 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
tool_name: tool_name:
oneOf: oneOf:
@ -2256,8 +2315,11 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
role: role:
const: ipython const: ipython
@ -2451,14 +2513,20 @@ components:
content: content:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
context: context:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ImageMedia'
- items: - items:
type: string oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array type: array
role: role:
const: user const: user
@ -2501,7 +2569,7 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-09-23 10:56:42.866760" \ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3739,25 +3807,27 @@ security:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
tags: tags:
- name: Inference
- name: Shields - name: Shields
- name: Models - name: BatchInference
- name: MemoryBanks
- name: SyntheticDataGeneration
- name: RewardScoring - name: RewardScoring
- name: PostTraining - name: SyntheticDataGeneration
- name: Agents
- name: MemoryBanks
- name: Safety - name: Safety
- name: Evaluations - name: Models
- name: Inference
- name: Memory - name: Memory
- name: Telemetry - name: Telemetry
- name: Agents - name: PostTraining
- name: BatchInference
- name: Datasets - name: Datasets
- name: Evaluations
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage" - description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
/> />
name: CompletionMessage name: CompletionMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/ImageMedia" />
name: ImageMedia
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" /> - description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" />
name: SamplingParams name: SamplingParams
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingStrategy" - description: <SchemaDefinition schemaRef="#/components/schemas/SamplingStrategy"
@ -3790,6 +3860,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolResponseMessage" - description: <SchemaDefinition schemaRef="#/components/schemas/ToolResponseMessage"
/> />
name: ToolResponseMessage name: ToolResponseMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/URL" />
name: URL
- description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" /> - description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" />
name: UserMessage name: UserMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/BatchChatCompletionRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/BatchChatCompletionRequest"
@ -3876,8 +3948,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/SearchToolDefinition" - description: <SchemaDefinition schemaRef="#/components/schemas/SearchToolDefinition"
/> />
name: SearchToolDefinition name: SearchToolDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/URL" />
name: URL
- description: <SchemaDefinition schemaRef="#/components/schemas/WolframAlphaToolDefinition" - description: <SchemaDefinition schemaRef="#/components/schemas/WolframAlphaToolDefinition"
/> />
name: WolframAlphaToolDefinition name: WolframAlphaToolDefinition
@ -4233,6 +4303,7 @@ x-tagGroups:
- FunctionCallToolDefinition - FunctionCallToolDefinition
- GetAgentsSessionRequest - GetAgentsSessionRequest
- GetDocumentsRequest - GetDocumentsRequest
- ImageMedia
- InferenceStep - InferenceStep
- InsertDocumentsRequest - InsertDocumentsRequest
- LogEventRequest - LogEventRequest

View file

@ -94,14 +94,16 @@ class AgentsClient(Agents):
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def _run_agent(api, tool_definitions, user_prompts, attachments=None): async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
):
agent_config = AgentConfig( agent_config = AgentConfig(
model="Meta-Llama3.1-8B-Instruct", model=model,
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params=SamplingParams(temperature=1.0, top_p=0.9), sampling_params=SamplingParams(temperature=0.6, top_p=0.9),
tools=tool_definitions, tools=tool_definitions,
tool_choice=ToolChoice.auto, tool_choice=ToolChoice.auto,
tool_prompt_format=ToolPromptFormat.function_tag, tool_prompt_format=tool_prompt_format,
enable_session_persistence=False, enable_session_persistence=False,
) )
@ -130,7 +132,8 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
log.print() log.print()
async def run_main(host: str, port: int): async def run_llama_3_1(host: str, port: int):
model = "Llama3.1-8B-Instruct"
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
@ -167,10 +170,11 @@ async def run_main(host: str, port: int):
"Write code to check if a number is prime. Use that to check if 7 is prime", "Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?", "What is the boiling point of polyjuicepotion ?",
] ]
await _run_agent(api, tool_definitions, user_prompts) await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_rag(host: str, port: int): async def run_llama_3_2_rag(host: str, port: int):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
urls = [ urls = [
@ -206,12 +210,71 @@ async def run_rag(host: str, port: int):
"Tell me briefly about llama3 and torchtune", "Tell me briefly about llama3 and torchtune",
] ]
await _run_agent(api, tool_definitions, user_prompts, attachments) await _run_agent(
api, model, tool_definitions, ToolPromptFormat.json, user_prompts, attachments
)
def main(host: str, port: int, rag: bool = False): async def run_llama_3_2(host: str, port: int):
fn = run_rag if rag else run_main model = "Llama3.2-3B-Instruct"
asyncio.run(fn(host, port)) api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models
tool_definitions = [
FunctionCallToolDefinition(
function_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="bool",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
),
FunctionCallToolDefinition(
function_name="make_web_search",
description="Search the web / internet for more realtime information",
parameters={
"query": ToolParamDefinition(
param_type="str",
description="the query to search for",
required=True,
),
},
),
]
user_prompts = [
"Who are you?",
"what is the 100th prime number?",
"Who was 44th President of USA?",
# multiple tool calls in a single prompt
"What is the boiling point of polyjuicepotion and pinkponklyjuice?",
]
await _run_agent(
api, model, tool_definitions, ToolPromptFormat.python_list, user_prompts
)
def main(host: str, port: int, run_type: str):
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
"rag_llama_3_2",
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
fn = {
"tools_llama_3_1": run_llama_3_1,
"tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag,
}
asyncio.run(fn[run_type](host, port))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -10,6 +10,10 @@ from typing import Any, AsyncGenerator, List, Optional
import fire import fire
import httpx import httpx
from llama_models.llama3.api.datatypes import ImageMedia, URL
from PIL import Image as PIL_Image
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api import * # noqa: F403
@ -105,7 +109,7 @@ async def run_main(host: str, port: int, stream: bool):
) )
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
iterator = client.chat_completion( iterator = client.chat_completion(
model="Meta-Llama3.1-8B-Instruct", model="Llama3.1-8B-Instruct",
messages=[message], messages=[message],
stream=stream, stream=stream,
) )
@ -113,7 +117,33 @@ async def run_main(host: str, port: int, stream: bool):
log.print() log.print()
def main(host: str, port: int, stream: bool = True): async def run_mm_main(host: str, port: int, stream: bool, path: str):
client = InferenceClient(f"http://{host}:{port}")
with open(path, "rb") as f:
img = PIL_Image.open(f).convert("RGB")
message = UserMessage(
content=[
ImageMedia(image=URL(uri=f"file://{path}")),
# ImageMedia(image=img),
"Describe this image in two sentences",
],
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct",
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
if mm:
asyncio.run(run_mm_main(host, port, stream, file))
else:
asyncio.run(run_main(host, port, stream)) asyncio.run(run_main(host, port, stream))

View file

@ -7,11 +7,11 @@
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field
@json_schema_type @json_schema_type

View file

@ -51,6 +51,11 @@ class SafetyClient(Safety):
), ),
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps(
{
"together_api_key": "1882f9a484fc7c6ce3e4dc90272d5db52346c93838daab3d704803181f396b22"
}
),
}, },
timeout=20, timeout=20,
) )

View file

@ -44,7 +44,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument( parser.add_argument(
"--source", "--source",
choices=["meta", "huggingface"], choices=["meta", "huggingface"],
required=True, default="meta",
) )
parser.add_argument( parser.add_argument(
"--model-id", "--model-id",
@ -116,7 +116,7 @@ def _hf_download(
"You can find your token by visiting https://huggingface.co/settings/tokens" "You can find your token by visiting https://huggingface.co/settings/tokens"
) )
except RepositoryNotFoundError: except RepositoryNotFoundError:
parser.error(f"Repository '{args.repo_id}' not found on the Hugging Face Hub.") parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
except Exception as e: except Exception as e:
parser.error(e) parser.error(e)

View file

@ -9,7 +9,7 @@ import argparse
from llama_stack.cli.model.describe import ModelDescribe from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.template import ModelTemplate from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
@ -30,5 +30,5 @@ class ModelParser(Subcommand):
# Add sub-commands # Add sub-commands
ModelDownload.create(subparsers) ModelDownload.create(subparsers)
ModelList.create(subparsers) ModelList.create(subparsers)
ModelTemplate.create(subparsers) ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers) ModelDescribe.create(subparsers)

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import subprocess
import textwrap
from io import StringIO
from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily
from llama_stack.cli.subcommand import Subcommand
class ModelPromptFormat(Subcommand):
"""Llama model cli for describe a model prompt format (message formats)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"prompt-format",
prog="llama model prompt-format",
description="Show llama model message formats",
epilog=textwrap.dedent(
"""
Example:
llama model prompt-format <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-name",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
import pkg_resources
# Only Llama 3.1 and 3.2 are supported
supported_model_ids = [
m
for m in CoreModelId
if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
]
model_str = "\n".join([m.value for m in supported_model_ids])
try:
model_id = CoreModelId(args.model_name)
except ValueError:
raise argparse.ArgumentTypeError(
f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}"
) from None
if model_id not in supported_model_ids:
raise argparse.ArgumentTypeError(
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
) from None
llama_3_1_file = pkg_resources.resource_filename(
"llama_models", "llama3_1/prompt_format.md"
)
llama_3_2_text_file = pkg_resources.resource_filename(
"llama_models", "llama3_2/text_prompt_format.md"
)
llama_3_2_vision_file = pkg_resources.resource_filename(
"llama_models", "llama3_2/vision_prompt_format.md"
)
if model_family(model_id) == ModelFamily.llama3_1:
with open(llama_3_1_file, "r") as f:
content = f.read()
elif model_family(model_id) == ModelFamily.llama3_2:
if is_multimodal(model_id):
with open(llama_3_2_vision_file, "r") as f:
content = f.read()
else:
with open(llama_3_2_text_file, "r") as f:
content = f.read()
render_markdown_to_pager(content)
def render_markdown_to_pager(markdown_content: str):
from rich.console import Console
from rich.markdown import Markdown
from rich.style import Style
from rich.text import Text
class LeftAlignedHeaderMarkdown(Markdown):
def parse_header(self, token):
level = token.type.count("h")
content = Text(token.content)
header_style = Style(color="bright_blue", bold=True)
header = Text(f"{'#' * level} ", style=header_style) + content
self.add_text(header)
# Render the Markdown
md = LeftAlignedHeaderMarkdown(markdown_content)
# Capture the rendered output
output = StringIO()
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
console.print(md)
rendered_content = output.getvalue()
# Pipe to pager
pager = subprocess.Popen(["less", "-R"], stdin=subprocess.PIPE)
pager.communicate(input=rendered_content.encode())

View file

@ -1,113 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from termcolor import colored
from llama_stack.cli.subcommand import Subcommand
class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"template",
prog="llama model template",
description="Show llama model message formats",
epilog=textwrap.dedent(
"""
Example:
llama model template <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd)
def _prompt_type(self, value):
from llama_models.llama3.api.datatypes import ToolPromptFormat
try:
return ToolPromptFormat(value.lower())
except ValueError:
raise argparse.ArgumentTypeError(
f"{value} is not a valid ToolPromptFormat. Choose from {', '.join(t.value for t in ToolPromptFormat)}"
) from None
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-family",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
self.parser.add_argument(
"--name",
type=str,
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
required=False,
)
self.parser.add_argument(
"--format",
type=str,
help="ToolPromptFormat (json or function_tag). This flag is used to print the template in a specific formats.",
required=False,
default="json",
)
self.parser.add_argument(
"--raw",
action="store_true",
help="If set to true, don't pretty-print into a table. Useful to copy-paste.",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
from llama_models.llama3.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from llama_stack.cli.table import print_table
if args.name:
tool_prompt_format = self._prompt_type(args.format)
template, tokens_info = render_jinja_template(args.name, tool_prompt_format)
rendered = ""
for tok, is_special in tokens_info:
if is_special:
rendered += colored(tok, "yellow", attrs=["bold"])
else:
rendered += tok
if not args.raw:
rendered = rendered.replace("\n", "\n")
print_table(
[
(
"Name",
colored(template.template_name, "white", attrs=["bold"]),
),
("Template", rendered),
("Notes", template.notes),
],
separate_rows=True,
)
else:
print("Template: ", template.template_name)
print("=" * 40)
print(rendered)
else:
templates = list_jinja_templates()
headers = ["Role", "Template Name"]
print_table(
[(t.role, t.template_name) for t in templates],
headers,
)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -8,6 +8,7 @@
DOCKER_BINARY=${DOCKER_BINARY:-docker} DOCKER_BINARY=${DOCKER_BINARY:-docker}
DOCKER_OPTS=${DOCKER_OPTS:-} DOCKER_OPTS=${DOCKER_OPTS:-}
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
set -euo pipefail set -euo pipefail
@ -37,10 +38,25 @@ port="$1"
shift shift
set -x set -x
$DOCKER_BINARY run $DOCKER_OPTS -it \
if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
$DOCKER_BINARY run $DOCKER_OPTS -it \
-p $port:$port \
-v "$yaml_config:/app/config.yaml" \
-v "$LLAMA_CHECKPOINT_DIR:/root/.llama" \
--gpus=all \
$docker_image \
python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \
--port $port "$@"
fi
if [ -z "$LLAMA_CHECKPOINT_DIR" ]; then
$DOCKER_BINARY run $DOCKER_OPTS -it \
-p $port:$port \ -p $port:$port \
-v "$yaml_config:/app/config.yaml" \ -v "$yaml_config:/app/config.yaml" \
$docker_image \ $docker_image \
python -m llama_stack.distribution.server.server \ python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \ --yaml_config /app/config.yaml \
--port $port "$@" --port $port "$@"
fi

View file

@ -15,14 +15,16 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import FireworksImplConfig from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = { FIREWORKS_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Meta-Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Meta-Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct", "Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
} }
@ -106,7 +108,7 @@ class FireworksInferenceAdapter(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to fireworks # accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request) options = self.get_fireworks_chat_options(request)

View file

@ -16,14 +16,16 @@ from llama_models.sku_list import resolve_model
from ollama import AsyncClient from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
# TODO: Eventually this will move to the llama cli model list command # TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = { OLLAMA_SUPPORTED_SKUS = {
# "Meta-Llama3.1-8B-Instruct": "llama3.1", # "Llama3.1-8B-Instruct": "llama3.1",
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
} }
@ -115,7 +117,7 @@ class OllamaInferenceAdapter(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to ollama # accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request) options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model) ollama_model = self.resolve_ollama_model(request.model)

View file

@ -14,7 +14,9 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import TGIImplConfig from .config import TGIImplConfig
@ -95,7 +97,7 @@ class TGIAdapter(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages) model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens) prompt = self.tokenizer.decode(model_input.tokens)

View file

@ -15,14 +15,16 @@ from llama_models.sku_list import resolve_model
from together import Together from together import Together
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import TogetherImplConfig from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = { TOGETHER_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo",
"Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo",
"Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo",
} }
@ -110,7 +112,7 @@ class TogetherInferenceAdapter(Inference):
# accumulate sampling params and other options to pass to together # accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request) options = self.get_together_chat_options(request)
together_model = self.resolve_together_model(request.model) together_model = self.resolve_together_model(request.model)
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
if not request.stream: if not request.stream:
# TODO: might need to add back an async here # TODO: might need to add back an async here

View file

@ -0,0 +1,548 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 60;
objects = {
/* Begin PBXBuildFile section */
5C03561F2CA3AB9600E3BB46 /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5C03561E2CA3AB9600E3BB46 /* LlamaStackClient */; };
5C5B6E212CA3D89F00AF6130 /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5C5B6E202CA3D89F00AF6130 /* LlamaStackClient */; };
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */ = {isa = PBXBuildFile; fileRef = 5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */; settings = {ATTRIBUTES = (Public, ); }; };
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6742CA1F45800E958D0 /* executorch_debug */; };
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; };
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */; };
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */; };
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */; };
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */; };
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6922CA1F7D000E958D0 /* Stencil */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 036CAF9D2BB1444500D6C2D5;
remoteInfo = LLaMA;
};
5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 03729ED52BB1F8DE00152F2E;
remoteInfo = LLaMARunner;
};
5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6982CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmark;
};
5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6992CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmarkTests;
};
/* End PBXContainerItemProxy section */
/* Begin PBXCopyFilesBuildPhase section */
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */
5CCBC6082CA1F04A00E958D0 /* LocalInference.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LocalInference.framework; sourceTree = BUILT_PRODUCTS_DIR; };
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LocalInference.h; sourceTree = "<group>"; };
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = LLaMA.xcodeproj; path = "executorch/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj"; sourceTree = "<group>"; };
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PromptTemplate.swift; sourceTree = "<group>"; };
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LocalInference.swift; sourceTree = "<group>"; };
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Parsing.swift; sourceTree = "<group>"; };
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SystemPrompts.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
5CCBC6052CA1F04A00E958D0 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
5C03561F2CA3AB9600E3BB46 /* LlamaStackClient in Frameworks */,
5C5B6E212CA3D89F00AF6130 /* LlamaStackClient in Frameworks */,
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */,
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */,
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
5CCBC5FE2CA1F04A00E958D0 = {
isa = PBXGroup;
children = (
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */,
5CCBC60A2CA1F04A00E958D0 /* LocalInference */,
5CCBC6092CA1F04A00E958D0 /* Products */,
5CCBC6852CA1F64A00E958D0 /* Frameworks */,
);
sourceTree = "<group>";
};
5CCBC6092CA1F04A00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC6082CA1F04A00E958D0 /* LocalInference.framework */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC60A2CA1F04A00E958D0 /* LocalInference */ = {
isa = PBXGroup;
children = (
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */,
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */,
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */,
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */,
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */,
);
path = LocalInference;
sourceTree = "<group>";
};
5CCBC6772CA1F63F00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */,
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */,
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */,
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC6852CA1F64A00E958D0 /* Frameworks */ = {
isa = PBXGroup;
children = (
);
name = Frameworks;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
5CCBC6032CA1F04A00E958D0 /* Headers */ = {
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
5CCBC6072CA1F04A00E958D0 /* LocalInference */ = {
isa = PBXNativeTarget;
buildConfigurationList = 5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInference" */;
buildPhases = (
5CCBC6032CA1F04A00E958D0 /* Headers */,
5CCBC6042CA1F04A00E958D0 /* Sources */,
5CCBC6052CA1F04A00E958D0 /* Frameworks */,
5CCBC6062CA1F04A00E958D0 /* Resources */,
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */,
);
buildRules = (
);
dependencies = (
);
name = LocalInference;
packageProductDependencies = (
5CCBC6742CA1F45800E958D0 /* executorch_debug */,
5CCBC6922CA1F7D000E958D0 /* Stencil */,
5C03561E2CA3AB9600E3BB46 /* LlamaStackClient */,
5C5B6E202CA3D89F00AF6130 /* LlamaStackClient */,
);
productName = LocalInferenceProvider;
productReference = 5CCBC6082CA1F04A00E958D0 /* LocalInference.framework */;
productType = "com.apple.product-type.framework";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
5CCBC5FF2CA1F04A00E958D0 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastUpgradeCheck = 1540;
TargetAttributes = {
5CCBC6072CA1F04A00E958D0 = {
CreatedOnToolsVersion = 15.4;
LastSwiftMigration = 1540;
};
};
};
buildConfigurationList = 5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInference" */;
compatibilityVersion = "Xcode 14.0";
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 5CCBC5FE2CA1F04A00E958D0;
packageReferences = (
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */,
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */,
5C5B6E1F2CA3D89F00AF6130 /* XCLocalSwiftPackageReference "internal-llama-stack-client-swift" */,
);
productRefGroup = 5CCBC6092CA1F04A00E958D0 /* Products */;
projectDirPath = "";
projectReferences = (
{
ProductGroup = 5CCBC6772CA1F63F00E958D0 /* Products */;
ProjectRef = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
},
);
projectRoot = "";
targets = (
5CCBC6072CA1F04A00E958D0 /* LocalInference */,
);
};
/* End PBXProject section */
/* Begin PBXReferenceProxy section */
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMA.app;
remoteRef = 5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */ = {
isa = PBXReferenceProxy;
fileType = wrapper.framework;
path = LLaMARunner.framework;
remoteRef = 5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMAPerfBenchmark.app;
remoteRef = 5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */ = {
isa = PBXReferenceProxy;
fileType = wrapper.cfbundle;
path = LLaMAPerfBenchmarkTests.xctest;
remoteRef = 5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
/* End PBXReferenceProxy section */
/* Begin PBXResourcesBuildPhase section */
5CCBC6062CA1F04A00E958D0 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
5CCBC6042CA1F04A00E958D0 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */,
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */,
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */,
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
5CCBC60D2CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Debug;
};
5CCBC60E2CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Release;
};
5CCBC6102CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
5CCBC6112CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInference" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC60D2CA1F04A00E958D0 /* Debug */,
5CCBC60E2CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInference" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC6102CA1F04A00E958D0 /* Debug */,
5CCBC6112CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCLocalSwiftPackageReference section */
5C5B6E1F2CA3D89F00AF6130 /* XCLocalSwiftPackageReference "internal-llama-stack-client-swift" */ = {
isa = XCLocalSwiftPackageReference;
relativePath = "internal-llama-stack-client-swift";
};
/* End XCLocalSwiftPackageReference section */
/* Begin XCRemoteSwiftPackageReference section */
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/pytorch/executorch";
requirement = {
branch = latest;
kind = branch;
};
};
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/stencilproject/Stencil";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.15.1;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
5C03561E2CA3AB9600E3BB46 /* LlamaStackClient */ = {
isa = XCSwiftPackageProductDependency;
productName = LlamaStackClient;
};
5C5B6E202CA3D89F00AF6130 /* LlamaStackClient */ = {
isa = XCSwiftPackageProductDependency;
productName = LlamaStackClient;
};
5CCBC6742CA1F45800E958D0 /* executorch_debug */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */;
productName = executorch_debug;
};
5CCBC6922CA1F7D000E958D0 /* Stencil */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */;
productName = Stencil;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = 5CCBC5FF2CA1F04A00E958D0 /* Project object */;
}

View file

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>IDEDidComputeMac32BitWarning</key>
<true/>
</dict>
</plist>

View file

@ -0,0 +1,16 @@
//
// LocalInference.h
// LocalInference
//
// Created by Dalton Flanagan on 9/23/24.
//
#import <Foundation/Foundation.h>
//! Project version number for LocalInference.
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
//! Project version string for LocalInference.
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>

View file

@ -0,0 +1,167 @@
import Foundation
import LLaMARunner
import LlamaStackClient
class RunnerHolder: ObservableObject {
var runner: Runner?
}
public class LocalInference: Inference {
private var runnerHolder = RunnerHolder()
private let runnerQueue: DispatchQueue
public init (queue: DispatchQueue) {
runnerQueue = queue
}
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
runnerHolder.runner = runnerHolder.runner ?? Runner(
modelPath: modelPath,
tokenizerPath: tokenizerPath
)
runnerQueue.async {
let runner = self.runnerHolder.runner
do {
try runner!.load()
completion(.success(()))
} catch let loadError {
print("error: " + loadError.localizedDescription)
completion(.failure(loadError))
}
}
}
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
return AsyncStream { continuation in
runnerQueue.async {
do {
var tokens: [String] = []
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
var stopReason: Components.Schemas.StopReason? = nil
var buffer = ""
var ipython = false
var echoDropped = false
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
buffer += token
// HACK: Workaround until LlamaRunner exposes echo param
if (!echoDropped) {
if (buffer.hasPrefix(prompt)) {
buffer = String(buffer.dropFirst(prompt.count))
echoDropped = true
}
return
}
tokens.append(token)
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
ipython = true
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(""),
parse_status: Components.Schemas.ToolCallParseStatus.started
)
),
event_type: .progress
)
)
)
if (buffer.starts(with: "<|python_tag|>")) {
buffer = String(buffer.dropFirst("<|python_tag|>".count))
}
}
// TODO: Non-streaming lobprobs
var text = ""
if token == "<|eot_id|>" {
stopReason = Components.Schemas.StopReason.end_of_turn
} else if token == "<|eom_id|>" {
stopReason = Components.Schemas.StopReason.end_of_message
} else {
text = token
}
var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload
if ipython {
delta = .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(text),
parse_status: .in_progress
))
} else {
delta = .case1(text)
}
if stopReason == nil {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: delta,
event_type: .progress
)
)
)
}
}
if stopReason == nil {
stopReason = Components.Schemas.StopReason.out_of_tokens
}
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
// TODO: non-streaming support
let didParseToolCalls = message.tool_calls.count > 0
if ipython && !didParseToolCalls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)),
event_type: .progress
)
// TODO: stopReason
)
)
}
for toolCall in message.tool_calls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .ToolCall(toolCall),
parse_status: .success
)),
event_type: .progress
)
// TODO: stopReason
)
)
}
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .case1(""),
event_type: .complete
)
// TODO: stopReason
)
)
}
catch (let error) {
print("Inference error: " + error.localizedDescription)
}
}
}
}
}

View file

@ -0,0 +1,235 @@
import Foundation
import LlamaStackClient
func encodeHeader(role: String) -> String {
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
}
func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload]) -> String {
var prompt = ""
prompt.append("<|begin_of_text|>")
for message in messages {
let msg = encodeMessage(message: message)
prompt += msg
}
prompt.append(encodeHeader(role: "assistant"))
return prompt
}
func getRole(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
switch (message) {
case .UserMessage(let m):
return m.role.rawValue
case .SystemMessage(let m):
return m.role.rawValue
case .ToolResponseMessage(let m):
return m.role.rawValue
case .CompletionMessage(let m):
return m.role.rawValue
}
}
func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
var prompt = encodeHeader(role: getRole(message: message))
switch (message) {
case .CompletionMessage(let m):
if (m.tool_calls.count > 0) {
prompt += "<|python_tag|>"
}
default:
break
}
func _processContent(_ content: Any) -> String {
func _process(_ c: Any) {
if let str = c as? String {
prompt += str
}
}
if let str = content as? String {
_process(str)
} else if let list = content as? [Any] {
for c in list {
_process(c)
}
}
return ""
}
switch (message) {
case .UserMessage(let m):
prompt += _processContent(m.content)
case .SystemMessage(let m):
prompt += _processContent(m.content)
case .ToolResponseMessage(let m):
prompt += _processContent(m.content)
case .CompletionMessage(let m):
prompt += _processContent(m.content)
}
var eom = false
switch (message) {
case .UserMessage(let m):
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
case .case2(let c):
prompt += _processContent(c)
}
case .CompletionMessage(let m):
// TODO: Support encoding past tool call history
// for t in m.tool_calls {
// _processContent(t.)
//}
eom = m.stop_reason == Components.Schemas.StopReason.end_of_message
case .SystemMessage(_):
break
case .ToolResponseMessage(_):
break
}
if (eom) {
prompt += "<|eom_id|>"
} else {
prompt += "<|eot_id|>"
}
return prompt
}
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] {
var existingMessages = request.messages
var existingSystemMessage: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload?
// TODO: Existing system message
var messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] = []
let defaultGen = SystemDefaultGenerator()
let defaultTemplate = defaultGen.gen()
var sysContent = ""
// TODO: Built-in tools
sysContent += try defaultTemplate.render()
messages.append(.SystemMessage(Components.Schemas.SystemMessage(
content: .case1(sysContent),
role: .system))
)
if request.tools?.isEmpty == false {
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
let toolGen = FunctionTagCustomToolGenerator()
let toolTemplate = try toolGen.gen(customTools: request.tools!)
let tools = try toolTemplate.render()
messages.append(.UserMessage(Components.Schemas.UserMessage(
content: .case1(tools),
role: .user)
))
}
messages.append(contentsOf: existingMessages)
return messages
}
struct FunctionCall {
let name: String
let params: [String: Any]
}
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
guard input.hasPrefix("[") && input.hasSuffix("]") else {
return []
}
do {
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
var result: [Components.Schemas.ToolCall] = []
for call in calls {
guard let nameEndIndex = call.firstIndex(of: "("),
let paramsStartIndex = call.firstIndex(of: "{"),
let paramsEndIndex = call.lastIndex(of: "}") else {
return []
}
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
guard let data = paramsString.data(using: .utf8),
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
return []
}
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
for (param_name, param) in params {
switch (param) {
case let value as String:
props[param_name] = .case1(value)
case let value as Int:
props[param_name] = .case2(value)
case let value as Double:
props[param_name] = .case3(value)
case let value as Bool:
props[param_name] = .case4(value)
default:
return []
}
}
result.append(
Components.Schemas.ToolCall(
arguments: .init(additionalProperties: props),
call_id: UUID().uuidString,
tool_name: .case2(name) // custom_tool
)
)
}
return result.isEmpty ? [] : result
} catch {
return []
}
}
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.StopReason) -> Components.Schemas.CompletionMessage {
var content = tokens
let roles = ["user", "system", "assistant"]
for role in roles {
let headerStr = encodeHeader(role: role)
if content.hasPrefix(headerStr) {
content = String(content.dropFirst(encodeHeader(role: role).count))
}
}
if content.hasPrefix("<|python_tag|>") {
content = String(content.dropFirst("<|python_tag|>".count))
}
if content.hasSuffix("<|eot_id|>") {
content = String(content.dropLast("<|eot_id|>".count))
} else {
content = String(content.dropLast("<|eom_id|>".count))
}
return Components.Schemas.CompletionMessage(
content: .case1(content),
role: .assistant,
stop_reason: stopReason,
tool_calls: maybeExtractCustomToolCalls(input: content)
)
}

View file

@ -0,0 +1,12 @@
import Foundation
import Stencil
public struct PromptTemplate {
let template: String
let data: [String: Any]
public func render() throws -> String {
let template = Template(templateString: self.template)
return try template.render(self.data)
}
}

View file

@ -0,0 +1,91 @@
import Foundation
import LlamaStackClient
func convertToNativeSwiftType(_ value: Any) -> Any {
switch value {
case let number as NSNumber:
if CFGetTypeID(number) == CFBooleanGetTypeID() {
return number.boolValue
}
if floor(number.doubleValue) == number.doubleValue {
return number.intValue
}
return number.doubleValue
case let string as String:
return string
case let array as [Any]:
return array.map(convertToNativeSwiftType)
case let dict as [String: Any]:
return dict.mapValues(convertToNativeSwiftType)
case is NSNull:
return NSNull()
default:
return value
}
}
public class SystemDefaultGenerator {
public init() {}
public func gen() -> PromptTemplate {
let templateStr = """
Cutting Knowledge Date: December 2023
Today Date: {{ today }}
"""
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "dd MMMM yyyy"
return PromptTemplate(
template: templateStr,
data: ["today": dateFormatter.string(from: Date())]
)
}
}
public class FunctionTagCustomToolGenerator {
public init() {}
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
// TODO: required params
// TODO: {{#unless @last}},{{/unless}}
let templateStr = """
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{% for t in custom_tools %}
{
"name": "{{t.tool_name}}",
"description": "{{t.description}}",
"parameters": {
"type": "dict",
"properties": { {{t.parameters}} }
}
{{/let}}
{% endfor -%}
]
"""
let encoder = JSONEncoder()
return PromptTemplate(
template: templateStr,
data: ["custom_tools": try customTools.map {
let data = try encoder.encode($0)
let obj = try JSONSerialization.jsonObject(with: data)
return convertToNativeSwiftType(obj)
}]
)
}
}

View file

@ -0,0 +1,541 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 60;
objects = {
/* Begin PBXBuildFile section */
5CADC71A2CA471CC007662D2 /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5CADC7192CA471CC007662D2 /* LlamaStackClient */; };
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */ = {isa = PBXBuildFile; fileRef = 5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */; settings = {ATTRIBUTES = (Public, ); }; };
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6742CA1F45800E958D0 /* executorch_debug */; };
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; };
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */; };
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */; };
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */; };
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */; };
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6922CA1F7D000E958D0 /* Stencil */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 036CAF9D2BB1444500D6C2D5;
remoteInfo = LLaMA;
};
5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 03729ED52BB1F8DE00152F2E;
remoteInfo = LLaMARunner;
};
5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6982CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmark;
};
5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6992CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmarkTests;
};
/* End PBXContainerItemProxy section */
/* Begin PBXCopyFilesBuildPhase section */
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */
5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LocalInferenceImpl.framework; sourceTree = BUILT_PRODUCTS_DIR; };
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LocalInference.h; sourceTree = "<group>"; };
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = LLaMA.xcodeproj; path = "executorch/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj"; sourceTree = "<group>"; };
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PromptTemplate.swift; sourceTree = "<group>"; };
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LocalInference.swift; sourceTree = "<group>"; };
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Parsing.swift; sourceTree = "<group>"; };
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SystemPrompts.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
5CCBC6052CA1F04A00E958D0 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
5CADC71A2CA471CC007662D2 /* LlamaStackClient in Frameworks */,
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */,
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */,
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
5CCBC5FE2CA1F04A00E958D0 = {
isa = PBXGroup;
children = (
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */,
5CCBC60A2CA1F04A00E958D0 /* LocalInferenceImpl */,
5CCBC6092CA1F04A00E958D0 /* Products */,
5CCBC6852CA1F64A00E958D0 /* Frameworks */,
);
sourceTree = "<group>";
};
5CCBC6092CA1F04A00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC60A2CA1F04A00E958D0 /* LocalInferenceImpl */ = {
isa = PBXGroup;
children = (
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */,
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */,
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */,
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */,
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */,
);
path = LocalInferenceImpl;
sourceTree = "<group>";
};
5CCBC6772CA1F63F00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */,
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */,
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */,
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC6852CA1F64A00E958D0 /* Frameworks */ = {
isa = PBXGroup;
children = (
);
name = Frameworks;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
5CCBC6032CA1F04A00E958D0 /* Headers */ = {
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
5CCBC6072CA1F04A00E958D0 /* LocalInferenceImpl */ = {
isa = PBXNativeTarget;
buildConfigurationList = 5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInferenceImpl" */;
buildPhases = (
5CCBC6032CA1F04A00E958D0 /* Headers */,
5CCBC6042CA1F04A00E958D0 /* Sources */,
5CCBC6052CA1F04A00E958D0 /* Frameworks */,
5CCBC6062CA1F04A00E958D0 /* Resources */,
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */,
);
buildRules = (
);
dependencies = (
);
name = LocalInferenceImpl;
packageProductDependencies = (
5CCBC6742CA1F45800E958D0 /* executorch_debug */,
5CCBC6922CA1F7D000E958D0 /* Stencil */,
5CADC7192CA471CC007662D2 /* LlamaStackClient */,
);
productName = LocalInferenceProvider;
productReference = 5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */;
productType = "com.apple.product-type.framework";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
5CCBC5FF2CA1F04A00E958D0 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastUpgradeCheck = 1540;
TargetAttributes = {
5CCBC6072CA1F04A00E958D0 = {
CreatedOnToolsVersion = 15.4;
LastSwiftMigration = 1540;
};
};
};
buildConfigurationList = 5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInferenceImpl" */;
compatibilityVersion = "Xcode 14.0";
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 5CCBC5FE2CA1F04A00E958D0;
packageReferences = (
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */,
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */,
5CADC7182CA471CC007662D2 /* XCLocalSwiftPackageReference "internal-llama-stack-client-swift" */,
);
productRefGroup = 5CCBC6092CA1F04A00E958D0 /* Products */;
projectDirPath = "";
projectReferences = (
{
ProductGroup = 5CCBC6772CA1F63F00E958D0 /* Products */;
ProjectRef = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
},
);
projectRoot = "";
targets = (
5CCBC6072CA1F04A00E958D0 /* LocalInferenceImpl */,
);
};
/* End PBXProject section */
/* Begin PBXReferenceProxy section */
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMA.app;
remoteRef = 5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */ = {
isa = PBXReferenceProxy;
fileType = wrapper.framework;
path = LLaMARunner.framework;
remoteRef = 5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMAPerfBenchmark.app;
remoteRef = 5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */ = {
isa = PBXReferenceProxy;
fileType = wrapper.cfbundle;
path = LLaMAPerfBenchmarkTests.xctest;
remoteRef = 5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
/* End PBXReferenceProxy section */
/* Begin PBXResourcesBuildPhase section */
5CCBC6062CA1F04A00E958D0 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
5CCBC6042CA1F04A00E958D0 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */,
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */,
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */,
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
5CCBC60D2CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Debug;
};
5CCBC60E2CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Release;
};
5CCBC6102CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
5CCBC6112CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInferenceImpl" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC60D2CA1F04A00E958D0 /* Debug */,
5CCBC60E2CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInferenceImpl" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC6102CA1F04A00E958D0 /* Debug */,
5CCBC6112CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCLocalSwiftPackageReference section */
5CADC7182CA471CC007662D2 /* XCLocalSwiftPackageReference "internal-llama-stack-client-swift" */ = {
isa = XCLocalSwiftPackageReference;
relativePath = "internal-llama-stack-client-swift";
};
/* End XCLocalSwiftPackageReference section */
/* Begin XCRemoteSwiftPackageReference section */
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/pytorch/executorch";
requirement = {
branch = latest;
kind = branch;
};
};
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/stencilproject/Stencil";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.15.1;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
5CADC7192CA471CC007662D2 /* LlamaStackClient */ = {
isa = XCSwiftPackageProductDependency;
productName = LlamaStackClient;
};
5CCBC6742CA1F45800E958D0 /* executorch_debug */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */;
productName = executorch_debug;
};
5CCBC6922CA1F7D000E958D0 /* Stencil */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */;
productName = Stencil;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = 5CCBC5FF2CA1F04A00E958D0 /* Project object */;
}

View file

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>IDEDidComputeMac32BitWarning</key>
<true/>
</dict>
</plist>

View file

@ -0,0 +1,16 @@
//
// LocalInference.h
// LocalInference
//
// Created by Dalton Flanagan on 9/23/24.
//
#import <Foundation/Foundation.h>
//! Project version number for LocalInference.
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
//! Project version string for LocalInference.
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>

View file

@ -0,0 +1,167 @@
import Foundation
import LLaMARunner
import LlamaStackClient
class RunnerHolder: ObservableObject {
var runner: Runner?
}
public class LocalInference: Inference {
private var runnerHolder = RunnerHolder()
private let runnerQueue: DispatchQueue
public init (queue: DispatchQueue) {
runnerQueue = queue
}
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
runnerHolder.runner = runnerHolder.runner ?? Runner(
modelPath: modelPath,
tokenizerPath: tokenizerPath
)
runnerQueue.async {
let runner = self.runnerHolder.runner
do {
try runner!.load()
completion(.success(()))
} catch let loadError {
print("error: " + loadError.localizedDescription)
completion(.failure(loadError))
}
}
}
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
return AsyncStream { continuation in
runnerQueue.async {
do {
var tokens: [String] = []
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
var stopReason: Components.Schemas.StopReason? = nil
var buffer = ""
var ipython = false
var echoDropped = false
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
buffer += token
// HACK: Workaround until LlamaRunner exposes echo param
if (!echoDropped) {
if (buffer.hasPrefix(prompt)) {
buffer = String(buffer.dropFirst(prompt.count))
echoDropped = true
}
return
}
tokens.append(token)
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
ipython = true
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(""),
parse_status: Components.Schemas.ToolCallParseStatus.started
)
),
event_type: .progress
)
)
)
if (buffer.starts(with: "<|python_tag|>")) {
buffer = String(buffer.dropFirst("<|python_tag|>".count))
}
}
// TODO: Non-streaming lobprobs
var text = ""
if token == "<|eot_id|>" {
stopReason = Components.Schemas.StopReason.end_of_turn
} else if token == "<|eom_id|>" {
stopReason = Components.Schemas.StopReason.end_of_message
} else {
text = token
}
var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload
if ipython {
delta = .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(text),
parse_status: .in_progress
))
} else {
delta = .case1(text)
}
if stopReason == nil {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: delta,
event_type: .progress
)
)
)
}
}
if stopReason == nil {
stopReason = Components.Schemas.StopReason.out_of_tokens
}
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
// TODO: non-streaming support
let didParseToolCalls = message.tool_calls.count > 0
if ipython && !didParseToolCalls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)),
event_type: .progress
)
// TODO: stopReason
)
)
}
for toolCall in message.tool_calls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .ToolCall(toolCall),
parse_status: .success
)),
event_type: .progress
)
// TODO: stopReason
)
)
}
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .case1(""),
event_type: .complete
)
// TODO: stopReason
)
)
}
catch (let error) {
print("Inference error: " + error.localizedDescription)
}
}
}
}
}

View file

@ -0,0 +1,235 @@
import Foundation
import LlamaStackClient
func encodeHeader(role: String) -> String {
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
}
func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload]) -> String {
var prompt = ""
prompt.append("<|begin_of_text|>")
for message in messages {
let msg = encodeMessage(message: message)
prompt += msg
}
prompt.append(encodeHeader(role: "assistant"))
return prompt
}
func getRole(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
switch (message) {
case .UserMessage(let m):
return m.role.rawValue
case .SystemMessage(let m):
return m.role.rawValue
case .ToolResponseMessage(let m):
return m.role.rawValue
case .CompletionMessage(let m):
return m.role.rawValue
}
}
func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
var prompt = encodeHeader(role: getRole(message: message))
switch (message) {
case .CompletionMessage(let m):
if (m.tool_calls.count > 0) {
prompt += "<|python_tag|>"
}
default:
break
}
func _processContent(_ content: Any) -> String {
func _process(_ c: Any) {
if let str = c as? String {
prompt += str
}
}
if let str = content as? String {
_process(str)
} else if let list = content as? [Any] {
for c in list {
_process(c)
}
}
return ""
}
switch (message) {
case .UserMessage(let m):
prompt += _processContent(m.content)
case .SystemMessage(let m):
prompt += _processContent(m.content)
case .ToolResponseMessage(let m):
prompt += _processContent(m.content)
case .CompletionMessage(let m):
prompt += _processContent(m.content)
}
var eom = false
switch (message) {
case .UserMessage(let m):
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
case .case2(let c):
prompt += _processContent(c)
}
case .CompletionMessage(let m):
// TODO: Support encoding past tool call history
// for t in m.tool_calls {
// _processContent(t.)
//}
eom = m.stop_reason == Components.Schemas.StopReason.end_of_message
case .SystemMessage(_):
break
case .ToolResponseMessage(_):
break
}
if (eom) {
prompt += "<|eom_id|>"
} else {
prompt += "<|eot_id|>"
}
return prompt
}
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] {
var existingMessages = request.messages
var existingSystemMessage: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload?
// TODO: Existing system message
var messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] = []
let defaultGen = SystemDefaultGenerator()
let defaultTemplate = defaultGen.gen()
var sysContent = ""
// TODO: Built-in tools
sysContent += try defaultTemplate.render()
messages.append(.SystemMessage(Components.Schemas.SystemMessage(
content: .case1(sysContent),
role: .system))
)
if request.tools?.isEmpty == false {
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
let toolGen = FunctionTagCustomToolGenerator()
let toolTemplate = try toolGen.gen(customTools: request.tools!)
let tools = try toolTemplate.render()
messages.append(.UserMessage(Components.Schemas.UserMessage(
content: .case1(tools),
role: .user)
))
}
messages.append(contentsOf: existingMessages)
return messages
}
struct FunctionCall {
let name: String
let params: [String: Any]
}
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
guard input.hasPrefix("[") && input.hasSuffix("]") else {
return []
}
do {
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
var result: [Components.Schemas.ToolCall] = []
for call in calls {
guard let nameEndIndex = call.firstIndex(of: "("),
let paramsStartIndex = call.firstIndex(of: "{"),
let paramsEndIndex = call.lastIndex(of: "}") else {
return []
}
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
guard let data = paramsString.data(using: .utf8),
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
return []
}
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
for (param_name, param) in params {
switch (param) {
case let value as String:
props[param_name] = .case1(value)
case let value as Int:
props[param_name] = .case2(value)
case let value as Double:
props[param_name] = .case3(value)
case let value as Bool:
props[param_name] = .case4(value)
default:
return []
}
}
result.append(
Components.Schemas.ToolCall(
arguments: .init(additionalProperties: props),
call_id: UUID().uuidString,
tool_name: .case2(name) // custom_tool
)
)
}
return result.isEmpty ? [] : result
} catch {
return []
}
}
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.StopReason) -> Components.Schemas.CompletionMessage {
var content = tokens
let roles = ["user", "system", "assistant"]
for role in roles {
let headerStr = encodeHeader(role: role)
if content.hasPrefix(headerStr) {
content = String(content.dropFirst(encodeHeader(role: role).count))
}
}
if content.hasPrefix("<|python_tag|>") {
content = String(content.dropFirst("<|python_tag|>".count))
}
if content.hasSuffix("<|eot_id|>") {
content = String(content.dropLast("<|eot_id|>".count))
} else {
content = String(content.dropLast("<|eom_id|>".count))
}
return Components.Schemas.CompletionMessage(
content: .case1(content),
role: .assistant,
stop_reason: stopReason,
tool_calls: maybeExtractCustomToolCalls(input: content)
)
}

View file

@ -0,0 +1,12 @@
import Foundation
import Stencil
public struct PromptTemplate {
let template: String
let data: [String: Any]
public func render() throws -> String {
let template = Template(templateString: self.template)
return try template.render(self.data)
}
}

View file

@ -0,0 +1,91 @@
import Foundation
import LlamaStackClient
func convertToNativeSwiftType(_ value: Any) -> Any {
switch value {
case let number as NSNumber:
if CFGetTypeID(number) == CFBooleanGetTypeID() {
return number.boolValue
}
if floor(number.doubleValue) == number.doubleValue {
return number.intValue
}
return number.doubleValue
case let string as String:
return string
case let array as [Any]:
return array.map(convertToNativeSwiftType)
case let dict as [String: Any]:
return dict.mapValues(convertToNativeSwiftType)
case is NSNull:
return NSNull()
default:
return value
}
}
public class SystemDefaultGenerator {
public init() {}
public func gen() -> PromptTemplate {
let templateStr = """
Cutting Knowledge Date: December 2023
Today Date: {{ today }}
"""
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "dd MMMM yyyy"
return PromptTemplate(
template: templateStr,
data: ["today": dateFormatter.string(from: Date())]
)
}
}
public class FunctionTagCustomToolGenerator {
public init() {}
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
// TODO: required params
// TODO: {{#unless @last}},{{/unless}}
let templateStr = """
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{% for t in custom_tools %}
{
"name": "{{t.tool_name}}",
"description": "{{t.description}}",
"parameters": {
"type": "dict",
"properties": { {{t.parameters}} }
}
{{/let}}
{% endfor -%}
]
"""
let encoder = JSONEncoder()
return PromptTemplate(
template: templateStr,
data: ["custom_tools": try customTools.map {
let data = try encoder.encode($0)
let obj = try JSONSerialization.jsonObject(with: data)
return convertToNativeSwiftType(obj)
}]
)
}
}

View file

@ -0,0 +1,109 @@
# LocalInference
LocalInference provides a local inference implementation powered by [executorch](https://github.com/pytorch/executorch/).
Llama Stack currently supports on-device inference for iOS with Android coming soon. You can run on-device inference on Android today using [executorch](https://github.com/pytorch/executorch/tree/main/examples/demo-apps/android/LlamaDemo), PyTorchs on-device inference library.
## Installation
We're working on making LocalInference easier to set up. For now, you'll need to import it via `.xcframework`:
1. Clone the executorch submodule in this repo and its dependencies: `git submodule update --init --recursive`
1. Install [Cmake](https://cmake.org/) for the executorch build`
1. Drag `LocalInference.xcodeproj` into your project
1. Add `LocalInference` as a framework in your app target
1. Add a package dependency on https://github.com/pytorch/executorch (branch latest)
1. Add all the kernels / backends from executorch (but not exectuorch itself!) as frameworks in your app target:
- backend_coreml
- backend_mps
- backend_xnnpack
- kernels_custom
- kernels_optimized
- kernels_portable
- kernels_quantized
1. In "Build Settings" > "Other Linker Flags" > "Any iOS Simulator SDK", add:
```
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
```
1. In "Build Settings" > "Other Linker Flags" > "Any iOS SDK", add:
```
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
```
## Preparing a model
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md#step-2-prepare-model)
2. Bundle the `.pte` and `tokenizer.model` file into your app
## Using LocalInference
1. Instantiate LocalInference with a DispatchQueue. Optionally, pass it into your agents service:
```swift
init () {
runnerQueue = DispatchQueue(label: "org.meta.llamastack")
inferenceService = LocalInferenceService(queue: runnerQueue)
agentsService = LocalAgentsService(inference: inferenceService)
}
```
2. Before making any inference calls, load your model from your bundle:
```swift
let mainBundle = Bundle.main
inferenceService.loadModel(
modelPath: mainBundle.url(forResource: "llama32_1b_spinquant", withExtension: "pte"),
tokenizerPath: mainBundle.url(forResource: "tokenizer", withExtension: "model"),
completion: {_ in } // use to handle load failures
)
```
3. Make inference calls (or agents calls) as you normally would with LlamaStack:
```
for await chunk in try await agentsService.initAndCreateTurn(
messages: [
.UserMessage(Components.Schemas.UserMessage(
content: .case1("Call functions as needed to handle any actions in the following text:\n\n" + text),
role: .user))
]
) {
```
## Troubleshooting
If you receive errors like "missing package product" or "invalid checksum", try cleaning the build folder and resetting the Swift package cache:
(Opt+Click) Product > Clean Build Folder Immediately
```
rm -rf \
~/Library/org.swift.swiftpm \
~/Library/Caches/org.swift.swiftpm \
~/Library/Caches/com.apple.dt.Xcode \
~/Library/Developer/Xcode/DerivedData
```

View file

@ -398,7 +398,11 @@ class ChatAgent(ShieldRunnerMixin):
color = "yellow" color = "yellow"
else: else:
color = None color = None
cprint(f"{str(msg)}", color=color) if len(str(msg)) > 1000:
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else:
msg_str = str(msg)
cprint(f"{msg_str}", color=color)
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -466,6 +470,13 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = event.stop_reason stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens stop_reason = stop_reason or StopReason.out_of_tokens
# If tool calls are parsed successfully,
# if content is not made null the tool call str will also be in the content
# and tokens will have tool call syntax included twice
if tool_calls:
content = ""
message = CompletionMessage( message = CompletionMessage(
content=content, content=content,
stop_reason=stop_reason, stop_reason=stop_reason,

View file

@ -10,13 +10,14 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig, DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator, MemoryQueryGenerator,
MemoryQueryGeneratorConfig, MemoryQueryGeneratorConfig,
) )
from termcolor import cprint # noqa: F401
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403

View file

@ -16,7 +16,7 @@ from pydantic import BaseModel, Field, field_validator
class MetaReferenceImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):
model: str = Field( model: str = Field(
default="Meta-Llama3.1-8B-Instruct", default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
) )
quantization: Optional[QuantizationConfig] = None quantization: Optional[QuantizationConfig] = None
@ -30,7 +30,7 @@ class MetaReferenceImplConfig(BaseModel):
permitted_models = [ permitted_models = [
m.descriptor() m.descriptor()
for m in all_registered_models() for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1 if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
or m.core_model_id == CoreModelId.llama_guard_3_8b or m.core_model_id == CoreModelId.llama_guard_3_8b
] ]
if model not in permitted_models: if model not in permitted_models:
@ -42,14 +42,9 @@ class MetaReferenceImplConfig(BaseModel):
@property @property
def model_parallel_size(self) -> int: def model_parallel_size(self) -> int:
# HUGE HACK ALERT: this will be fixed when we move inference configuration # HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size` # to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there # as configuration there
gpu_count = 1
resolved = resolve_model(self.model) resolved = resolve_model(self.model)
assert resolved is not None assert resolved is not None
descriptor = resolved.descriptor().lower() return resolved.pth_file_count
if "-70b" in descriptor or "-405b" in descriptor:
gpu_count = 8
return gpu_count

View file

@ -24,21 +24,31 @@ from fairscale.nn.model_parallel.initialize import (
) )
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from termcolor import cprint
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model.descriptor()))
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original" checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), ( assert checkpoint_dir.exists(), (
@ -134,6 +144,10 @@ class Llama:
# load on CPU in bf16 so that fp8 conversion does not find an # load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype # unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor) torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args) model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config) model = convert_to_quantized_model(model, config)
@ -142,6 +156,10 @@ class Llama:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else: else:
torch.set_default_tensor_type(torch.cuda.HalfTensor) torch.set_default_tensor_type(torch.cuda.HalfTensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args) model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
@ -167,7 +185,11 @@ class Llama:
) -> Generator: ) -> Generator:
params = self.model.params params = self.model.params
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red") # input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens] prompt_tokens = [model_input.tokens]
bsz = 1 bsz = 1
@ -183,6 +205,21 @@ class Llama:
return return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = isinstance(self.model, CrossAttentionTransformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
)
pad_id = self.tokenizer.pad_id pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens): for k, t in enumerate(prompt_tokens):
@ -206,6 +243,18 @@ class Llama:
stop_tokens = torch.tensor(self.tokenizer.stop_tokens) stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len): for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
prev_pos, cur_pos, dtype=torch.long, device="cuda"
)
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0: if temperature > 0:
@ -222,6 +271,18 @@ class Llama:
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1] target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs: if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2), input=logits.transpose(1, 2),
@ -248,7 +309,7 @@ class Llama:
def text_completion( def text_completion(
self, self,
prompt: str, content: InterleavedTextMedia,
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
@ -262,10 +323,10 @@ class Llama:
): ):
max_gen_len = self.model.params.max_seq_len - 1 max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) model_input = self.formatter.encode_content(content)
yield from self.generate( yield from self.generate(
model_input=ModelInput(tokens=prompt_tokens), model_input=model_input,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,

View file

@ -21,7 +21,9 @@ from llama_stack.apis.inference import (
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
) )
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
@ -57,7 +59,7 @@ class MetaReferenceInferenceImpl(Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = [], tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -70,14 +72,14 @@ class MetaReferenceInferenceImpl(Inference):
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
model = resolve_model(request.model) model = resolve_model(request.model)
if model is None: if model is None:
raise RuntimeError( raise RuntimeError(

View file

@ -7,11 +7,11 @@
from .config import SafetyConfig from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, _deps): async def get_provider_impl(config: SafetyConfig, deps):
from .safety import MetaReferenceSafetyImpl from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config) impl = MetaReferenceSafetyImpl(config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -31,7 +31,10 @@ class LlamaGuardShieldConfig(BaseModel):
permitted_models = [ permitted_models = [
m.descriptor() m.descriptor()
for m in safety_models() for m in safety_models()
if m.core_model_id == CoreModelId.llama_guard_3_8b if (
m.core_model_id
in {CoreModelId.llama_guard_3_8b, CoreModelId.llama_guard_3_11b_vision}
)
] ]
if model not in permitted_models: if model not in permitted_models:
raise ValueError( raise ValueError(

View file

@ -7,8 +7,10 @@
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.impls.meta_reference.safety.shields.base import ( from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction, OnViolationAction,
@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety): class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None: def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None: async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None: if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model) model_dir = resolve_and_get_path(shield_cfg.model)
@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety):
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config cfg = self.config
if typ == MetaReferenceShieldType.llama_guard: if typ == MetaReferenceShieldType.llama_guard:
cfg = cfg.llama_guard_shield
assert ( assert (
cfg.llama_guard_shield is not None cfg is not None
), "Cannot use LlamaGuardShield since not present in config" ), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir) return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
elif typ == MetaReferenceShieldType.jailbreak_shield: elif typ == MetaReferenceShieldType.jailbreak_shield:
assert ( assert (
cfg.prompt_guard_shield is not None cfg.prompt_guard_shield is not None

View file

@ -9,9 +9,8 @@ import re
from string import Template from string import Template
from typing import List, Optional from typing import List, Optional
import torch
from llama_models.llama3.api.datatypes import Message, Role from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer from llama_stack.apis.inference import * # noqa: F403
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -100,39 +99,17 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase): class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__( def __init__(
self, self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE, model: str,
model_dir: str = None, inference_api: Inference,
excluded_categories: List[str] = None, excluded_categories: List[str] = None,
disable_input_check: bool = False, disable_input_check: bool = False,
disable_output_check: bool = False, disable_output_check: bool = False,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
): ):
super().__init__(on_violation_action) super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None: if excluded_categories is None:
excluded_categories = [] excluded_categories = []
@ -140,18 +117,12 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda" self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def check_unsafe_response(self, response: str) -> Optional[str]: def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
if match: if match:
@ -212,26 +183,21 @@ class LlamaGuardShield(ShieldBase):
) )
else: else:
prompt = self.build_prompt(messages) prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode( # TODO: llama-stack inference protocol has issues with non-streaming inference code
generated_tokens[0], skip_special_tokens=True content = ""
) async for chunk in self.inference_api.chat_completion(
response = response.strip() model=self.model,
shield_response = self.get_shield_response(response) messages=[
UserMessage(content=prompt),
],
stream=True,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress:
assert isinstance(event.delta, str)
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response return shield_response

View file

@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]:
"fairscale", "fairscale",
"fbgemm-gpu==0.8.0", "fbgemm-gpu==0.8.0",
"torch", "torch",
"torchvision",
"transformers", "transformers",
"zmq", "zmq",
], ],
@ -75,15 +76,4 @@ def available_providers() -> List[ProviderSpec]:
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor", header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
), ),
), ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="bedrock",
pip_packages=[
"boto3",
],
module="llama_stack.providers.adapters.inference.bedrock",
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),
),
] ]

View file

@ -21,13 +21,15 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety, api=Api.safety,
provider_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"accelerate",
"codeshield", "codeshield",
"torch",
"transformers", "transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
], ],
module="llama_stack.providers.impls.meta_reference.safety", module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
api_dependencies=[
Api.inference,
],
), ),
remote_provider_spec( remote_provider_spec(
api=Api.safety, api=Api.safety,

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_models.datatypes import ModelFamily
from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_models.sku_list import resolve_model
def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
model = resolve_model(request.model)
if model is None:
cprint(f"Could not resolve model {request.model}", color="red")
return request.messages
if model.model_family not in [ModelFamily.llama3_1, ModelFamily.llama3_2]:
cprint(f"Model family {model.model_family} not llama 3_1 or 3_2", color="red")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return augment_messages_for_tools_llama_3_1(request)
elif model.model_family == ModelFamily.llama3_2:
return augment_messages_for_tools_llama_3_2(request)
else:
return request.messages
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
if isinstance(t.tool_name, str):
custom_tools.append(t)
else:
builtin_tools.append(t)
tool_template = None
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
sys_content += tool_template.render()
sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
if request.tool_prompt_format != ToolPromptFormat.python_list:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
tool_gen = PythonListCustomToolGenerator()
tool_template = tool_gen.gen(custom_tools)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message:
sys_content += interleaved_text_media_as_str(
existing_system_message.content, sep="\n"
)
messages.append(SystemMessage(content=sys_content))
# Add back existing messages from the request
messages += existing_messages
return messages

View file

@ -1,84 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages

View file

@ -7,4 +7,5 @@ prompt-toolkit
python-dotenv python-dotenv
pydantic pydantic
requests requests
rich
termcolor termcolor

View file

@ -8,9 +8,9 @@ import unittest
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403 from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages from llama_stack.inference.augment_messages import augment_messages_for_tools
MODEL = "Meta-Llama3.1-8B-Instruct" MODEL = "Llama3.1-8B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content), UserMessage(content=content),
], ],
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2) self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content) self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search), ToolDefinition(tool_name=BuiltinTool.brave_search),
], ],
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2) self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content) self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
tool_prompt_format=ToolPromptFormat.json, tool_prompt_format=ToolPromptFormat.json,
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 3) self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content) self.assertTrue("Environment: ipython" in messages[0].content)
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
), ),
], ],
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 3) self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content) self.assertTrue("Environment: ipython" in messages[0].content)
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter), ToolDefinition(tool_name=BuiltinTool.code_interpreter),
], ],
) )
messages = prepare_messages(request) messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2, messages) self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt)) self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -59,7 +59,7 @@ class TestE2E(unittest.IsolatedAsyncioTestCase):
host=TestE2E.HOST, host=TestE2E.HOST,
port=TestE2E.PORT, port=TestE2E.PORT,
custom_tools=custom_tools, custom_tools=custom_tools,
# model="Meta-Llama3.1-70B-Instruct", # Defaults to 8B # model="Llama3.1-70B-Instruct", # Defaults to 8B
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
) )
await client.create_session(__file__) await client.create_session(__file__)

View file

@ -9,31 +9,15 @@
import asyncio import asyncio
import os import os
import textwrap
import unittest import unittest
from datetime import datetime from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_models.llama3.api.datatypes import (
BuiltinTool,
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.inference.api import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_stack.inference.meta_reference.config import MetaReferenceImplConfig from llama_stack.inference.meta_reference.config import MetaReferenceImplConfig
from llama_stack.inference.meta_reference.inference import get_provider_impl from llama_stack.inference.meta_reference.inference import get_provider_impl
MODEL = "Meta-Llama3.1-8B-Instruct" MODEL = "Llama3.1-8B-Instruct"
HELPER_MSG = """ HELPER_MSG = """
This test needs llama-3.1-8b-instruct models. This test needs llama-3.1-8b-instruct models.
Please donwload using the llama cli Please donwload using the llama cli
@ -45,11 +29,10 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token <
class InferenceTests(unittest.IsolatedAsyncioTestCase): class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# This runs the async setup function
asyncio.run(cls.asyncSetUpClass()) asyncio.run(cls.asyncSetUpClass())
@classmethod @classmethod
async def asyncSetUpClass(cls): async def asyncSetUpClass(cls): # noqa
# assert model exists on local # assert model exists on local
model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/") model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/")
assert os.path.isdir(model_dir), HELPER_MSG assert os.path.isdir(model_dir), HELPER_MSG
@ -67,11 +50,10 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
# This runs the async teardown function
asyncio.run(cls.asyncTearDownClass()) asyncio.run(cls.asyncTearDownClass())
@classmethod @classmethod
async def asyncTearDownClass(cls): async def asyncTearDownClass(cls): # noqa
await cls.api.shutdown() await cls.api.shutdown()
async def asyncSetUp(self): async def asyncSetUp(self):

View file

@ -4,26 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import textwrap
import unittest import unittest
from datetime import datetime
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import * # noqa: F403
BuiltinTool, from llama_stack.inference.api import * # noqa: F403
SamplingParams,
SamplingStrategy,
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.inference.api import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_stack.inference.ollama.config import OllamaImplConfig from llama_stack.inference.ollama.config import OllamaImplConfig
from llama_stack.inference.ollama.ollama import get_provider_impl from llama_stack.inference.ollama.ollama import get_provider_impl
@ -52,7 +36,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
), ),
}, },
) )
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" self.valid_supported_model = "Llama3.1-8B-Instruct"
async def asyncTearDown(self): async def asyncTearDown(self):
await self.api.shutdown() await self.api.shutdown()
@ -272,7 +256,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
ollama_model = self.api.resolve_ollama_model(self.valid_supported_model) ollama_model = self.api.resolve_ollama_model(self.valid_supported_model)
self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16") self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16")
invalid_model = "Meta-Llama3.1-8B" invalid_model = "Llama3.1-8B"
with self.assertRaisesRegex( with self.assertRaisesRegex(
AssertionError, f"Unsupported model: {invalid_model}" AssertionError, f"Unsupported model: {invalid_model}"
): ):