Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

9
.gitignore vendored
View file

@ -5,6 +5,11 @@ dist
dev_requirements.txt
build
.DS_Store
.idea
*.iml
llama_stack/configs/*
xcuserdata/
*.hmap
.DS_Store
.build/
Package.resolved
*.pte
*.ipynb_checkpoints*

View file

@ -37,50 +37,74 @@ llama model list
You should see a table like this:
<pre style="font-family: monospace;">
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Model Descriptor | HuggingFace Repo | Context Length | Hardware Requirements |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-8B | meta-llama/Meta-Llama-3.1-8B | 128K | 1 GPU, each >= 20GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-70B | meta-llama/Meta-Llama-3.1-70B | 128K | 8 GPUs, each >= 20GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-405B:bf16-mp8 | | 128K | 8 GPUs, each >= 120GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-405B | meta-llama/Meta-Llama-3.1-405B-FP8 | 128K | 8 GPUs, each >= 70GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-405B:bf16-mp16 | meta-llama/Meta-Llama-3.1-405B | 128K | 16 GPUs, each >= 70GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-8B-Instruct | meta-llama/Meta-Llama-3.1-8B-Instruct | 128K | 1 GPU, each >= 20GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-70B-Instruct | meta-llama/Meta-Llama-3.1-70B-Instruct | 128K | 8 GPUs, each >= 20GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-405B-Instruct:bf16-mp8 | | 128K | 8 GPUs, each >= 120GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-405B-Instruct | meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 | 128K | 8 GPUs, each >= 70GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Meta-Llama3.1-405B-Instruct:bf16-mp16 | meta-llama/Meta-Llama-3.1-405B-Instruct | 128K | 16 GPUs, each >= 70GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Llama-Guard-3-8B | meta-llama/Llama-Guard-3-8B | 128K | 1 GPU, each >= 20GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Llama-Guard-3-8B:int8-mp1 | meta-llama/Llama-Guard-3-8B-INT8 | 128K | 1 GPU, each >= 10GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
| Prompt-Guard-86M | meta-llama/Prompt-Guard-86M | 128K | 1 GPU, each >= 1GB VRAM |
+---------------------------------------+---------------------------------------------+----------------+----------------------------+
+----------------------------------+------------------------------------------+----------------+
| Model Descriptor | HuggingFace Repo | Context Length |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-8B | meta-llama/Llama-3.1-8B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-70B | meta-llama/Llama-3.1-70B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-405B:bf16-mp8 | meta-llama/Llama-3.1-405B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-405B | meta-llama/Llama-3.1-405B-FP8 | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-405B:bf16-mp16 | meta-llama/Llama-3.1-405B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-8B-Instruct | meta-llama/Llama-3.1-8B-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-70B-Instruct | meta-llama/Llama-3.1-70B-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-405B-Instruct:bf16-mp8 | meta-llama/Llama-3.1-405B-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-405B-Instruct | meta-llama/Llama-3.1-405B-Instruct-FP8 | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.1-405B-Instruct:bf16-mp16 | meta-llama/Llama-3.1-405B-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-1B | meta-llama/Llama-3.2-1B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-3B | meta-llama/Llama-3.2-3B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-11B-Vision | meta-llama/Llama-3.2-11B-Vision | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-90B-Vision | meta-llama/Llama-3.2-90B-Vision | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-1B-Instruct | meta-llama/Llama-3.2-1B-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-3B-Instruct | meta-llama/Llama-3.2-3B-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-11B-Vision-Instruct | meta-llama/Llama-3.2-11B-Vision-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama3.2-90B-Vision-Instruct | meta-llama/Llama-3.2-90B-Vision-Instruct | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-11B-Vision | meta-llama/Llama-Guard-3-11B-Vision | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-1B:int4-mp1 | meta-llama/Llama-Guard-3-1B-INT4 | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-1B | meta-llama/Llama-Guard-3-1B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-8B | meta-llama/Llama-Guard-3-8B | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-3-8B:int8-mp1 | meta-llama/Llama-Guard-3-8B-INT8 | 128K |
+----------------------------------+------------------------------------------+----------------+
| Prompt-Guard-86M | meta-llama/Prompt-Guard-86M | 128K |
+----------------------------------+------------------------------------------+----------------+
| Llama-Guard-2-8B | meta-llama/Llama-Guard-2-8B | 4K |
+----------------------------------+------------------------------------------+----------------+
</pre>
To download models, you can use the llama download command.
#### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
Here is an example download command to get the 8B/70B Instruct model. You will need META_URL which can be obtained from [here](https://llama.meta.com/docs/getting_the_models/meta/)
Here is an example download command to get the 3B-Instruct/11B-Vision-Instruct model. You will need META_URL which can be obtained from [here](https://llama.meta.com/docs/getting_the_models/meta/)
Download the required checkpoints using the following commands:
```bash
# download the 8B model, this can be run on a single GPU
llama download --source meta --model-id Meta-Llama3.1-8B-Instruct --meta-url META_URL
llama download --source meta --model-id Llama3.2-3B-Instruct --meta-url META_URL
# you can also get the 70B model, this will require 8 GPUs however
llama download --source meta --model-id Meta-Llama3.1-70B-Instruct --meta-url META_URL
llama download --source meta --model-id Llama3.2-11B-Vision-Instruct --meta-url META_URL
# llama-agents have safety enabled by default. For this, you will need
# safety models -- Llama-Guard and Prompt-Guard
@ -124,7 +148,7 @@ The `llama model` command helps you explore the models interface.
### 2.1 Subcommands
1. `download`: Download the model from different sources. (meta, huggingface)
2. `list`: Lists all the models available for download with hardware requirements to deploy the models.
3. `template`: <TODO: What is a template?>
3. `prompt-format`: Show llama model message formats.
4. `describe`: Describes all the properties of the model.
### 2.2 Sample Usage
@ -135,7 +159,7 @@ The `llama model` command helps you explore the models interface.
llama model --help
```
<pre style="font-family: monospace;">
usage: llama model [-h] {download,list,template,describe} ...
usage: llama model [-h] {download,list,prompt-format,describe} ...
Work with llama models
@ -143,124 +167,67 @@ options:
-h, --help show this help message and exit
model_subcommands:
{download,list,template,describe}
{download,list,prompt-format,describe}
</pre>
You can use the describe command to know more about a model:
```
llama model describe -m Meta-Llama3.1-8B-Instruct
llama model describe -m Llama3.2-3B-Instruct
```
### 2.3 Describe
<pre style="font-family: monospace;">
+-----------------------------+---------------------------------------+
| Model | Meta- |
| | Llama3.1-8B-Instruct |
+-----------------------------+---------------------------------------+
| HuggingFace ID | meta-llama/Meta-Llama-3.1-8B-Instruct |
+-----------------------------+---------------------------------------+
| Description | Llama 3.1 8b instruct model |
+-----------------------------+---------------------------------------+
+-----------------------------+----------------------------------+
| Model | Llama3.2-3B-Instruct |
+-----------------------------+----------------------------------+
| HuggingFace ID | meta-llama/Llama-3.2-3B-Instruct |
+-----------------------------+----------------------------------+
| Description | Llama 3.2 3b instruct model |
+-----------------------------+----------------------------------+
| Context Length | 128K tokens |
+-----------------------------+---------------------------------------+
+-----------------------------+----------------------------------+
| Weights format | bf16 |
+-----------------------------+---------------------------------------+
+-----------------------------+----------------------------------+
| Model params.json | { |
| | "dim": 4096, |
| | "n_layers": 32, |
| | "n_heads": 32, |
| | "dim": 3072, |
| | "n_layers": 28, |
| | "n_heads": 24, |
| | "n_kv_heads": 8, |
| | "vocab_size": 128256, |
| | "ffn_dim_multiplier": 1.3, |
| | "multiple_of": 1024, |
| | "ffn_dim_multiplier": 1.0, |
| | "multiple_of": 256, |
| | "norm_eps": 1e-05, |
| | "rope_theta": 500000.0, |
| | "use_scaled_rope": true |
| | } |
+-----------------------------+---------------------------------------+
+-----------------------------+----------------------------------+
| Recommended sampling params | { |
| | "strategy": "top_p", |
| | "temperature": 1.0, |
| | "top_p": 0.9, |
| | "top_k": 0 |
| | } |
+-----------------------------+---------------------------------------+
+-----------------------------+----------------------------------+
</pre>
### 2.4 Template
You can even run `llama model template` see all of the templates and their tokens:
### 2.4 Prompt Format
You can even run `llama model prompt-format` see all of the templates and their tokens:
```
llama model template
llama model prompt-format -m Llama3.2-3B-Instruct
```
<p align="center">
<img width="719" alt="image" src="https://github.com/user-attachments/assets/c5332026-8c0b-4edc-b438-ec60cd7ca554">
</p>
<pre style="font-family: monospace;">
+-----------+---------------------------------+
| Role | Template Name |
+-----------+---------------------------------+
| user | user-default |
| assistant | assistant-builtin-tool-call |
| assistant | assistant-custom-tool-call |
| assistant | assistant-default |
| system | system-builtin-and-custom-tools |
| system | system-builtin-tools-only |
| system | system-custom-tools-only |
| system | system-default |
| tool | tool-success |
| tool | tool-failure |
+-----------+---------------------------------+
</pre>
And fetch an example by passing it to `--name`:
```
llama model template --name tool-success
```
<pre style="font-family: monospace;">
+----------+----------------------------------------------------------------+
| Name | tool-success |
+----------+----------------------------------------------------------------+
| Template | <|start_header_id|>ipython<|end_header_id|> |
| | |
| | completed |
| | [stdout]{"results":["something |
| | something"]}[/stdout]<|eot_id|> |
| | |
+----------+----------------------------------------------------------------+
| Notes | Note ipython header and [stdout] |
+----------+----------------------------------------------------------------+
</pre>
Or:
```
llama model template --name system-builtin-tools-only
```
<pre style="font-family: monospace;">
+----------+--------------------------------------------+
| Name | system-builtin-tools-only |
+----------+--------------------------------------------+
| Template | <|start_header_id|>system<|end_header_id|> |
| | |
| | Environment: ipython |
| | Tools: brave_search, wolfram_alpha |
| | |
| | Cutting Knowledge Date: December 2023 |
| | Today Date: 21 August 2024 |
| | <|eot_id|> |
| | |
+----------+--------------------------------------------+
| Notes | |
+----------+--------------------------------------------+
</pre>
These commands can help understand the model interface and how prompts / messages are formatted for various scenarios.
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
**NOTE**: Outputs in terminal are color printed to show special tokens.
## Step 3: Building, and Configuring Llama Stack Distributions
- Please see our [Getting Started](getting_started.md) guide for details.
- Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
### Step 3.1 Build
In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:

BIN
docs/dog.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

325
docs/getting_started.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -1,9 +1,70 @@
# llama-stack
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU)
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
The Llama Stack defines and standardizes the building blocks needed to bring generative AI applications to market. These blocks span the entire development lifecycle: from model training and fine-tuning, through product evaluation, to invoking AI agents in production. Beyond definition, we're developing open-source versions and partnering with cloud providers, ensuring developers can assemble AI solutions using consistent, interlocking pieces across platforms. The ultimate goal is to accelerate innovation in the AI space.
The Stack APIs are rapidly improving, but still very much work in progress and we invite feedback as well as direct contributions.
## APIs
The Llama Stack consists of the following set of APIs:
- Inference
- Safety
- Memory
- Agentic System
- Evaluation
- Post Training
- Synthetic Data Generation
- Reward Scoring
Each of the APIs themselves is a collection of REST endpoints.
## API Providers
A Provider is what makes the API real -- they provide the actual implementation backing the API.
As an example, for Inference, we could have the implementation be backed by open source libraries like `[ torch | vLLM | TensorRT ]` as possible options.
A provider can also be just a pointer to a remote REST service -- for example, cloud providers or dedicated inference providers could serve these APIs.
## Llama Stack Distribution
A Distribution is where APIs and Providers are assembled together to provide a consistent whole to the end application developer. You can mix-and-match providers -- some could be backed by local code and some could be remote. As a hobbyist, you can serve a small model locally, but can choose a cloud provider for a large model. Regardless, the higher level APIs your app needs to work with don't need to change at all. You can even imagine moving across the server / mobile-device boundary as well always using the same uniform set of APIs for developing Generative AI applications.
## Installation
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
If you want to install from source:
```bash
mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
conda create -n stack python=3.10
conda activate stack
cd llama-stack
$CONDA_PREFIX/bin/pip install -e .
```
# Getting Started
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes!
You may also checkout this [notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) for trying out out demo scripts.
## Quick Cheatsheet
- Quick 3 line command to build and start a LlamaStack server using our Meta Reference implementation for all API endpoints with `conda` as build type.
@ -12,7 +73,7 @@ This guides allows you to quickly get started with building and running a Llama
```
llama stack build
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-llama-stack
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack
> Enter the image type you want your distribution to be built with (docker or conda): conda
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
@ -24,47 +85,57 @@ llama stack build
> (Optional) Enter a short description for your Llama Stack distribution:
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-llama-stack/my-local-llama-stack-build.yaml
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml
You can now run `llama stack configure my-local-stack`
```
**`llama stack configure`**
- Run `llama stack configure <name>` with the name you have previously defined in `build` step.
```
llama stack configure my-local-llama-stack
llama stack configure <name>
```
- You will be prompted to enter configurations for your Llama Stack
Configuring APIs to serve...
Enter comma-separated list of APIs to serve:
```
$ llama stack configure my-local-stack
Could not find my-local-stack. Trying conda build name instead...
Configuring API `inference`...
Configuring provider `meta-reference`...
Enter value for model (default: Meta-Llama3.1-8B-Instruct) (required):
=== Configuring provider `meta-reference` for API inference...
Enter value for model (default: Llama3.1-8B-Instruct) (required):
Do you want to configure quantization? (y/n): n
Enter value for torch_seed (optional):
Enter value for max_seq_len (required): 4096
Enter value for max_seq_len (default: 4096) (required):
Enter value for max_batch_size (default: 1) (required):
Configuring API `safety`...
Configuring provider `meta-reference`...
Configuring API `safety`...
=== Configuring provider `meta-reference` for API safety...
Do you want to configure llama_guard_shield? (y/n): n
Do you want to configure prompt_guard_shield? (y/n): n
Configuring API `agents`...
=== Configuring provider `meta-reference` for API agents...
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite):
Configuring SqliteKVStoreConfig:
Enter value for namespace (optional):
Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required):
Configuring provider `meta-reference`...
Configuring API `memory`...
=== Configuring provider `meta-reference` for API memory...
> Please enter the supported memory bank type your provider has for memory: vector
Configuring provider `meta-reference`...
Configuring API `telemetry`...
=== Configuring provider `meta-reference` for API telemetry...
Configuring provider `meta-reference`...
> YAML configuration has been written to ~/.llama/builds/conda/my-local-llama-stack-run.yaml.
You can now run `llama stack run my-local-llama-stack --port PORT` or `llama stack run ~/.llama/builds/conda/my-local-llama-stack-run.yaml --port PORT
> YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml.
You can now run `llama stack run my-local-stack --port PORT`
```
**`llama stack run`**
- Run `llama stack run <name>` with the name you have previously defined.
```
llama stack run my-local-llama-stack
llama stack run my-local-stack
...
> initializing model parallel with size 1
@ -126,7 +197,7 @@ llama stack build
Running the command above will allow you to fill in the configuration to build your Llama Stack distribution, you will see the following outputs.
```
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-llama-stack
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): 8b-instruct
> Enter the image type you want your distribution to be built with (docker or conda): conda
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
@ -138,9 +209,14 @@ Running the command above will allow you to fill in the configuration to build y
> (Optional) Enter a short description for your Llama Stack distribution:
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-llama-stack/my-local-llama-stack-build.yaml
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-llama-stack/8b-instruct-build.yaml
```
**Ollama (optional)**
If you plan to use Ollama for inference, you'll need to install the server [via these instructions](https://ollama.com/download).
#### Building from templates
- To build from alternative API providers, we provide distribution templates for users to get started building a distribution backed by different providers.
@ -236,7 +312,7 @@ llama stack configure [ <name> | <docker-image-name> | <path/to/name.build.yaml>
- Run `docker images` to check list of available images on your machine.
```
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml
$ llama stack configure 8b-instruct
Configuring API: inference (meta-reference)
Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required):
@ -284,13 +360,13 @@ Note that all configurations as well as models are stored in `~/.llama`
Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack configure` step.
```
llama stack run ~/.llama/builds/conda/8b-instruct-run.yaml
llama stack run 8b-instruct
```
You should see the Llama Stack server start and print the APIs that it is supporting
```
$ llama stack run ~/.llama/builds/local/conda/8b-instruct.yaml
$ llama stack run 8b-instruct
> initializing model parallel with size 1
> initializing ddp with size 1
@ -357,4 +433,4 @@ Similarly you can test safety (if you configured llama-guard and/or prompt-guard
python -m llama_stack.apis.safety.client localhost 5000
```
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps) repo.

View file

@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 10:56:42.866760"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
},
"servers": [
{
@ -2027,10 +2027,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -2053,6 +2063,35 @@
"tool_calls"
]
},
"ImageMedia": {
"type": "object",
"properties": {
"image": {
"oneOf": [
{
"type": "object",
"properties": {
"format": {
"type": "string"
},
"format_description": {
"type": "string"
}
},
"additionalProperties": false,
"title": "This class represents an image object. To create"
},
{
"$ref": "#/components/schemas/URL"
}
]
}
},
"additionalProperties": false,
"required": [
"image"
]
},
"SamplingParams": {
"type": "object",
"properties": {
@ -2115,10 +2154,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -2267,6 +2316,28 @@
"required": {
"type": "boolean",
"default": true
},
"default": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"additionalProperties": false,
@ -2278,7 +2349,8 @@
"type": "string",
"enum": [
"json",
"function_tag"
"function_tag",
"python_list"
],
"title": "This Enum refers to the prompt format for calling custom / zero shot tools",
"description": "`json` --\n Refers to the json format for calling tools.\n The json format takes the form like\n {\n \"type\": \"function\",\n \"function\" : {\n \"name\": \"function_name\",\n \"description\": \"function_description\",\n \"parameters\": {...}\n }\n }\n\n`function_tag` --\n This is an example of how you could define\n your own user defined format for making tool calls.\n The function_tag format looks like this,\n <function=function_name>(parameters)</function>\n\nThe detailed prompts for each of these formats are added to llama cli"
@ -2309,10 +2381,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -2326,6 +2408,11 @@
"content"
]
},
"URL": {
"type": "string",
"format": "uri",
"pattern": "^(https?://|file://|data:)"
},
"UserMessage": {
"type": "object",
"properties": {
@ -2339,10 +2426,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -2352,10 +2449,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -2455,10 +2562,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -2714,10 +2831,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -3298,11 +3425,6 @@
"engine"
]
},
"URL": {
"type": "string",
"format": "uri",
"pattern": "^(https?://|file://|data:)"
},
"WolframAlphaToolDefinition": {
"type": "object",
"properties": {
@ -3396,10 +3518,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
},
{
@ -3731,10 +3863,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -3888,10 +4030,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -4316,10 +4468,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -4515,10 +4677,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
},
{
@ -5407,10 +5579,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -5460,10 +5642,20 @@
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ImageMedia"
}
]
}
}
]
@ -6027,32 +6219,32 @@
}
],
"tags": [
{
"name": "Inference"
},
{
"name": "Shields"
},
{
"name": "Models"
},
{
"name": "MemoryBanks"
},
{
"name": "SyntheticDataGeneration"
"name": "BatchInference"
},
{
"name": "RewardScoring"
},
{
"name": "PostTraining"
"name": "SyntheticDataGeneration"
},
{
"name": "Agents"
},
{
"name": "MemoryBanks"
},
{
"name": "Safety"
},
{
"name": "Evaluations"
"name": "Models"
},
{
"name": "Inference"
},
{
"name": "Memory"
@ -6061,14 +6253,14 @@
"name": "Telemetry"
},
{
"name": "Agents"
},
{
"name": "BatchInference"
"name": "PostTraining"
},
{
"name": "Datasets"
},
{
"name": "Evaluations"
},
{
"name": "BuiltinTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
@ -6077,6 +6269,10 @@
"name": "CompletionMessage",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/CompletionMessage\" />"
},
{
"name": "ImageMedia",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ImageMedia\" />"
},
{
"name": "SamplingParams",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingParams\" />"
@ -6117,6 +6313,10 @@
"name": "ToolResponseMessage",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ToolResponseMessage\" />"
},
{
"name": "URL",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/URL\" />"
},
{
"name": "UserMessage",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserMessage\" />"
@ -6221,10 +6421,6 @@
"name": "SearchToolDefinition",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SearchToolDefinition\" />"
},
{
"name": "URL",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/URL\" />"
},
{
"name": "WolframAlphaToolDefinition",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/WolframAlphaToolDefinition\" />"
@ -6661,6 +6857,7 @@
"FunctionCallToolDefinition",
"GetAgentsSessionRequest",
"GetDocumentsRequest",
"ImageMedia",
"InferenceStep",
"InsertDocumentsRequest",
"LogEventRequest",

View file

@ -210,8 +210,11 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
- $ref: '#/components/schemas/URL'
mime_type:
@ -273,8 +276,11 @@ components:
items:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
type: array
logprobs:
@ -441,8 +447,11 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
role:
const: assistant
@ -466,8 +475,11 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
logprobs:
additionalProperties: false
@ -742,8 +754,11 @@ components:
items:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
type: array
model:
@ -893,6 +908,23 @@ components:
required:
- document_ids
type: object
ImageMedia:
additionalProperties: false
properties:
image:
oneOf:
- additionalProperties: false
properties:
format:
type: string
format_description:
type: string
title: This class represents an image object. To create
type: object
- $ref: '#/components/schemas/URL'
required:
- image
type: object
InferenceStep:
additionalProperties: false
properties:
@ -1041,8 +1073,11 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
- $ref: '#/components/schemas/URL'
document_id:
@ -1108,8 +1143,11 @@ components:
inserted_context:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
memory_bank_ids:
items:
@ -1545,8 +1583,11 @@ components:
query:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
required:
- bank_id
@ -1562,8 +1603,11 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
document_id:
type: string
@ -2067,8 +2111,11 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
role:
const: system
@ -2203,6 +2250,14 @@ components:
ToolParamDefinition:
additionalProperties: false
properties:
default:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description:
type: string
param_type:
@ -2225,6 +2280,7 @@ components:
enum:
- json
- function_tag
- python_list
title: This Enum refers to the prompt format for calling custom / zero shot
tools
type: string
@ -2236,8 +2292,11 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
tool_name:
oneOf:
@ -2256,8 +2315,11 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
role:
const: ipython
@ -2451,14 +2513,20 @@ components:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
context:
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
- items:
type: string
oneOf:
- type: string
- $ref: '#/components/schemas/ImageMedia'
type: array
role:
const: user
@ -2501,7 +2569,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-09-23 10:56:42.866760"
\ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3739,25 +3807,27 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: Inference
- name: Shields
- name: Models
- name: MemoryBanks
- name: SyntheticDataGeneration
- name: BatchInference
- name: RewardScoring
- name: PostTraining
- name: SyntheticDataGeneration
- name: Agents
- name: MemoryBanks
- name: Safety
- name: Evaluations
- name: Models
- name: Inference
- name: Memory
- name: Telemetry
- name: Agents
- name: BatchInference
- name: PostTraining
- name: Datasets
- name: Evaluations
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
/>
name: CompletionMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/ImageMedia" />
name: ImageMedia
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" />
name: SamplingParams
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingStrategy"
@ -3790,6 +3860,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolResponseMessage"
/>
name: ToolResponseMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/URL" />
name: URL
- description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" />
name: UserMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/BatchChatCompletionRequest"
@ -3876,8 +3948,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/SearchToolDefinition"
/>
name: SearchToolDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/URL" />
name: URL
- description: <SchemaDefinition schemaRef="#/components/schemas/WolframAlphaToolDefinition"
/>
name: WolframAlphaToolDefinition
@ -4233,6 +4303,7 @@ x-tagGroups:
- FunctionCallToolDefinition
- GetAgentsSessionRequest
- GetDocumentsRequest
- ImageMedia
- InferenceStep
- InsertDocumentsRequest
- LogEventRequest

View file

@ -94,14 +94,16 @@ class AgentsClient(Agents):
print(f"Error with parsing or validation: {e}")
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
):
agent_config = AgentConfig(
model="Meta-Llama3.1-8B-Instruct",
model=model,
instructions="You are a helpful assistant",
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
sampling_params=SamplingParams(temperature=0.6, top_p=0.9),
tools=tool_definitions,
tool_choice=ToolChoice.auto,
tool_prompt_format=ToolPromptFormat.function_tag,
tool_prompt_format=tool_prompt_format,
enable_session_persistence=False,
)
@ -130,7 +132,8 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
log.print()
async def run_main(host: str, port: int):
async def run_llama_3_1(host: str, port: int):
model = "Llama3.1-8B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
@ -167,10 +170,11 @@ async def run_main(host: str, port: int):
"Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?",
]
await _run_agent(api, tool_definitions, user_prompts)
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_rag(host: str, port: int):
async def run_llama_3_2_rag(host: str, port: int):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
urls = [
@ -206,12 +210,71 @@ async def run_rag(host: str, port: int):
"Tell me briefly about llama3 and torchtune",
]
await _run_agent(api, tool_definitions, user_prompts, attachments)
await _run_agent(
api, model, tool_definitions, ToolPromptFormat.json, user_prompts, attachments
)
def main(host: str, port: int, rag: bool = False):
fn = run_rag if rag else run_main
asyncio.run(fn(host, port))
async def run_llama_3_2(host: str, port: int):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models
tool_definitions = [
FunctionCallToolDefinition(
function_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="bool",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
),
FunctionCallToolDefinition(
function_name="make_web_search",
description="Search the web / internet for more realtime information",
parameters={
"query": ToolParamDefinition(
param_type="str",
description="the query to search for",
required=True,
),
},
),
]
user_prompts = [
"Who are you?",
"what is the 100th prime number?",
"Who was 44th President of USA?",
# multiple tool calls in a single prompt
"What is the boiling point of polyjuicepotion and pinkponklyjuice?",
]
await _run_agent(
api, model, tool_definitions, ToolPromptFormat.python_list, user_prompts
)
def main(host: str, port: int, run_type: str):
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
"rag_llama_3_2",
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
fn = {
"tools_llama_3_1": run_llama_3_1,
"tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag,
}
asyncio.run(fn[run_type](host, port))
if __name__ == "__main__":

View file

@ -10,6 +10,10 @@ from typing import Any, AsyncGenerator, List, Optional
import fire
import httpx
from llama_models.llama3.api.datatypes import ImageMedia, URL
from PIL import Image as PIL_Image
from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403
@ -105,7 +109,7 @@ async def run_main(host: str, port: int, stream: bool):
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Meta-Llama3.1-8B-Instruct",
model="Llama3.1-8B-Instruct",
messages=[message],
stream=stream,
)
@ -113,7 +117,33 @@ async def run_main(host: str, port: int, stream: bool):
log.print()
def main(host: str, port: int, stream: bool = True):
async def run_mm_main(host: str, port: int, stream: bool, path: str):
client = InferenceClient(f"http://{host}:{port}")
with open(path, "rb") as f:
img = PIL_Image.open(f).convert("RGB")
message = UserMessage(
content=[
ImageMedia(image=URL(uri=f"file://{path}")),
# ImageMedia(image=img),
"Describe this image in two sentences",
],
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct",
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
if mm:
asyncio.run(run_mm_main(host, port, stream, file))
else:
asyncio.run(run_main(host, port, stream))

View file

@ -7,11 +7,11 @@
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field
@json_schema_type

View file

@ -51,6 +51,11 @@ class SafetyClient(Safety):
),
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps(
{
"together_api_key": "1882f9a484fc7c6ce3e4dc90272d5db52346c93838daab3d704803181f396b22"
}
),
},
timeout=20,
)

View file

@ -44,7 +44,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--source",
choices=["meta", "huggingface"],
required=True,
default="meta",
)
parser.add_argument(
"--model-id",
@ -116,7 +116,7 @@ def _hf_download(
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
parser.error(f"Repository '{args.repo_id}' not found on the Hugging Face Hub.")
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
except Exception as e:
parser.error(e)

View file

@ -9,7 +9,7 @@ import argparse
from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.template import ModelTemplate
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.subcommand import Subcommand
@ -30,5 +30,5 @@ class ModelParser(Subcommand):
# Add sub-commands
ModelDownload.create(subparsers)
ModelList.create(subparsers)
ModelTemplate.create(subparsers)
ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers)

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import subprocess
import textwrap
from io import StringIO
from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily
from llama_stack.cli.subcommand import Subcommand
class ModelPromptFormat(Subcommand):
"""Llama model cli for describe a model prompt format (message formats)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"prompt-format",
prog="llama model prompt-format",
description="Show llama model message formats",
epilog=textwrap.dedent(
"""
Example:
llama model prompt-format <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-name",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
import pkg_resources
# Only Llama 3.1 and 3.2 are supported
supported_model_ids = [
m
for m in CoreModelId
if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
]
model_str = "\n".join([m.value for m in supported_model_ids])
try:
model_id = CoreModelId(args.model_name)
except ValueError:
raise argparse.ArgumentTypeError(
f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}"
) from None
if model_id not in supported_model_ids:
raise argparse.ArgumentTypeError(
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
) from None
llama_3_1_file = pkg_resources.resource_filename(
"llama_models", "llama3_1/prompt_format.md"
)
llama_3_2_text_file = pkg_resources.resource_filename(
"llama_models", "llama3_2/text_prompt_format.md"
)
llama_3_2_vision_file = pkg_resources.resource_filename(
"llama_models", "llama3_2/vision_prompt_format.md"
)
if model_family(model_id) == ModelFamily.llama3_1:
with open(llama_3_1_file, "r") as f:
content = f.read()
elif model_family(model_id) == ModelFamily.llama3_2:
if is_multimodal(model_id):
with open(llama_3_2_vision_file, "r") as f:
content = f.read()
else:
with open(llama_3_2_text_file, "r") as f:
content = f.read()
render_markdown_to_pager(content)
def render_markdown_to_pager(markdown_content: str):
from rich.console import Console
from rich.markdown import Markdown
from rich.style import Style
from rich.text import Text
class LeftAlignedHeaderMarkdown(Markdown):
def parse_header(self, token):
level = token.type.count("h")
content = Text(token.content)
header_style = Style(color="bright_blue", bold=True)
header = Text(f"{'#' * level} ", style=header_style) + content
self.add_text(header)
# Render the Markdown
md = LeftAlignedHeaderMarkdown(markdown_content)
# Capture the rendered output
output = StringIO()
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
console.print(md)
rendered_content = output.getvalue()
# Pipe to pager
pager = subprocess.Popen(["less", "-R"], stdin=subprocess.PIPE)
pager.communicate(input=rendered_content.encode())

View file

@ -1,113 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from termcolor import colored
from llama_stack.cli.subcommand import Subcommand
class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"template",
prog="llama model template",
description="Show llama model message formats",
epilog=textwrap.dedent(
"""
Example:
llama model template <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd)
def _prompt_type(self, value):
from llama_models.llama3.api.datatypes import ToolPromptFormat
try:
return ToolPromptFormat(value.lower())
except ValueError:
raise argparse.ArgumentTypeError(
f"{value} is not a valid ToolPromptFormat. Choose from {', '.join(t.value for t in ToolPromptFormat)}"
) from None
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-family",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
self.parser.add_argument(
"--name",
type=str,
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
required=False,
)
self.parser.add_argument(
"--format",
type=str,
help="ToolPromptFormat (json or function_tag). This flag is used to print the template in a specific formats.",
required=False,
default="json",
)
self.parser.add_argument(
"--raw",
action="store_true",
help="If set to true, don't pretty-print into a table. Useful to copy-paste.",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
from llama_models.llama3.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from llama_stack.cli.table import print_table
if args.name:
tool_prompt_format = self._prompt_type(args.format)
template, tokens_info = render_jinja_template(args.name, tool_prompt_format)
rendered = ""
for tok, is_special in tokens_info:
if is_special:
rendered += colored(tok, "yellow", attrs=["bold"])
else:
rendered += tok
if not args.raw:
rendered = rendered.replace("\n", "\n")
print_table(
[
(
"Name",
colored(template.template_name, "white", attrs=["bold"]),
),
("Template", rendered),
("Notes", template.notes),
],
separate_rows=True,
)
else:
print("Template: ", template.template_name)
print("=" * 40)
print(rendered)
else:
templates = list_jinja_templates()
headers = ["Role", "Template Name"]
print_table(
[(t.role, t.template_name) for t in templates],
headers,
)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -8,6 +8,7 @@
DOCKER_BINARY=${DOCKER_BINARY:-docker}
DOCKER_OPTS=${DOCKER_OPTS:-}
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
set -euo pipefail
@ -37,6 +38,20 @@ port="$1"
shift
set -x
if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
$DOCKER_BINARY run $DOCKER_OPTS -it \
-p $port:$port \
-v "$yaml_config:/app/config.yaml" \
-v "$LLAMA_CHECKPOINT_DIR:/root/.llama" \
--gpus=all \
$docker_image \
python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \
--port $port "$@"
fi
if [ -z "$LLAMA_CHECKPOINT_DIR" ]; then
$DOCKER_BINARY run $DOCKER_OPTS -it \
-p $port:$port \
-v "$yaml_config:/app/config.yaml" \
@ -44,3 +59,4 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \
python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \
--port $port "$@"
fi

View file

@ -15,14 +15,16 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Meta-Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Meta-Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
}
@ -106,7 +108,7 @@ class FireworksInferenceAdapter(Inference):
logprobs=logprobs,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request)

View file

@ -16,14 +16,16 @@ from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
# "Llama3.1-8B-Instruct": "llama3.1",
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
}
@ -115,7 +117,7 @@ class OllamaInferenceAdapter(Inference):
logprobs=logprobs,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model)

View file

@ -14,7 +14,9 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import TGIImplConfig
@ -95,7 +97,7 @@ class TGIAdapter(Inference):
logprobs=logprobs,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)

View file

@ -15,14 +15,16 @@ from llama_models.sku_list import resolve_model
from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo",
}
@ -110,7 +112,7 @@ class TogetherInferenceAdapter(Inference):
# accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request)
together_model = self.resolve_together_model(request.model)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
if not request.stream:
# TODO: might need to add back an async here

View file

@ -0,0 +1,548 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 60;
objects = {
/* Begin PBXBuildFile section */
5C03561F2CA3AB9600E3BB46 /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5C03561E2CA3AB9600E3BB46 /* LlamaStackClient */; };
5C5B6E212CA3D89F00AF6130 /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5C5B6E202CA3D89F00AF6130 /* LlamaStackClient */; };
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */ = {isa = PBXBuildFile; fileRef = 5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */; settings = {ATTRIBUTES = (Public, ); }; };
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6742CA1F45800E958D0 /* executorch_debug */; };
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; };
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */; };
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */; };
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */; };
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */; };
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6922CA1F7D000E958D0 /* Stencil */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 036CAF9D2BB1444500D6C2D5;
remoteInfo = LLaMA;
};
5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 03729ED52BB1F8DE00152F2E;
remoteInfo = LLaMARunner;
};
5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6982CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmark;
};
5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6992CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmarkTests;
};
/* End PBXContainerItemProxy section */
/* Begin PBXCopyFilesBuildPhase section */
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */
5CCBC6082CA1F04A00E958D0 /* LocalInference.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LocalInference.framework; sourceTree = BUILT_PRODUCTS_DIR; };
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LocalInference.h; sourceTree = "<group>"; };
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = LLaMA.xcodeproj; path = "executorch/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj"; sourceTree = "<group>"; };
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PromptTemplate.swift; sourceTree = "<group>"; };
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LocalInference.swift; sourceTree = "<group>"; };
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Parsing.swift; sourceTree = "<group>"; };
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SystemPrompts.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
5CCBC6052CA1F04A00E958D0 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
5C03561F2CA3AB9600E3BB46 /* LlamaStackClient in Frameworks */,
5C5B6E212CA3D89F00AF6130 /* LlamaStackClient in Frameworks */,
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */,
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */,
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
5CCBC5FE2CA1F04A00E958D0 = {
isa = PBXGroup;
children = (
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */,
5CCBC60A2CA1F04A00E958D0 /* LocalInference */,
5CCBC6092CA1F04A00E958D0 /* Products */,
5CCBC6852CA1F64A00E958D0 /* Frameworks */,
);
sourceTree = "<group>";
};
5CCBC6092CA1F04A00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC6082CA1F04A00E958D0 /* LocalInference.framework */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC60A2CA1F04A00E958D0 /* LocalInference */ = {
isa = PBXGroup;
children = (
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */,
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */,
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */,
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */,
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */,
);
path = LocalInference;
sourceTree = "<group>";
};
5CCBC6772CA1F63F00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */,
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */,
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */,
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC6852CA1F64A00E958D0 /* Frameworks */ = {
isa = PBXGroup;
children = (
);
name = Frameworks;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
5CCBC6032CA1F04A00E958D0 /* Headers */ = {
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
5CCBC6072CA1F04A00E958D0 /* LocalInference */ = {
isa = PBXNativeTarget;
buildConfigurationList = 5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInference" */;
buildPhases = (
5CCBC6032CA1F04A00E958D0 /* Headers */,
5CCBC6042CA1F04A00E958D0 /* Sources */,
5CCBC6052CA1F04A00E958D0 /* Frameworks */,
5CCBC6062CA1F04A00E958D0 /* Resources */,
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */,
);
buildRules = (
);
dependencies = (
);
name = LocalInference;
packageProductDependencies = (
5CCBC6742CA1F45800E958D0 /* executorch_debug */,
5CCBC6922CA1F7D000E958D0 /* Stencil */,
5C03561E2CA3AB9600E3BB46 /* LlamaStackClient */,
5C5B6E202CA3D89F00AF6130 /* LlamaStackClient */,
);
productName = LocalInferenceProvider;
productReference = 5CCBC6082CA1F04A00E958D0 /* LocalInference.framework */;
productType = "com.apple.product-type.framework";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
5CCBC5FF2CA1F04A00E958D0 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastUpgradeCheck = 1540;
TargetAttributes = {
5CCBC6072CA1F04A00E958D0 = {
CreatedOnToolsVersion = 15.4;
LastSwiftMigration = 1540;
};
};
};
buildConfigurationList = 5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInference" */;
compatibilityVersion = "Xcode 14.0";
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 5CCBC5FE2CA1F04A00E958D0;
packageReferences = (
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */,
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */,
5C5B6E1F2CA3D89F00AF6130 /* XCLocalSwiftPackageReference "internal-llama-stack-client-swift" */,
);
productRefGroup = 5CCBC6092CA1F04A00E958D0 /* Products */;
projectDirPath = "";
projectReferences = (
{
ProductGroup = 5CCBC6772CA1F63F00E958D0 /* Products */;
ProjectRef = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
},
);
projectRoot = "";
targets = (
5CCBC6072CA1F04A00E958D0 /* LocalInference */,
);
};
/* End PBXProject section */
/* Begin PBXReferenceProxy section */
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMA.app;
remoteRef = 5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */ = {
isa = PBXReferenceProxy;
fileType = wrapper.framework;
path = LLaMARunner.framework;
remoteRef = 5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMAPerfBenchmark.app;
remoteRef = 5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */ = {
isa = PBXReferenceProxy;
fileType = wrapper.cfbundle;
path = LLaMAPerfBenchmarkTests.xctest;
remoteRef = 5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
/* End PBXReferenceProxy section */
/* Begin PBXResourcesBuildPhase section */
5CCBC6062CA1F04A00E958D0 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
5CCBC6042CA1F04A00E958D0 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */,
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */,
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */,
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
5CCBC60D2CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Debug;
};
5CCBC60E2CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Release;
};
5CCBC6102CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
5CCBC6112CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInference" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC60D2CA1F04A00E958D0 /* Debug */,
5CCBC60E2CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInference" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC6102CA1F04A00E958D0 /* Debug */,
5CCBC6112CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCLocalSwiftPackageReference section */
5C5B6E1F2CA3D89F00AF6130 /* XCLocalSwiftPackageReference "internal-llama-stack-client-swift" */ = {
isa = XCLocalSwiftPackageReference;
relativePath = "internal-llama-stack-client-swift";
};
/* End XCLocalSwiftPackageReference section */
/* Begin XCRemoteSwiftPackageReference section */
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/pytorch/executorch";
requirement = {
branch = latest;
kind = branch;
};
};
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/stencilproject/Stencil";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.15.1;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
5C03561E2CA3AB9600E3BB46 /* LlamaStackClient */ = {
isa = XCSwiftPackageProductDependency;
productName = LlamaStackClient;
};
5C5B6E202CA3D89F00AF6130 /* LlamaStackClient */ = {
isa = XCSwiftPackageProductDependency;
productName = LlamaStackClient;
};
5CCBC6742CA1F45800E958D0 /* executorch_debug */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */;
productName = executorch_debug;
};
5CCBC6922CA1F7D000E958D0 /* Stencil */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */;
productName = Stencil;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = 5CCBC5FF2CA1F04A00E958D0 /* Project object */;
}

View file

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>IDEDidComputeMac32BitWarning</key>
<true/>
</dict>
</plist>

View file

@ -0,0 +1,16 @@
//
// LocalInference.h
// LocalInference
//
// Created by Dalton Flanagan on 9/23/24.
//
#import <Foundation/Foundation.h>
//! Project version number for LocalInference.
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
//! Project version string for LocalInference.
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>

View file

@ -0,0 +1,167 @@
import Foundation
import LLaMARunner
import LlamaStackClient
class RunnerHolder: ObservableObject {
var runner: Runner?
}
public class LocalInference: Inference {
private var runnerHolder = RunnerHolder()
private let runnerQueue: DispatchQueue
public init (queue: DispatchQueue) {
runnerQueue = queue
}
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
runnerHolder.runner = runnerHolder.runner ?? Runner(
modelPath: modelPath,
tokenizerPath: tokenizerPath
)
runnerQueue.async {
let runner = self.runnerHolder.runner
do {
try runner!.load()
completion(.success(()))
} catch let loadError {
print("error: " + loadError.localizedDescription)
completion(.failure(loadError))
}
}
}
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
return AsyncStream { continuation in
runnerQueue.async {
do {
var tokens: [String] = []
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
var stopReason: Components.Schemas.StopReason? = nil
var buffer = ""
var ipython = false
var echoDropped = false
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
buffer += token
// HACK: Workaround until LlamaRunner exposes echo param
if (!echoDropped) {
if (buffer.hasPrefix(prompt)) {
buffer = String(buffer.dropFirst(prompt.count))
echoDropped = true
}
return
}
tokens.append(token)
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
ipython = true
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(""),
parse_status: Components.Schemas.ToolCallParseStatus.started
)
),
event_type: .progress
)
)
)
if (buffer.starts(with: "<|python_tag|>")) {
buffer = String(buffer.dropFirst("<|python_tag|>".count))
}
}
// TODO: Non-streaming lobprobs
var text = ""
if token == "<|eot_id|>" {
stopReason = Components.Schemas.StopReason.end_of_turn
} else if token == "<|eom_id|>" {
stopReason = Components.Schemas.StopReason.end_of_message
} else {
text = token
}
var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload
if ipython {
delta = .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(text),
parse_status: .in_progress
))
} else {
delta = .case1(text)
}
if stopReason == nil {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: delta,
event_type: .progress
)
)
)
}
}
if stopReason == nil {
stopReason = Components.Schemas.StopReason.out_of_tokens
}
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
// TODO: non-streaming support
let didParseToolCalls = message.tool_calls.count > 0
if ipython && !didParseToolCalls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)),
event_type: .progress
)
// TODO: stopReason
)
)
}
for toolCall in message.tool_calls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .ToolCall(toolCall),
parse_status: .success
)),
event_type: .progress
)
// TODO: stopReason
)
)
}
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .case1(""),
event_type: .complete
)
// TODO: stopReason
)
)
}
catch (let error) {
print("Inference error: " + error.localizedDescription)
}
}
}
}
}

View file

@ -0,0 +1,235 @@
import Foundation
import LlamaStackClient
func encodeHeader(role: String) -> String {
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
}
func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload]) -> String {
var prompt = ""
prompt.append("<|begin_of_text|>")
for message in messages {
let msg = encodeMessage(message: message)
prompt += msg
}
prompt.append(encodeHeader(role: "assistant"))
return prompt
}
func getRole(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
switch (message) {
case .UserMessage(let m):
return m.role.rawValue
case .SystemMessage(let m):
return m.role.rawValue
case .ToolResponseMessage(let m):
return m.role.rawValue
case .CompletionMessage(let m):
return m.role.rawValue
}
}
func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
var prompt = encodeHeader(role: getRole(message: message))
switch (message) {
case .CompletionMessage(let m):
if (m.tool_calls.count > 0) {
prompt += "<|python_tag|>"
}
default:
break
}
func _processContent(_ content: Any) -> String {
func _process(_ c: Any) {
if let str = c as? String {
prompt += str
}
}
if let str = content as? String {
_process(str)
} else if let list = content as? [Any] {
for c in list {
_process(c)
}
}
return ""
}
switch (message) {
case .UserMessage(let m):
prompt += _processContent(m.content)
case .SystemMessage(let m):
prompt += _processContent(m.content)
case .ToolResponseMessage(let m):
prompt += _processContent(m.content)
case .CompletionMessage(let m):
prompt += _processContent(m.content)
}
var eom = false
switch (message) {
case .UserMessage(let m):
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
case .case2(let c):
prompt += _processContent(c)
}
case .CompletionMessage(let m):
// TODO: Support encoding past tool call history
// for t in m.tool_calls {
// _processContent(t.)
//}
eom = m.stop_reason == Components.Schemas.StopReason.end_of_message
case .SystemMessage(_):
break
case .ToolResponseMessage(_):
break
}
if (eom) {
prompt += "<|eom_id|>"
} else {
prompt += "<|eot_id|>"
}
return prompt
}
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] {
var existingMessages = request.messages
var existingSystemMessage: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload?
// TODO: Existing system message
var messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] = []
let defaultGen = SystemDefaultGenerator()
let defaultTemplate = defaultGen.gen()
var sysContent = ""
// TODO: Built-in tools
sysContent += try defaultTemplate.render()
messages.append(.SystemMessage(Components.Schemas.SystemMessage(
content: .case1(sysContent),
role: .system))
)
if request.tools?.isEmpty == false {
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
let toolGen = FunctionTagCustomToolGenerator()
let toolTemplate = try toolGen.gen(customTools: request.tools!)
let tools = try toolTemplate.render()
messages.append(.UserMessage(Components.Schemas.UserMessage(
content: .case1(tools),
role: .user)
))
}
messages.append(contentsOf: existingMessages)
return messages
}
struct FunctionCall {
let name: String
let params: [String: Any]
}
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
guard input.hasPrefix("[") && input.hasSuffix("]") else {
return []
}
do {
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
var result: [Components.Schemas.ToolCall] = []
for call in calls {
guard let nameEndIndex = call.firstIndex(of: "("),
let paramsStartIndex = call.firstIndex(of: "{"),
let paramsEndIndex = call.lastIndex(of: "}") else {
return []
}
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
guard let data = paramsString.data(using: .utf8),
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
return []
}
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
for (param_name, param) in params {
switch (param) {
case let value as String:
props[param_name] = .case1(value)
case let value as Int:
props[param_name] = .case2(value)
case let value as Double:
props[param_name] = .case3(value)
case let value as Bool:
props[param_name] = .case4(value)
default:
return []
}
}
result.append(
Components.Schemas.ToolCall(
arguments: .init(additionalProperties: props),
call_id: UUID().uuidString,
tool_name: .case2(name) // custom_tool
)
)
}
return result.isEmpty ? [] : result
} catch {
return []
}
}
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.StopReason) -> Components.Schemas.CompletionMessage {
var content = tokens
let roles = ["user", "system", "assistant"]
for role in roles {
let headerStr = encodeHeader(role: role)
if content.hasPrefix(headerStr) {
content = String(content.dropFirst(encodeHeader(role: role).count))
}
}
if content.hasPrefix("<|python_tag|>") {
content = String(content.dropFirst("<|python_tag|>".count))
}
if content.hasSuffix("<|eot_id|>") {
content = String(content.dropLast("<|eot_id|>".count))
} else {
content = String(content.dropLast("<|eom_id|>".count))
}
return Components.Schemas.CompletionMessage(
content: .case1(content),
role: .assistant,
stop_reason: stopReason,
tool_calls: maybeExtractCustomToolCalls(input: content)
)
}

View file

@ -0,0 +1,12 @@
import Foundation
import Stencil
public struct PromptTemplate {
let template: String
let data: [String: Any]
public func render() throws -> String {
let template = Template(templateString: self.template)
return try template.render(self.data)
}
}

View file

@ -0,0 +1,91 @@
import Foundation
import LlamaStackClient
func convertToNativeSwiftType(_ value: Any) -> Any {
switch value {
case let number as NSNumber:
if CFGetTypeID(number) == CFBooleanGetTypeID() {
return number.boolValue
}
if floor(number.doubleValue) == number.doubleValue {
return number.intValue
}
return number.doubleValue
case let string as String:
return string
case let array as [Any]:
return array.map(convertToNativeSwiftType)
case let dict as [String: Any]:
return dict.mapValues(convertToNativeSwiftType)
case is NSNull:
return NSNull()
default:
return value
}
}
public class SystemDefaultGenerator {
public init() {}
public func gen() -> PromptTemplate {
let templateStr = """
Cutting Knowledge Date: December 2023
Today Date: {{ today }}
"""
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "dd MMMM yyyy"
return PromptTemplate(
template: templateStr,
data: ["today": dateFormatter.string(from: Date())]
)
}
}
public class FunctionTagCustomToolGenerator {
public init() {}
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
// TODO: required params
// TODO: {{#unless @last}},{{/unless}}
let templateStr = """
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{% for t in custom_tools %}
{
"name": "{{t.tool_name}}",
"description": "{{t.description}}",
"parameters": {
"type": "dict",
"properties": { {{t.parameters}} }
}
{{/let}}
{% endfor -%}
]
"""
let encoder = JSONEncoder()
return PromptTemplate(
template: templateStr,
data: ["custom_tools": try customTools.map {
let data = try encoder.encode($0)
let obj = try JSONSerialization.jsonObject(with: data)
return convertToNativeSwiftType(obj)
}]
)
}
}

View file

@ -0,0 +1,541 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 60;
objects = {
/* Begin PBXBuildFile section */
5CADC71A2CA471CC007662D2 /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5CADC7192CA471CC007662D2 /* LlamaStackClient */; };
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */ = {isa = PBXBuildFile; fileRef = 5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */; settings = {ATTRIBUTES = (Public, ); }; };
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6742CA1F45800E958D0 /* executorch_debug */; };
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; };
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */; };
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */; };
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */; };
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */; };
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6922CA1F7D000E958D0 /* Stencil */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 036CAF9D2BB1444500D6C2D5;
remoteInfo = LLaMA;
};
5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 03729ED52BB1F8DE00152F2E;
remoteInfo = LLaMARunner;
};
5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6982CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmark;
};
5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6992CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmarkTests;
};
/* End PBXContainerItemProxy section */
/* Begin PBXCopyFilesBuildPhase section */
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */
5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LocalInferenceImpl.framework; sourceTree = BUILT_PRODUCTS_DIR; };
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LocalInference.h; sourceTree = "<group>"; };
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = LLaMA.xcodeproj; path = "executorch/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj"; sourceTree = "<group>"; };
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PromptTemplate.swift; sourceTree = "<group>"; };
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LocalInference.swift; sourceTree = "<group>"; };
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Parsing.swift; sourceTree = "<group>"; };
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SystemPrompts.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
5CCBC6052CA1F04A00E958D0 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
5CADC71A2CA471CC007662D2 /* LlamaStackClient in Frameworks */,
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */,
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */,
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
5CCBC5FE2CA1F04A00E958D0 = {
isa = PBXGroup;
children = (
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */,
5CCBC60A2CA1F04A00E958D0 /* LocalInferenceImpl */,
5CCBC6092CA1F04A00E958D0 /* Products */,
5CCBC6852CA1F64A00E958D0 /* Frameworks */,
);
sourceTree = "<group>";
};
5CCBC6092CA1F04A00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC60A2CA1F04A00E958D0 /* LocalInferenceImpl */ = {
isa = PBXGroup;
children = (
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */,
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */,
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */,
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */,
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */,
);
path = LocalInferenceImpl;
sourceTree = "<group>";
};
5CCBC6772CA1F63F00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */,
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */,
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */,
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC6852CA1F64A00E958D0 /* Frameworks */ = {
isa = PBXGroup;
children = (
);
name = Frameworks;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
5CCBC6032CA1F04A00E958D0 /* Headers */ = {
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
5CCBC6072CA1F04A00E958D0 /* LocalInferenceImpl */ = {
isa = PBXNativeTarget;
buildConfigurationList = 5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInferenceImpl" */;
buildPhases = (
5CCBC6032CA1F04A00E958D0 /* Headers */,
5CCBC6042CA1F04A00E958D0 /* Sources */,
5CCBC6052CA1F04A00E958D0 /* Frameworks */,
5CCBC6062CA1F04A00E958D0 /* Resources */,
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */,
);
buildRules = (
);
dependencies = (
);
name = LocalInferenceImpl;
packageProductDependencies = (
5CCBC6742CA1F45800E958D0 /* executorch_debug */,
5CCBC6922CA1F7D000E958D0 /* Stencil */,
5CADC7192CA471CC007662D2 /* LlamaStackClient */,
);
productName = LocalInferenceProvider;
productReference = 5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */;
productType = "com.apple.product-type.framework";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
5CCBC5FF2CA1F04A00E958D0 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastUpgradeCheck = 1540;
TargetAttributes = {
5CCBC6072CA1F04A00E958D0 = {
CreatedOnToolsVersion = 15.4;
LastSwiftMigration = 1540;
};
};
};
buildConfigurationList = 5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInferenceImpl" */;
compatibilityVersion = "Xcode 14.0";
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 5CCBC5FE2CA1F04A00E958D0;
packageReferences = (
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */,
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */,
5CADC7182CA471CC007662D2 /* XCLocalSwiftPackageReference "internal-llama-stack-client-swift" */,
);
productRefGroup = 5CCBC6092CA1F04A00E958D0 /* Products */;
projectDirPath = "";
projectReferences = (
{
ProductGroup = 5CCBC6772CA1F63F00E958D0 /* Products */;
ProjectRef = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
},
);
projectRoot = "";
targets = (
5CCBC6072CA1F04A00E958D0 /* LocalInferenceImpl */,
);
};
/* End PBXProject section */
/* Begin PBXReferenceProxy section */
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMA.app;
remoteRef = 5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */ = {
isa = PBXReferenceProxy;
fileType = wrapper.framework;
path = LLaMARunner.framework;
remoteRef = 5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMAPerfBenchmark.app;
remoteRef = 5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */ = {
isa = PBXReferenceProxy;
fileType = wrapper.cfbundle;
path = LLaMAPerfBenchmarkTests.xctest;
remoteRef = 5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
/* End PBXReferenceProxy section */
/* Begin PBXResourcesBuildPhase section */
5CCBC6062CA1F04A00E958D0 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
5CCBC6042CA1F04A00E958D0 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */,
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */,
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */,
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
5CCBC60D2CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Debug;
};
5CCBC60E2CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Release;
};
5CCBC6102CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
5CCBC6112CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInferenceImpl" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC60D2CA1F04A00E958D0 /* Debug */,
5CCBC60E2CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInferenceImpl" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC6102CA1F04A00E958D0 /* Debug */,
5CCBC6112CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCLocalSwiftPackageReference section */
5CADC7182CA471CC007662D2 /* XCLocalSwiftPackageReference "internal-llama-stack-client-swift" */ = {
isa = XCLocalSwiftPackageReference;
relativePath = "internal-llama-stack-client-swift";
};
/* End XCLocalSwiftPackageReference section */
/* Begin XCRemoteSwiftPackageReference section */
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/pytorch/executorch";
requirement = {
branch = latest;
kind = branch;
};
};
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/stencilproject/Stencil";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.15.1;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
5CADC7192CA471CC007662D2 /* LlamaStackClient */ = {
isa = XCSwiftPackageProductDependency;
productName = LlamaStackClient;
};
5CCBC6742CA1F45800E958D0 /* executorch_debug */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */;
productName = executorch_debug;
};
5CCBC6922CA1F7D000E958D0 /* Stencil */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */;
productName = Stencil;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = 5CCBC5FF2CA1F04A00E958D0 /* Project object */;
}

View file

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>IDEDidComputeMac32BitWarning</key>
<true/>
</dict>
</plist>

View file

@ -0,0 +1,16 @@
//
// LocalInference.h
// LocalInference
//
// Created by Dalton Flanagan on 9/23/24.
//
#import <Foundation/Foundation.h>
//! Project version number for LocalInference.
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
//! Project version string for LocalInference.
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>

View file

@ -0,0 +1,167 @@
import Foundation
import LLaMARunner
import LlamaStackClient
class RunnerHolder: ObservableObject {
var runner: Runner?
}
public class LocalInference: Inference {
private var runnerHolder = RunnerHolder()
private let runnerQueue: DispatchQueue
public init (queue: DispatchQueue) {
runnerQueue = queue
}
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
runnerHolder.runner = runnerHolder.runner ?? Runner(
modelPath: modelPath,
tokenizerPath: tokenizerPath
)
runnerQueue.async {
let runner = self.runnerHolder.runner
do {
try runner!.load()
completion(.success(()))
} catch let loadError {
print("error: " + loadError.localizedDescription)
completion(.failure(loadError))
}
}
}
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
return AsyncStream { continuation in
runnerQueue.async {
do {
var tokens: [String] = []
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
var stopReason: Components.Schemas.StopReason? = nil
var buffer = ""
var ipython = false
var echoDropped = false
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
buffer += token
// HACK: Workaround until LlamaRunner exposes echo param
if (!echoDropped) {
if (buffer.hasPrefix(prompt)) {
buffer = String(buffer.dropFirst(prompt.count))
echoDropped = true
}
return
}
tokens.append(token)
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
ipython = true
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(""),
parse_status: Components.Schemas.ToolCallParseStatus.started
)
),
event_type: .progress
)
)
)
if (buffer.starts(with: "<|python_tag|>")) {
buffer = String(buffer.dropFirst("<|python_tag|>".count))
}
}
// TODO: Non-streaming lobprobs
var text = ""
if token == "<|eot_id|>" {
stopReason = Components.Schemas.StopReason.end_of_turn
} else if token == "<|eom_id|>" {
stopReason = Components.Schemas.StopReason.end_of_message
} else {
text = token
}
var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload
if ipython {
delta = .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(text),
parse_status: .in_progress
))
} else {
delta = .case1(text)
}
if stopReason == nil {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: delta,
event_type: .progress
)
)
)
}
}
if stopReason == nil {
stopReason = Components.Schemas.StopReason.out_of_tokens
}
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
// TODO: non-streaming support
let didParseToolCalls = message.tool_calls.count > 0
if ipython && !didParseToolCalls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)),
event_type: .progress
)
// TODO: stopReason
)
)
}
for toolCall in message.tool_calls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .ToolCall(toolCall),
parse_status: .success
)),
event_type: .progress
)
// TODO: stopReason
)
)
}
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .case1(""),
event_type: .complete
)
// TODO: stopReason
)
)
}
catch (let error) {
print("Inference error: " + error.localizedDescription)
}
}
}
}
}

View file

@ -0,0 +1,235 @@
import Foundation
import LlamaStackClient
func encodeHeader(role: String) -> String {
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
}
func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload]) -> String {
var prompt = ""
prompt.append("<|begin_of_text|>")
for message in messages {
let msg = encodeMessage(message: message)
prompt += msg
}
prompt.append(encodeHeader(role: "assistant"))
return prompt
}
func getRole(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
switch (message) {
case .UserMessage(let m):
return m.role.rawValue
case .SystemMessage(let m):
return m.role.rawValue
case .ToolResponseMessage(let m):
return m.role.rawValue
case .CompletionMessage(let m):
return m.role.rawValue
}
}
func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
var prompt = encodeHeader(role: getRole(message: message))
switch (message) {
case .CompletionMessage(let m):
if (m.tool_calls.count > 0) {
prompt += "<|python_tag|>"
}
default:
break
}
func _processContent(_ content: Any) -> String {
func _process(_ c: Any) {
if let str = c as? String {
prompt += str
}
}
if let str = content as? String {
_process(str)
} else if let list = content as? [Any] {
for c in list {
_process(c)
}
}
return ""
}
switch (message) {
case .UserMessage(let m):
prompt += _processContent(m.content)
case .SystemMessage(let m):
prompt += _processContent(m.content)
case .ToolResponseMessage(let m):
prompt += _processContent(m.content)
case .CompletionMessage(let m):
prompt += _processContent(m.content)
}
var eom = false
switch (message) {
case .UserMessage(let m):
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
case .case2(let c):
prompt += _processContent(c)
}
case .CompletionMessage(let m):
// TODO: Support encoding past tool call history
// for t in m.tool_calls {
// _processContent(t.)
//}
eom = m.stop_reason == Components.Schemas.StopReason.end_of_message
case .SystemMessage(_):
break
case .ToolResponseMessage(_):
break
}
if (eom) {
prompt += "<|eom_id|>"
} else {
prompt += "<|eot_id|>"
}
return prompt
}
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] {
var existingMessages = request.messages
var existingSystemMessage: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload?
// TODO: Existing system message
var messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] = []
let defaultGen = SystemDefaultGenerator()
let defaultTemplate = defaultGen.gen()
var sysContent = ""
// TODO: Built-in tools
sysContent += try defaultTemplate.render()
messages.append(.SystemMessage(Components.Schemas.SystemMessage(
content: .case1(sysContent),
role: .system))
)
if request.tools?.isEmpty == false {
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
let toolGen = FunctionTagCustomToolGenerator()
let toolTemplate = try toolGen.gen(customTools: request.tools!)
let tools = try toolTemplate.render()
messages.append(.UserMessage(Components.Schemas.UserMessage(
content: .case1(tools),
role: .user)
))
}
messages.append(contentsOf: existingMessages)
return messages
}
struct FunctionCall {
let name: String
let params: [String: Any]
}
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
guard input.hasPrefix("[") && input.hasSuffix("]") else {
return []
}
do {
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
var result: [Components.Schemas.ToolCall] = []
for call in calls {
guard let nameEndIndex = call.firstIndex(of: "("),
let paramsStartIndex = call.firstIndex(of: "{"),
let paramsEndIndex = call.lastIndex(of: "}") else {
return []
}
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
guard let data = paramsString.data(using: .utf8),
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
return []
}
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
for (param_name, param) in params {
switch (param) {
case let value as String:
props[param_name] = .case1(value)
case let value as Int:
props[param_name] = .case2(value)
case let value as Double:
props[param_name] = .case3(value)
case let value as Bool:
props[param_name] = .case4(value)
default:
return []
}
}
result.append(
Components.Schemas.ToolCall(
arguments: .init(additionalProperties: props),
call_id: UUID().uuidString,
tool_name: .case2(name) // custom_tool
)
)
}
return result.isEmpty ? [] : result
} catch {
return []
}
}
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.StopReason) -> Components.Schemas.CompletionMessage {
var content = tokens
let roles = ["user", "system", "assistant"]
for role in roles {
let headerStr = encodeHeader(role: role)
if content.hasPrefix(headerStr) {
content = String(content.dropFirst(encodeHeader(role: role).count))
}
}
if content.hasPrefix("<|python_tag|>") {
content = String(content.dropFirst("<|python_tag|>".count))
}
if content.hasSuffix("<|eot_id|>") {
content = String(content.dropLast("<|eot_id|>".count))
} else {
content = String(content.dropLast("<|eom_id|>".count))
}
return Components.Schemas.CompletionMessage(
content: .case1(content),
role: .assistant,
stop_reason: stopReason,
tool_calls: maybeExtractCustomToolCalls(input: content)
)
}

View file

@ -0,0 +1,12 @@
import Foundation
import Stencil
public struct PromptTemplate {
let template: String
let data: [String: Any]
public func render() throws -> String {
let template = Template(templateString: self.template)
return try template.render(self.data)
}
}

View file

@ -0,0 +1,91 @@
import Foundation
import LlamaStackClient
func convertToNativeSwiftType(_ value: Any) -> Any {
switch value {
case let number as NSNumber:
if CFGetTypeID(number) == CFBooleanGetTypeID() {
return number.boolValue
}
if floor(number.doubleValue) == number.doubleValue {
return number.intValue
}
return number.doubleValue
case let string as String:
return string
case let array as [Any]:
return array.map(convertToNativeSwiftType)
case let dict as [String: Any]:
return dict.mapValues(convertToNativeSwiftType)
case is NSNull:
return NSNull()
default:
return value
}
}
public class SystemDefaultGenerator {
public init() {}
public func gen() -> PromptTemplate {
let templateStr = """
Cutting Knowledge Date: December 2023
Today Date: {{ today }}
"""
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "dd MMMM yyyy"
return PromptTemplate(
template: templateStr,
data: ["today": dateFormatter.string(from: Date())]
)
}
}
public class FunctionTagCustomToolGenerator {
public init() {}
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
// TODO: required params
// TODO: {{#unless @last}},{{/unless}}
let templateStr = """
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{% for t in custom_tools %}
{
"name": "{{t.tool_name}}",
"description": "{{t.description}}",
"parameters": {
"type": "dict",
"properties": { {{t.parameters}} }
}
{{/let}}
{% endfor -%}
]
"""
let encoder = JSONEncoder()
return PromptTemplate(
template: templateStr,
data: ["custom_tools": try customTools.map {
let data = try encoder.encode($0)
let obj = try JSONSerialization.jsonObject(with: data)
return convertToNativeSwiftType(obj)
}]
)
}
}

View file

@ -0,0 +1,109 @@
# LocalInference
LocalInference provides a local inference implementation powered by [executorch](https://github.com/pytorch/executorch/).
Llama Stack currently supports on-device inference for iOS with Android coming soon. You can run on-device inference on Android today using [executorch](https://github.com/pytorch/executorch/tree/main/examples/demo-apps/android/LlamaDemo), PyTorchs on-device inference library.
## Installation
We're working on making LocalInference easier to set up. For now, you'll need to import it via `.xcframework`:
1. Clone the executorch submodule in this repo and its dependencies: `git submodule update --init --recursive`
1. Install [Cmake](https://cmake.org/) for the executorch build`
1. Drag `LocalInference.xcodeproj` into your project
1. Add `LocalInference` as a framework in your app target
1. Add a package dependency on https://github.com/pytorch/executorch (branch latest)
1. Add all the kernels / backends from executorch (but not exectuorch itself!) as frameworks in your app target:
- backend_coreml
- backend_mps
- backend_xnnpack
- kernels_custom
- kernels_optimized
- kernels_portable
- kernels_quantized
1. In "Build Settings" > "Other Linker Flags" > "Any iOS Simulator SDK", add:
```
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
```
1. In "Build Settings" > "Other Linker Flags" > "Any iOS SDK", add:
```
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
-force_load
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
```
## Preparing a model
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md#step-2-prepare-model)
2. Bundle the `.pte` and `tokenizer.model` file into your app
## Using LocalInference
1. Instantiate LocalInference with a DispatchQueue. Optionally, pass it into your agents service:
```swift
init () {
runnerQueue = DispatchQueue(label: "org.meta.llamastack")
inferenceService = LocalInferenceService(queue: runnerQueue)
agentsService = LocalAgentsService(inference: inferenceService)
}
```
2. Before making any inference calls, load your model from your bundle:
```swift
let mainBundle = Bundle.main
inferenceService.loadModel(
modelPath: mainBundle.url(forResource: "llama32_1b_spinquant", withExtension: "pte"),
tokenizerPath: mainBundle.url(forResource: "tokenizer", withExtension: "model"),
completion: {_ in } // use to handle load failures
)
```
3. Make inference calls (or agents calls) as you normally would with LlamaStack:
```
for await chunk in try await agentsService.initAndCreateTurn(
messages: [
.UserMessage(Components.Schemas.UserMessage(
content: .case1("Call functions as needed to handle any actions in the following text:\n\n" + text),
role: .user))
]
) {
```
## Troubleshooting
If you receive errors like "missing package product" or "invalid checksum", try cleaning the build folder and resetting the Swift package cache:
(Opt+Click) Product > Clean Build Folder Immediately
```
rm -rf \
~/Library/org.swift.swiftpm \
~/Library/Caches/org.swift.swiftpm \
~/Library/Caches/com.apple.dt.Xcode \
~/Library/Developer/Xcode/DerivedData
```

View file

@ -398,7 +398,11 @@ class ChatAgent(ShieldRunnerMixin):
color = "yellow"
else:
color = None
cprint(f"{str(msg)}", color=color)
if len(str(msg)) > 1000:
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else:
msg_str = str(msg)
cprint(f"{msg_str}", color=color)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -466,6 +470,13 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens
# If tool calls are parsed successfully,
# if content is not made null the tool call str will also be in the content
# and tokens will have tool call syntax included twice
if tool_calls:
content = ""
message = CompletionMessage(
content=content,
stop_reason=stop_reason,

View file

@ -10,13 +10,14 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from termcolor import cprint # noqa: F401
from llama_stack.apis.inference import * # noqa: F403

View file

@ -16,7 +16,7 @@ from pydantic import BaseModel, Field, field_validator
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
default="Meta-Llama3.1-8B-Instruct",
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
quantization: Optional[QuantizationConfig] = None
@ -30,7 +30,7 @@ class MetaReferenceImplConfig(BaseModel):
permitted_models = [
m.descriptor()
for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1
if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
or m.core_model_id == CoreModelId.llama_guard_3_8b
]
if model not in permitted_models:
@ -42,14 +42,9 @@ class MetaReferenceImplConfig(BaseModel):
@property
def model_parallel_size(self) -> int:
# HUGE HACK ALERT: this will be fixed when we move inference configuration
# HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
gpu_count = 1
resolved = resolve_model(self.model)
assert resolved is not None
descriptor = resolved.descriptor().lower()
if "-70b" in descriptor or "-405b" in descriptor:
gpu_count = 8
return gpu_count
return resolved.pth_file_count

View file

@ -24,21 +24,31 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from termcolor import cprint
from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
@ -134,6 +144,10 @@ class Llama:
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config)
@ -142,6 +156,10 @@ class Llama:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
@ -167,7 +185,11 @@ class Llama:
) -> Generator:
params = self.model.params
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
@ -183,6 +205,21 @@ class Llama:
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = isinstance(self.model, CrossAttentionTransformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
@ -206,6 +243,18 @@ class Llama:
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
prev_pos, cur_pos, dtype=torch.long, device="cuda"
)
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
@ -222,6 +271,18 @@ class Llama:
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
@ -248,7 +309,7 @@ class Llama:
def text_completion(
self,
prompt: str,
content: InterleavedTextMedia,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@ -262,10 +323,10 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
model_input = self.formatter.encode_content(content)
yield from self.generate(
model_input=ModelInput(tokens=prompt_tokens),
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,

View file

@ -21,7 +21,9 @@ from llama_stack.apis.inference import (
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
@ -57,7 +59,7 @@ class MetaReferenceInferenceImpl(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = [],
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
@ -70,14 +72,14 @@ class MetaReferenceInferenceImpl(Inference):
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(

View file

@ -7,11 +7,11 @@
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, _deps):
async def get_provider_impl(config: SafetyConfig, deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
impl = MetaReferenceSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -31,7 +31,10 @@ class LlamaGuardShieldConfig(BaseModel):
permitted_models = [
m.descriptor()
for m in safety_models()
if m.core_model_id == CoreModelId.llama_guard_3_8b
if (
m.core_model_id
in {CoreModelId.llama_guard_3_8b, CoreModelId.llama_guard_3_11b_vision}
)
]
if model not in permitted_models:
raise ValueError(

View file

@ -7,8 +7,10 @@
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety):
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
cfg = cfg.llama_guard_shield
assert (
cfg.llama_guard_shield is not None
cfg is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
elif typ == MetaReferenceShieldType.jailbreak_shield:
assert (
cfg.prompt_guard_shield is not None

View file

@ -9,9 +9,8 @@ import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from llama_stack.apis.inference import * # noqa: F403
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -100,39 +99,17 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
@ -140,18 +117,12 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda"
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -212,26 +183,21 @@ class LlamaGuardShield(ShieldBase):
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in self.inference_api.chat_completion(
model=self.model,
messages=[
UserMessage(content=prompt),
],
stream=True,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress:
assert isinstance(event.delta, str)
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response

View file

@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]:
"fairscale",
"fbgemm-gpu==0.8.0",
"torch",
"torchvision",
"transformers",
"zmq",
],
@ -75,15 +76,4 @@ def available_providers() -> List[ProviderSpec]:
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="bedrock",
pip_packages=[
"boto3",
],
module="llama_stack.providers.adapters.inference.bedrock",
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),
),
]

View file

@ -21,13 +21,15 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
provider_id="meta-reference",
pip_packages=[
"accelerate",
"codeshield",
"torch",
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
api_dependencies=[
Api.inference,
],
),
remote_provider_spec(
api=Api.safety,

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_models.datatypes import ModelFamily
from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_models.sku_list import resolve_model
def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
model = resolve_model(request.model)
if model is None:
cprint(f"Could not resolve model {request.model}", color="red")
return request.messages
if model.model_family not in [ModelFamily.llama3_1, ModelFamily.llama3_2]:
cprint(f"Model family {model.model_family} not llama 3_1 or 3_2", color="red")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return augment_messages_for_tools_llama_3_1(request)
elif model.model_family == ModelFamily.llama3_2:
return augment_messages_for_tools_llama_3_2(request)
else:
return request.messages
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
if isinstance(t.tool_name, str):
custom_tools.append(t)
else:
builtin_tools.append(t)
tool_template = None
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
sys_content += tool_template.render()
sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
if request.tool_prompt_format != ToolPromptFormat.python_list:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
tool_gen = PythonListCustomToolGenerator()
tool_template = tool_gen.gen(custom_tools)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message:
sys_content += interleaved_text_media_as_str(
existing_system_message.content, sep="\n"
)
messages.append(SystemMessage(content=sys_content))
# Add back existing messages from the request
messages += existing_messages
return messages

View file

@ -1,84 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages

View file

@ -7,4 +7,5 @@ prompt-toolkit
python-dotenv
pydantic
requests
rich
termcolor

View file

@ -8,9 +8,9 @@ import unittest
from llama_models.llama3.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages
from llama_stack.inference.augment_messages import augment_messages_for_tools
MODEL = "Meta-Llama3.1-8B-Instruct"
MODEL = "Llama3.1-8B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content),
],
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_prompt_format=ToolPromptFormat.json,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
),
],
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -59,7 +59,7 @@ class TestE2E(unittest.IsolatedAsyncioTestCase):
host=TestE2E.HOST,
port=TestE2E.PORT,
custom_tools=custom_tools,
# model="Meta-Llama3.1-70B-Instruct", # Defaults to 8B
# model="Llama3.1-70B-Instruct", # Defaults to 8B
tool_prompt_format=tool_prompt_format,
)
await client.create_session(__file__)

View file

@ -9,31 +9,15 @@
import asyncio
import os
import textwrap
import unittest
from datetime import datetime
from llama_models.llama3.api.datatypes import (
BuiltinTool,
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.inference.api import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.meta_reference.config import MetaReferenceImplConfig
from llama_stack.inference.meta_reference.inference import get_provider_impl
MODEL = "Meta-Llama3.1-8B-Instruct"
MODEL = "Llama3.1-8B-Instruct"
HELPER_MSG = """
This test needs llama-3.1-8b-instruct models.
Please donwload using the llama cli
@ -45,11 +29,10 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token <
class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
# This runs the async setup function
asyncio.run(cls.asyncSetUpClass())
@classmethod
async def asyncSetUpClass(cls):
async def asyncSetUpClass(cls): # noqa
# assert model exists on local
model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/")
assert os.path.isdir(model_dir), HELPER_MSG
@ -67,11 +50,10 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod
def tearDownClass(cls):
# This runs the async teardown function
asyncio.run(cls.asyncTearDownClass())
@classmethod
async def asyncTearDownClass(cls):
async def asyncTearDownClass(cls): # noqa
await cls.api.shutdown()
async def asyncSetUp(self):

View file

@ -4,26 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import textwrap
import unittest
from datetime import datetime
from llama_models.llama3.api.datatypes import (
BuiltinTool,
SamplingParams,
SamplingStrategy,
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.inference.api import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.ollama.config import OllamaImplConfig
from llama_stack.inference.ollama.ollama import get_provider_impl
@ -52,7 +36,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
),
},
)
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
self.valid_supported_model = "Llama3.1-8B-Instruct"
async def asyncTearDown(self):
await self.api.shutdown()
@ -272,7 +256,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
ollama_model = self.api.resolve_ollama_model(self.valid_supported_model)
self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16")
invalid_model = "Meta-Llama3.1-8B"
invalid_model = "Llama3.1-8B"
with self.assertRaisesRegex(
AssertionError, f"Unsupported model: {invalid_model}"
):