mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
rebase on top of registry
This commit is contained in:
commit
6abef716dd
107 changed files with 4813 additions and 3587 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ xcuserdata/
|
||||||
Package.resolved
|
Package.resolved
|
||||||
*.pte
|
*.pte
|
||||||
*.ipynb_checkpoints*
|
*.ipynb_checkpoints*
|
||||||
|
.idea
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Contributing to Llama-Models
|
# Contributing to Llama-Stack
|
||||||
We want to make contributing to this project as easy and transparent as
|
We want to make contributing to this project as easy and transparent as
|
||||||
possible.
|
possible.
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ outlined on that page and do not file a public issue.
|
||||||
* ...
|
* ...
|
||||||
|
|
||||||
## Tips
|
## Tips
|
||||||
* If you are developing with a llama-models repository checked out and need your distribution to reflect changes from there, set `LLAMA_MODELS_DIR` to that dir when running any of the `llama` CLI commands.
|
* If you are developing with a llama-stack repository checked out and need your distribution to reflect changes from there, set `LLAMA_STACK_DIR` to that dir when running any of the `llama` CLI commands.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
By contributing to Llama, you agree that your contributions will be licensed
|
By contributing to Llama, you agree that your contributions will be licensed
|
||||||
|
|
17
README.md
17
README.md
|
@ -81,11 +81,24 @@ cd llama-stack
|
||||||
$CONDA_PREFIX/bin/pip install -e .
|
$CONDA_PREFIX/bin/pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
## The Llama CLI
|
## Documentations
|
||||||
|
|
||||||
The `llama` CLI makes it easy to work with the Llama Stack set of tools, including installing and running Distributions, downloading models, studying model prompt formats, etc. Please see the [CLI reference](docs/cli_reference.md) for details. Please see the [Getting Started](docs/getting_started.md) guide for running a Llama Stack server.
|
The `llama` CLI makes it easy to work with the Llama Stack set of tools. Please find the following docs for details.
|
||||||
|
|
||||||
|
* [CLI reference](docs/cli_reference.md)
|
||||||
|
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
|
||||||
|
* [Getting Started](docs/getting_started.md)
|
||||||
|
* Guide to build and run a Llama Stack server.
|
||||||
|
* [Contributing](CONTRIBUTING.md)
|
||||||
|
|
||||||
|
|
||||||
## Llama Stack Client SDK
|
## Llama Stack Client SDK
|
||||||
|
|
||||||
|
| **Language** | **Client SDK** | **Package** |
|
||||||
|
| :----: | :----: | :----: |
|
||||||
|
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [](https://pypi.org/project/llama_stack_client/)
|
||||||
|
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) |
|
||||||
|
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [](https://npmjs.org/package/llama-stack-client)
|
||||||
|
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) |
|
||||||
|
|
||||||
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
|
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
|
||||||
|
|
5
SECURITY.md
Normal file
5
SECURITY.md
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Security Policy
|
||||||
|
|
||||||
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
|
Please report vulnerabilities to our bug bounty program at https://bugbounty.meta.com/
|
|
@ -1,6 +1,6 @@
|
||||||
# Llama CLI Reference
|
# Llama CLI Reference
|
||||||
|
|
||||||
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
|
The `llama` CLI tool helps you setup and use the Llama Stack & agentic systems. It should be available on your path after installing the `llama-stack` package.
|
||||||
|
|
||||||
### Subcommands
|
### Subcommands
|
||||||
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
|
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
|
||||||
|
@ -117,9 +117,9 @@ llama download --source meta --model-id Llama-Guard-3-1B --meta-url META_URL
|
||||||
Essentially, the same commands above work, just replace `--source meta` with `--source huggingface`.
|
Essentially, the same commands above work, just replace `--source meta` with `--source huggingface`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama download --source huggingface --model-id Meta-Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
|
llama download --source huggingface --model-id Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
|
||||||
|
|
||||||
llama download --source huggingface --model-id Meta-Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
|
llama download --source huggingface --model-id Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
|
||||||
|
|
||||||
llama download --source huggingface --model-id Llama-Guard-3-1B --ignore-patterns *original*
|
llama download --source huggingface --model-id Llama-Guard-3-1B --ignore-patterns *original*
|
||||||
llama download --source huggingface --model-id Prompt-Guard-86M --ignore-patterns *original*
|
llama download --source huggingface --model-id Prompt-Guard-86M --ignore-patterns *original*
|
||||||
|
@ -215,9 +215,8 @@ You can even run `llama model prompt-format` see all of the templates and their
|
||||||
```
|
```
|
||||||
llama model prompt-format -m Llama3.2-3B-Instruct
|
llama model prompt-format -m Llama3.2-3B-Instruct
|
||||||
```
|
```
|
||||||
<p align="center">
|

|
||||||
<img width="719" alt="image" src="https://github.com/user-attachments/assets/c5332026-8c0b-4edc-b438-ec60cd7ca554">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
|
||||||
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
||||||
|
@ -230,7 +229,7 @@ You will be shown a Markdown formatted description of the model interface and ho
|
||||||
- Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
|
- Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
|
||||||
|
|
||||||
### Step 3.1 Build
|
### Step 3.1 Build
|
||||||
In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
|
In the following steps, imagine we'll be working with a `Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
|
||||||
- `name`: the name for our distribution (e.g. `8b-instruct`)
|
- `name`: the name for our distribution (e.g. `8b-instruct`)
|
||||||
- `image_type`: our build image type (`conda | docker`)
|
- `image_type`: our build image type (`conda | docker`)
|
||||||
- `distribution_spec`: our distribution specs for specifying API providers
|
- `distribution_spec`: our distribution specs for specifying API providers
|
||||||
|
@ -365,7 +364,7 @@ llama stack configure [ <name> | <docker-image-name> | <path/to/name.build.yaml>
|
||||||
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml
|
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml
|
||||||
|
|
||||||
Configuring API: inference (meta-reference)
|
Configuring API: inference (meta-reference)
|
||||||
Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required):
|
Enter value for model (existing: Llama3.1-8B-Instruct) (required):
|
||||||
Enter value for quantization (optional):
|
Enter value for quantization (optional):
|
||||||
Enter value for torch_seed (optional):
|
Enter value for torch_seed (optional):
|
||||||
Enter value for max_seq_len (existing: 4096) (required):
|
Enter value for max_seq_len (existing: 4096) (required):
|
||||||
|
@ -397,7 +396,7 @@ YAML configuration has been written to ~/.llama/builds/conda/8b-instruct-run.yam
|
||||||
After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/8b-instruct-run.yaml` with the following contents. You may edit this file to change the settings.
|
After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/8b-instruct-run.yaml` with the following contents. You may edit this file to change the settings.
|
||||||
|
|
||||||
As you can see, we did basic configuration above and configured:
|
As you can see, we did basic configuration above and configured:
|
||||||
- inference to run on model `Meta-Llama3.1-8B-Instruct` (obtained from `llama model list`)
|
- inference to run on model `Llama3.1-8B-Instruct` (obtained from `llama model list`)
|
||||||
- Llama Guard safety shield with model `Llama-Guard-3-1B`
|
- Llama Guard safety shield with model `Llama-Guard-3-1B`
|
||||||
- Prompt Guard safety shield with model `Prompt-Guard-86M`
|
- Prompt Guard safety shield with model `Prompt-Guard-86M`
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# llama-stack
|
# llama-stack
|
||||||
|
|
||||||
[](https://pypi.org/project/llama-stack/)
|
[](https://pypi.org/project/llama-stack/)
|
||||||
[](https://discord.gg/TZAAYNVtrU)
|
[](https://discord.gg/llama-stack)
|
||||||
|
|
||||||
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
|
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
|
||||||
|
|
||||||
|
@ -66,8 +66,17 @@ This guides allows you to quickly get started with building and running a Llama
|
||||||
You may also checkout this [notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) for trying out out demo scripts.
|
You may also checkout this [notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) for trying out out demo scripts.
|
||||||
|
|
||||||
## Quick Cheatsheet
|
## Quick Cheatsheet
|
||||||
- Quick 3 line command to build and start a LlamaStack server using our Meta Reference implementation for all API endpoints with `conda` as build type.
|
|
||||||
|
|
||||||
|
#### Via docker
|
||||||
|
```
|
||||||
|
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack-local-gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> `~/.llama` should be the path containing downloaded weights of Llama models.
|
||||||
|
|
||||||
|
|
||||||
|
#### Via conda
|
||||||
**`llama stack build`**
|
**`llama stack build`**
|
||||||
- You'll be prompted to enter build information interactively.
|
- You'll be prompted to enter build information interactively.
|
||||||
```
|
```
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -580,63 +580,6 @@ components:
|
||||||
- uuid
|
- uuid
|
||||||
- dataset
|
- dataset
|
||||||
type: object
|
type: object
|
||||||
CreateMemoryBankRequest:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
config:
|
|
||||||
oneOf:
|
|
||||||
- additionalProperties: false
|
|
||||||
properties:
|
|
||||||
chunk_size_in_tokens:
|
|
||||||
type: integer
|
|
||||||
embedding_model:
|
|
||||||
type: string
|
|
||||||
overlap_size_in_tokens:
|
|
||||||
type: integer
|
|
||||||
type:
|
|
||||||
const: vector
|
|
||||||
default: vector
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- embedding_model
|
|
||||||
- chunk_size_in_tokens
|
|
||||||
type: object
|
|
||||||
- additionalProperties: false
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
const: keyvalue
|
|
||||||
default: keyvalue
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
type: object
|
|
||||||
- additionalProperties: false
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
const: keyword
|
|
||||||
default: keyword
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
type: object
|
|
||||||
- additionalProperties: false
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
const: graph
|
|
||||||
default: graph
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
type: object
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
url:
|
|
||||||
$ref: '#/components/schemas/URL'
|
|
||||||
required:
|
|
||||||
- name
|
|
||||||
- config
|
|
||||||
type: object
|
|
||||||
DPOAlignmentConfig:
|
DPOAlignmentConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -739,14 +682,6 @@ components:
|
||||||
- rank
|
- rank
|
||||||
- alpha
|
- alpha
|
||||||
type: object
|
type: object
|
||||||
DropMemoryBankRequest:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
bank_id:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- bank_id
|
|
||||||
type: object
|
|
||||||
EmbeddingsRequest:
|
EmbeddingsRequest:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -908,6 +843,21 @@ components:
|
||||||
required:
|
required:
|
||||||
- document_ids
|
- document_ids
|
||||||
type: object
|
type: object
|
||||||
|
GraphMemoryBankDef:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
identifier:
|
||||||
|
type: string
|
||||||
|
provider_id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
const: graph
|
||||||
|
default: graph
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- identifier
|
||||||
|
- type
|
||||||
|
type: object
|
||||||
HealthInfo:
|
HealthInfo:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -973,6 +923,36 @@ components:
|
||||||
- bank_id
|
- bank_id
|
||||||
- documents
|
- documents
|
||||||
type: object
|
type: object
|
||||||
|
KeyValueMemoryBankDef:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
identifier:
|
||||||
|
type: string
|
||||||
|
provider_id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
const: keyvalue
|
||||||
|
default: keyvalue
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- identifier
|
||||||
|
- type
|
||||||
|
type: object
|
||||||
|
KeywordMemoryBankDef:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
identifier:
|
||||||
|
type: string
|
||||||
|
provider_id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
const: keyword
|
||||||
|
default: keyword
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- identifier
|
||||||
|
- type
|
||||||
|
type: object
|
||||||
LogEventRequest:
|
LogEventRequest:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1015,66 +995,6 @@ components:
|
||||||
- rank
|
- rank
|
||||||
- alpha
|
- alpha
|
||||||
type: object
|
type: object
|
||||||
MemoryBank:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
bank_id:
|
|
||||||
type: string
|
|
||||||
config:
|
|
||||||
oneOf:
|
|
||||||
- additionalProperties: false
|
|
||||||
properties:
|
|
||||||
chunk_size_in_tokens:
|
|
||||||
type: integer
|
|
||||||
embedding_model:
|
|
||||||
type: string
|
|
||||||
overlap_size_in_tokens:
|
|
||||||
type: integer
|
|
||||||
type:
|
|
||||||
const: vector
|
|
||||||
default: vector
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- embedding_model
|
|
||||||
- chunk_size_in_tokens
|
|
||||||
type: object
|
|
||||||
- additionalProperties: false
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
const: keyvalue
|
|
||||||
default: keyvalue
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
type: object
|
|
||||||
- additionalProperties: false
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
const: keyword
|
|
||||||
default: keyword
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
type: object
|
|
||||||
- additionalProperties: false
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
const: graph
|
|
||||||
default: graph
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
type: object
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
url:
|
|
||||||
$ref: '#/components/schemas/URL'
|
|
||||||
required:
|
|
||||||
- bank_id
|
|
||||||
- name
|
|
||||||
- config
|
|
||||||
type: object
|
|
||||||
MemoryBankDocument:
|
MemoryBankDocument:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1107,41 +1027,6 @@ components:
|
||||||
- content
|
- content
|
||||||
- metadata
|
- metadata
|
||||||
type: object
|
type: object
|
||||||
MemoryBankSpec:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
bank_type:
|
|
||||||
$ref: '#/components/schemas/MemoryBankType'
|
|
||||||
provider_config:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
config:
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
type: object
|
|
||||||
provider_type:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- provider_type
|
|
||||||
- config
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- bank_type
|
|
||||||
- provider_config
|
|
||||||
type: object
|
|
||||||
MemoryBankType:
|
|
||||||
enum:
|
|
||||||
- vector
|
|
||||||
- keyvalue
|
|
||||||
- keyword
|
|
||||||
- graph
|
|
||||||
type: string
|
|
||||||
MemoryRetrievalStep:
|
MemoryRetrievalStep:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1349,36 +1234,18 @@ components:
|
||||||
- value
|
- value
|
||||||
- unit
|
- unit
|
||||||
type: object
|
type: object
|
||||||
Model:
|
ModelDef:
|
||||||
description: The model family and SKU of the model along with other parameters
|
|
||||||
corresponding to the model.
|
|
||||||
ModelServingSpec:
|
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
identifier:
|
||||||
|
type: string
|
||||||
llama_model:
|
llama_model:
|
||||||
$ref: '#/components/schemas/Model'
|
type: string
|
||||||
provider_config:
|
provider_id:
|
||||||
additionalProperties: false
|
type: string
|
||||||
properties:
|
|
||||||
config:
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
type: object
|
|
||||||
provider_type:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- provider_type
|
|
||||||
- config
|
|
||||||
type: object
|
|
||||||
required:
|
required:
|
||||||
|
- identifier
|
||||||
- llama_model
|
- llama_model
|
||||||
- provider_config
|
|
||||||
type: object
|
type: object
|
||||||
OptimizerConfig:
|
OptimizerConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
@ -1554,13 +1421,13 @@ components:
|
||||||
ProviderInfo:
|
ProviderInfo:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
description:
|
provider_id:
|
||||||
type: string
|
type: string
|
||||||
provider_type:
|
provider_type:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
|
- provider_id
|
||||||
- provider_type
|
- provider_type
|
||||||
- description
|
|
||||||
type: object
|
type: object
|
||||||
QLoraFinetuningConfig:
|
QLoraFinetuningConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
@ -1650,6 +1517,56 @@ components:
|
||||||
enum:
|
enum:
|
||||||
- dpo
|
- dpo
|
||||||
type: string
|
type: string
|
||||||
|
RegisterMemoryBankRequest:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
memory_bank:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/VectorMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/KeywordMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/GraphMemoryBankDef'
|
||||||
|
required:
|
||||||
|
- memory_bank
|
||||||
|
type: object
|
||||||
|
RegisterModelRequest:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
model:
|
||||||
|
$ref: '#/components/schemas/ModelDef'
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
type: object
|
||||||
|
RegisterShieldRequest:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
shield:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
identifier:
|
||||||
|
type: string
|
||||||
|
params:
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
type: object
|
||||||
|
provider_id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- identifier
|
||||||
|
- type
|
||||||
|
- params
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- shield
|
||||||
|
type: object
|
||||||
RestAPIExecutionConfig:
|
RestAPIExecutionConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1728,7 +1645,7 @@ components:
|
||||||
properties:
|
properties:
|
||||||
method:
|
method:
|
||||||
type: string
|
type: string
|
||||||
providers:
|
provider_types:
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
type: array
|
type: array
|
||||||
|
@ -1737,7 +1654,7 @@ components:
|
||||||
required:
|
required:
|
||||||
- route
|
- route
|
||||||
- method
|
- method
|
||||||
- providers
|
- provider_types
|
||||||
type: object
|
type: object
|
||||||
RunShieldRequest:
|
RunShieldRequest:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
@ -1892,7 +1809,11 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
memory_bank:
|
memory_bank:
|
||||||
$ref: '#/components/schemas/MemoryBank'
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/VectorMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/KeywordMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/GraphMemoryBankDef'
|
||||||
session_id:
|
session_id:
|
||||||
type: string
|
type: string
|
||||||
session_name:
|
session_name:
|
||||||
|
@ -1935,33 +1856,29 @@ components:
|
||||||
- step_id
|
- step_id
|
||||||
- step_type
|
- step_type
|
||||||
type: object
|
type: object
|
||||||
ShieldSpec:
|
ShieldDef:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
provider_config:
|
identifier:
|
||||||
additionalProperties: false
|
type: string
|
||||||
properties:
|
params:
|
||||||
config:
|
additionalProperties:
|
||||||
additionalProperties:
|
oneOf:
|
||||||
oneOf:
|
- type: 'null'
|
||||||
- type: 'null'
|
- type: boolean
|
||||||
- type: boolean
|
- type: number
|
||||||
- type: number
|
- type: string
|
||||||
- type: string
|
- type: array
|
||||||
- type: array
|
- type: object
|
||||||
- type: object
|
|
||||||
type: object
|
|
||||||
provider_type:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- provider_type
|
|
||||||
- config
|
|
||||||
type: object
|
type: object
|
||||||
shield_type:
|
provider_id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- shield_type
|
- identifier
|
||||||
- provider_config
|
- type
|
||||||
|
- params
|
||||||
type: object
|
type: object
|
||||||
SpanEndPayload:
|
SpanEndPayload:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
@ -2571,6 +2488,29 @@ components:
|
||||||
- role
|
- role
|
||||||
- content
|
- content
|
||||||
type: object
|
type: object
|
||||||
|
VectorMemoryBankDef:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
chunk_size_in_tokens:
|
||||||
|
type: integer
|
||||||
|
embedding_model:
|
||||||
|
type: string
|
||||||
|
identifier:
|
||||||
|
type: string
|
||||||
|
overlap_size_in_tokens:
|
||||||
|
type: integer
|
||||||
|
provider_id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
const: vector
|
||||||
|
default: vector
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- identifier
|
||||||
|
- type
|
||||||
|
- embedding_model
|
||||||
|
- chunk_size_in_tokens
|
||||||
|
type: object
|
||||||
ViolationLevel:
|
ViolationLevel:
|
||||||
enum:
|
enum:
|
||||||
- info
|
- info
|
||||||
|
@ -2604,7 +2544,7 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
|
\ draft and subject to change.\n Generated at 2024-10-08 15:18:57.600111"
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
@ -3226,7 +3166,7 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Inference
|
- Inference
|
||||||
/memory/create:
|
/inference/register_model:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
@ -3240,17 +3180,13 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/CreateMemoryBankRequest'
|
$ref: '#/components/schemas/RegisterModelRequest'
|
||||||
required: true
|
required: true
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/MemoryBank'
|
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Memory
|
- Models
|
||||||
/memory/documents/delete:
|
/memory/documents/delete:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -3302,57 +3238,6 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Memory
|
- Memory
|
||||||
/memory/drop:
|
|
||||||
post:
|
|
||||||
parameters:
|
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
|
||||||
adapter servicing the API
|
|
||||||
in: header
|
|
||||||
name: X-LlamaStack-ProviderData
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
requestBody:
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/DropMemoryBankRequest'
|
|
||||||
required: true
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
description: OK
|
|
||||||
tags:
|
|
||||||
- Memory
|
|
||||||
/memory/get:
|
|
||||||
get:
|
|
||||||
parameters:
|
|
||||||
- in: query
|
|
||||||
name: bank_id
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
|
||||||
adapter servicing the API
|
|
||||||
in: header
|
|
||||||
name: X-LlamaStack-ProviderData
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
oneOf:
|
|
||||||
- $ref: '#/components/schemas/MemoryBank'
|
|
||||||
- type: 'null'
|
|
||||||
description: OK
|
|
||||||
tags:
|
|
||||||
- Memory
|
|
||||||
/memory/insert:
|
/memory/insert:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -3374,25 +3259,6 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Memory
|
- Memory
|
||||||
/memory/list:
|
|
||||||
get:
|
|
||||||
parameters:
|
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
|
||||||
adapter servicing the API
|
|
||||||
in: header
|
|
||||||
name: X-LlamaStack-ProviderData
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
content:
|
|
||||||
application/jsonl:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/MemoryBank'
|
|
||||||
description: OK
|
|
||||||
tags:
|
|
||||||
- Memory
|
|
||||||
/memory/query:
|
/memory/query:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -3443,10 +3309,10 @@ paths:
|
||||||
get:
|
get:
|
||||||
parameters:
|
parameters:
|
||||||
- in: query
|
- in: query
|
||||||
name: bank_type
|
name: identifier
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/MemoryBankType'
|
type: string
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
adapter servicing the API
|
adapter servicing the API
|
||||||
in: header
|
in: header
|
||||||
|
@ -3460,7 +3326,11 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/MemoryBankSpec'
|
- oneOf:
|
||||||
|
- $ref: '#/components/schemas/VectorMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/KeywordMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/GraphMemoryBankDef'
|
||||||
- type: 'null'
|
- type: 'null'
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
|
@ -3480,15 +3350,40 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/jsonl:
|
application/jsonl:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/MemoryBankSpec'
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/VectorMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/KeywordMemoryBankDef'
|
||||||
|
- $ref: '#/components/schemas/GraphMemoryBankDef'
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- MemoryBanks
|
- MemoryBanks
|
||||||
|
/memory_banks/register:
|
||||||
|
post:
|
||||||
|
parameters:
|
||||||
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
adapter servicing the API
|
||||||
|
in: header
|
||||||
|
name: X-LlamaStack-ProviderData
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RegisterMemoryBankRequest'
|
||||||
|
required: true
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
tags:
|
||||||
|
- Memory
|
||||||
/models/get:
|
/models/get:
|
||||||
get:
|
get:
|
||||||
parameters:
|
parameters:
|
||||||
- in: query
|
- in: query
|
||||||
name: core_model_id
|
name: identifier
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
@ -3505,7 +3400,7 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/ModelServingSpec'
|
- $ref: '#/components/schemas/ModelDef'
|
||||||
- type: 'null'
|
- type: 'null'
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
|
@ -3525,7 +3420,7 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/jsonl:
|
application/jsonl:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/ModelServingSpec'
|
$ref: '#/components/schemas/ModelDef'
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Models
|
- Models
|
||||||
|
@ -3760,6 +3655,27 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Inspect
|
- Inspect
|
||||||
|
/safety/register_shield:
|
||||||
|
post:
|
||||||
|
parameters:
|
||||||
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
adapter servicing the API
|
||||||
|
in: header
|
||||||
|
name: X-LlamaStack-ProviderData
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RegisterShieldRequest'
|
||||||
|
required: true
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
tags:
|
||||||
|
- Shields
|
||||||
/safety/run_shield:
|
/safety/run_shield:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -3806,7 +3722,29 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/ShieldSpec'
|
- additionalProperties: false
|
||||||
|
properties:
|
||||||
|
identifier:
|
||||||
|
type: string
|
||||||
|
params:
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
type: object
|
||||||
|
provider_id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- identifier
|
||||||
|
- type
|
||||||
|
- params
|
||||||
|
type: object
|
||||||
- type: 'null'
|
- type: 'null'
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
|
@ -3826,7 +3764,7 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/jsonl:
|
application/jsonl:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/ShieldSpec'
|
$ref: '#/components/schemas/ShieldDef'
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Shields
|
- Shields
|
||||||
|
@ -3905,21 +3843,21 @@ security:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
tags:
|
tags:
|
||||||
- name: Datasets
|
|
||||||
- name: Inspect
|
|
||||||
- name: Memory
|
|
||||||
- name: BatchInference
|
- name: BatchInference
|
||||||
- name: Agents
|
- name: Datasets
|
||||||
- name: Inference
|
- name: Inference
|
||||||
- name: Shields
|
|
||||||
- name: SyntheticDataGeneration
|
|
||||||
- name: Models
|
|
||||||
- name: RewardScoring
|
|
||||||
- name: MemoryBanks
|
|
||||||
- name: Safety
|
|
||||||
- name: Evaluations
|
- name: Evaluations
|
||||||
- name: Telemetry
|
- name: Memory
|
||||||
|
- name: Safety
|
||||||
- name: PostTraining
|
- name: PostTraining
|
||||||
|
- name: MemoryBanks
|
||||||
|
- name: Models
|
||||||
|
- name: Shields
|
||||||
|
- name: Inspect
|
||||||
|
- name: SyntheticDataGeneration
|
||||||
|
- name: Telemetry
|
||||||
|
- name: Agents
|
||||||
|
- name: RewardScoring
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||||
name: BuiltinTool
|
name: BuiltinTool
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||||
|
@ -4123,11 +4061,6 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateDatasetRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateDatasetRequest"
|
||||||
/>
|
/>
|
||||||
name: CreateDatasetRequest
|
name: CreateDatasetRequest
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateMemoryBankRequest"
|
|
||||||
/>
|
|
||||||
name: CreateMemoryBankRequest
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBank" />
|
|
||||||
name: MemoryBank
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsRequest"
|
||||||
/>
|
/>
|
||||||
name: DeleteAgentsRequest
|
name: DeleteAgentsRequest
|
||||||
|
@ -4140,9 +4073,6 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteDocumentsRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteDocumentsRequest"
|
||||||
/>
|
/>
|
||||||
name: DeleteDocumentsRequest
|
name: DeleteDocumentsRequest
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DropMemoryBankRequest"
|
|
||||||
/>
|
|
||||||
name: DropMemoryBankRequest
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/EmbeddingsRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/EmbeddingsRequest"
|
||||||
/>
|
/>
|
||||||
name: EmbeddingsRequest
|
name: EmbeddingsRequest
|
||||||
|
@ -4163,11 +4093,23 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/GetAgentsSessionRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/GetAgentsSessionRequest"
|
||||||
/>
|
/>
|
||||||
name: GetAgentsSessionRequest
|
name: GetAgentsSessionRequest
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankDef"
|
||||||
|
/>
|
||||||
|
name: GraphMemoryBankDef
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/KeyValueMemoryBankDef"
|
||||||
|
/>
|
||||||
|
name: KeyValueMemoryBankDef
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/KeywordMemoryBankDef"
|
||||||
|
/>
|
||||||
|
name: KeywordMemoryBankDef
|
||||||
- description: 'A single session of an interaction with an Agentic System.
|
- description: 'A single session of an interaction with an Agentic System.
|
||||||
|
|
||||||
|
|
||||||
<SchemaDefinition schemaRef="#/components/schemas/Session" />'
|
<SchemaDefinition schemaRef="#/components/schemas/Session" />'
|
||||||
name: Session
|
name: Session
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBankDef"
|
||||||
|
/>
|
||||||
|
name: VectorMemoryBankDef
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse"
|
||||||
/>
|
/>
|
||||||
name: AgentStepResponse
|
name: AgentStepResponse
|
||||||
|
@ -4189,21 +4131,8 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJobStatusResponse"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJobStatusResponse"
|
||||||
/>
|
/>
|
||||||
name: EvaluationJobStatusResponse
|
name: EvaluationJobStatusResponse
|
||||||
- description: 'The model family and SKU of the model along with other parameters
|
- description: <SchemaDefinition schemaRef="#/components/schemas/ModelDef" />
|
||||||
corresponding to the model.
|
name: ModelDef
|
||||||
|
|
||||||
|
|
||||||
<SchemaDefinition schemaRef="#/components/schemas/Model" />'
|
|
||||||
name: Model
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ModelServingSpec"
|
|
||||||
/>
|
|
||||||
name: ModelServingSpec
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankType" />
|
|
||||||
name: MemoryBankType
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankSpec" />
|
|
||||||
name: MemoryBankSpec
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldSpec" />
|
|
||||||
name: ShieldSpec
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
|
||||||
name: Trace
|
name: Trace
|
||||||
- description: 'Checkpoint created during training runs
|
- description: 'Checkpoint created during training runs
|
||||||
|
@ -4243,6 +4172,8 @@ tags:
|
||||||
name: ProviderInfo
|
name: ProviderInfo
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
|
||||||
name: RouteInfo
|
name: RouteInfo
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldDef" />
|
||||||
|
name: ShieldDef
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
|
||||||
name: LogSeverity
|
name: LogSeverity
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
|
||||||
|
@ -4282,6 +4213,15 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryDocumentsResponse"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryDocumentsResponse"
|
||||||
/>
|
/>
|
||||||
name: QueryDocumentsResponse
|
name: QueryDocumentsResponse
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterMemoryBankRequest"
|
||||||
|
/>
|
||||||
|
name: RegisterMemoryBankRequest
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterModelRequest"
|
||||||
|
/>
|
||||||
|
name: RegisterModelRequest
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterShieldRequest"
|
||||||
|
/>
|
||||||
|
name: RegisterShieldRequest
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DialogGenerations"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/DialogGenerations"
|
||||||
/>
|
/>
|
||||||
name: DialogGenerations
|
name: DialogGenerations
|
||||||
|
@ -4387,7 +4327,6 @@ x-tagGroups:
|
||||||
- CreateAgentSessionRequest
|
- CreateAgentSessionRequest
|
||||||
- CreateAgentTurnRequest
|
- CreateAgentTurnRequest
|
||||||
- CreateDatasetRequest
|
- CreateDatasetRequest
|
||||||
- CreateMemoryBankRequest
|
|
||||||
- DPOAlignmentConfig
|
- DPOAlignmentConfig
|
||||||
- DeleteAgentsRequest
|
- DeleteAgentsRequest
|
||||||
- DeleteAgentsSessionRequest
|
- DeleteAgentsSessionRequest
|
||||||
|
@ -4395,7 +4334,6 @@ x-tagGroups:
|
||||||
- DeleteDocumentsRequest
|
- DeleteDocumentsRequest
|
||||||
- DialogGenerations
|
- DialogGenerations
|
||||||
- DoraFinetuningConfig
|
- DoraFinetuningConfig
|
||||||
- DropMemoryBankRequest
|
|
||||||
- EmbeddingsRequest
|
- EmbeddingsRequest
|
||||||
- EmbeddingsResponse
|
- EmbeddingsResponse
|
||||||
- EvaluateQuestionAnsweringRequest
|
- EvaluateQuestionAnsweringRequest
|
||||||
|
@ -4409,22 +4347,21 @@ x-tagGroups:
|
||||||
- FunctionCallToolDefinition
|
- FunctionCallToolDefinition
|
||||||
- GetAgentsSessionRequest
|
- GetAgentsSessionRequest
|
||||||
- GetDocumentsRequest
|
- GetDocumentsRequest
|
||||||
|
- GraphMemoryBankDef
|
||||||
- HealthInfo
|
- HealthInfo
|
||||||
- ImageMedia
|
- ImageMedia
|
||||||
- InferenceStep
|
- InferenceStep
|
||||||
- InsertDocumentsRequest
|
- InsertDocumentsRequest
|
||||||
|
- KeyValueMemoryBankDef
|
||||||
|
- KeywordMemoryBankDef
|
||||||
- LogEventRequest
|
- LogEventRequest
|
||||||
- LogSeverity
|
- LogSeverity
|
||||||
- LoraFinetuningConfig
|
- LoraFinetuningConfig
|
||||||
- MemoryBank
|
|
||||||
- MemoryBankDocument
|
- MemoryBankDocument
|
||||||
- MemoryBankSpec
|
|
||||||
- MemoryBankType
|
|
||||||
- MemoryRetrievalStep
|
- MemoryRetrievalStep
|
||||||
- MemoryToolDefinition
|
- MemoryToolDefinition
|
||||||
- MetricEvent
|
- MetricEvent
|
||||||
- Model
|
- ModelDef
|
||||||
- ModelServingSpec
|
|
||||||
- OptimizerConfig
|
- OptimizerConfig
|
||||||
- PhotogenToolDefinition
|
- PhotogenToolDefinition
|
||||||
- PostTrainingJob
|
- PostTrainingJob
|
||||||
|
@ -4438,6 +4375,9 @@ x-tagGroups:
|
||||||
- QueryDocumentsRequest
|
- QueryDocumentsRequest
|
||||||
- QueryDocumentsResponse
|
- QueryDocumentsResponse
|
||||||
- RLHFAlgorithm
|
- RLHFAlgorithm
|
||||||
|
- RegisterMemoryBankRequest
|
||||||
|
- RegisterModelRequest
|
||||||
|
- RegisterShieldRequest
|
||||||
- RestAPIExecutionConfig
|
- RestAPIExecutionConfig
|
||||||
- RestAPIMethod
|
- RestAPIMethod
|
||||||
- RewardScoreRequest
|
- RewardScoreRequest
|
||||||
|
@ -4453,7 +4393,7 @@ x-tagGroups:
|
||||||
- SearchToolDefinition
|
- SearchToolDefinition
|
||||||
- Session
|
- Session
|
||||||
- ShieldCallStep
|
- ShieldCallStep
|
||||||
- ShieldSpec
|
- ShieldDef
|
||||||
- SpanEndPayload
|
- SpanEndPayload
|
||||||
- SpanStartPayload
|
- SpanStartPayload
|
||||||
- SpanStatus
|
- SpanStatus
|
||||||
|
@ -4483,5 +4423,6 @@ x-tagGroups:
|
||||||
- UnstructuredLogEvent
|
- UnstructuredLogEvent
|
||||||
- UpdateDocumentsRequest
|
- UpdateDocumentsRequest
|
||||||
- UserMessage
|
- UserMessage
|
||||||
|
- VectorMemoryBankDef
|
||||||
- ViolationLevel
|
- ViolationLevel
|
||||||
- WolframAlphaToolDefinition
|
- WolframAlphaToolDefinition
|
||||||
|
|
BIN
docs/resources/prompt-format.png
Normal file
BIN
docs/resources/prompt-format.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 170 KiB |
|
@ -261,7 +261,7 @@ class Session(BaseModel):
|
||||||
turns: List[Turn]
|
turns: List[Turn]
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
memory_bank: Optional[MemoryBank] = None
|
memory_bank: Optional[MemoryBankDef] = None
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
|
@ -411,8 +411,10 @@ class Agents(Protocol):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgentCreateResponse: ...
|
) -> AgentCreateResponse: ...
|
||||||
|
|
||||||
|
# This method is not `async def` because it can result in either an
|
||||||
|
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
|
||||||
@webmethod(route="/agents/turn/create")
|
@webmethod(route="/agents/turn/create")
|
||||||
async def create_agent_turn(
|
def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -67,9 +67,17 @@ class AgentsClient(Agents):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return AgentSessionCreateResponse(**response.json())
|
return AgentSessionCreateResponse(**response.json())
|
||||||
|
|
||||||
async def create_agent_turn(
|
def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
request: AgentTurnCreateRequest,
|
request: AgentTurnCreateRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if request.stream:
|
||||||
|
return self._stream_agent_turn(request)
|
||||||
|
else:
|
||||||
|
return self._nonstream_agent_turn(request)
|
||||||
|
|
||||||
|
async def _stream_agent_turn(
|
||||||
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
|
@ -93,6 +101,9 @@ class AgentsClient(Agents):
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
||||||
|
raise NotImplementedError("Non-streaming not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
async def _run_agent(
|
async def _run_agent(
|
||||||
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
||||||
|
@ -132,8 +143,7 @@ async def _run_agent(
|
||||||
log.print()
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_1(host: str, port: int):
|
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||||
model = "Llama3.1-8B-Instruct"
|
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
|
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
|
||||||
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_2_rag(host: str, port: int):
|
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||||
model = "Llama3.2-3B-Instruct"
|
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
|
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_2(host: str, port: int):
|
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||||
model = "Llama3.2-3B-Instruct"
|
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
# zero shot tools for llama3.2 text models
|
# zero shot tools for llama3.2 text models
|
||||||
|
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, run_type: str):
|
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||||
assert run_type in [
|
assert run_type in [
|
||||||
"tools_llama_3_1",
|
"tools_llama_3_1",
|
||||||
"tools_llama_3_2",
|
"tools_llama_3_2",
|
||||||
|
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
|
||||||
"tools_llama_3_2": run_llama_3_2,
|
"tools_llama_3_2": run_llama_3_2,
|
||||||
"rag_llama_3_2": run_llama_3_2_rag,
|
"rag_llama_3_2": run_llama_3_2_rag,
|
||||||
}
|
}
|
||||||
asyncio.run(fn[run_type](host, port))
|
args = [host, port]
|
||||||
|
if model is not None:
|
||||||
|
args.append(model)
|
||||||
|
asyncio.run(fn[run_type](*args))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -42,10 +42,10 @@ class InferenceClient(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -66,6 +66,29 @@ class InferenceClient(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/inference/chat_completion",
|
||||||
|
json=encodable_dict(request),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
j = response.json()
|
||||||
|
return ChatCompletionResponse(**j)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
|
@ -77,7 +100,8 @@ class InferenceClient(Inference):
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
content = await response.aread()
|
content = await response.aread()
|
||||||
cprint(
|
cprint(
|
||||||
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
f"Error: HTTP {response.status_code} {content.decode()}",
|
||||||
|
"red",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -85,40 +109,59 @@ class InferenceClient(Inference):
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data = line[len("data: ") :]
|
data = line[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
if request.stream:
|
if "error" in data:
|
||||||
if "error" in data:
|
cprint(data, "red")
|
||||||
cprint(data, "red")
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||||
**json.loads(data)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield ChatCompletionResponse(**json.loads(data))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(
|
||||||
|
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
|
||||||
|
):
|
||||||
client = InferenceClient(f"http://{host}:{port}")
|
client = InferenceClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
model = "Llama3.1-8B-Instruct"
|
||||||
|
|
||||||
message = UserMessage(
|
message = UserMessage(
|
||||||
content="hello world, write me a 2 sentence poem about the moon"
|
content="hello world, write me a 2 sentence poem about the moon"
|
||||||
)
|
)
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
|
|
||||||
|
if logprobs:
|
||||||
|
logprobs_config = LogProbConfig(
|
||||||
|
top_k=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs_config = None
|
||||||
|
|
||||||
iterator = client.chat_completion(
|
iterator = client.chat_completion(
|
||||||
model="Llama3.1-8B-Instruct",
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
logprobs=logprobs_config,
|
||||||
)
|
)
|
||||||
async for log in EventLogger().log(iterator):
|
|
||||||
log.print()
|
if logprobs:
|
||||||
|
async for chunk in iterator:
|
||||||
|
cprint(f"Response: {chunk}", "red")
|
||||||
|
else:
|
||||||
|
async for log in EventLogger().log(iterator):
|
||||||
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
async def run_mm_main(
|
||||||
|
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
|
||||||
|
):
|
||||||
client = InferenceClient(f"http://{host}:{port}")
|
client = InferenceClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
model = "Llama3.2-11B-Vision-Instruct"
|
||||||
|
|
||||||
message = UserMessage(
|
message = UserMessage(
|
||||||
content=[
|
content=[
|
||||||
ImageMedia(image=URL(uri=f"file://{path}")),
|
ImageMedia(image=URL(uri=f"file://{path}")),
|
||||||
|
@ -127,7 +170,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
||||||
)
|
)
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
iterator = client.chat_completion(
|
iterator = client.chat_completion(
|
||||||
model="Llama3.2-11B-Vision-Instruct",
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
@ -135,11 +178,19 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
||||||
log.print()
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
|
def main(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
stream: bool = True,
|
||||||
|
mm: bool = False,
|
||||||
|
logprobs: bool = False,
|
||||||
|
file: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
):
|
||||||
if mm:
|
if mm:
|
||||||
asyncio.run(run_mm_main(host, port, stream, file))
|
asyncio.run(run_mm_main(host, port, stream, file, model))
|
||||||
else:
|
else:
|
||||||
asyncio.run(run_main(host, port, stream))
|
asyncio.run(run_main(host, port, stream, model, logprobs))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
|
@ -172,9 +173,17 @@ class EmbeddingsResponse(BaseModel):
|
||||||
embeddings: List[List[float]]
|
embeddings: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStore(Protocol):
|
||||||
|
def get_model(self, identifier: str) -> ModelDef: ...
|
||||||
|
|
||||||
|
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
|
model_store: ModelStore
|
||||||
|
|
||||||
|
# This method is not `async def` because it can result in either an
|
||||||
|
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -183,8 +192,10 @@ class Inference(Protocol):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||||
|
|
||||||
|
# This method is not `async def` because it can result in either an
|
||||||
|
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat_completion")
|
||||||
async def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -203,3 +214,6 @@ class Inference(Protocol):
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse: ...
|
) -> EmbeddingsResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/register_model")
|
||||||
|
async def register_model(self, model: ModelDef) -> None: ...
|
||||||
|
|
|
@ -12,15 +12,15 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProviderInfo(BaseModel):
|
class ProviderInfo(BaseModel):
|
||||||
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
description: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RouteInfo(BaseModel):
|
class RouteInfo(BaseModel):
|
||||||
route: str
|
route: str
|
||||||
method: str
|
method: str
|
||||||
providers: List[str]
|
provider_types: List[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks.client import MemoryBanksClient
|
||||||
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,44 +35,16 @@ class MemoryClient(Memory):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(
|
response = await client.post(
|
||||||
f"{self.base_url}/memory/get",
|
f"{self.base_url}/memory/register_memory_bank",
|
||||||
params={
|
|
||||||
"bank_id": bank_id,
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
timeout=20,
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
|
||||||
d = r.json()
|
|
||||||
if not d:
|
|
||||||
return None
|
|
||||||
return MemoryBank(**d)
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
config: MemoryBankConfig,
|
|
||||||
url: Optional[URL] = None,
|
|
||||||
) -> MemoryBank:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
r = await client.post(
|
|
||||||
f"{self.base_url}/memory/create",
|
|
||||||
json={
|
json={
|
||||||
"name": name,
|
"memory_bank": json.loads(memory_bank.json()),
|
||||||
"config": config.dict(),
|
|
||||||
"url": url,
|
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
response.raise_for_status()
|
||||||
d = r.json()
|
|
||||||
if not d:
|
|
||||||
return None
|
|
||||||
return MemoryBank(**d)
|
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -114,22 +86,20 @@ class MemoryClient(Memory):
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
client = MemoryClient(f"http://{host}:{port}")
|
client = MemoryClient(f"http://{host}:{port}")
|
||||||
|
banks_client = MemoryBanksClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
# create a memory bank
|
bank = VectorMemoryBankDef(
|
||||||
bank = await client.create_memory_bank(
|
identifier="test_bank",
|
||||||
name="test_bank",
|
provider_id="",
|
||||||
config=VectorMemoryBankConfig(
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
bank_id="test_bank",
|
chunk_size_in_tokens=512,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
overlap_size_in_tokens=64,
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
cprint(json.dumps(bank.dict(), indent=4), "green")
|
await client.register_memory_bank(bank)
|
||||||
|
|
||||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
|
||||||
assert retrieved_bank is not None
|
assert retrieved_bank is not None
|
||||||
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
|
assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
|
@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
|
|
||||||
# insert some documents
|
# insert some documents
|
||||||
await client.insert_documents(
|
await client.insert_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.identifier,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
)
|
)
|
||||||
|
|
||||||
# query the documents
|
# query the documents
|
||||||
response = await client.query_documents(
|
response = await client.query_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.identifier,
|
||||||
query=[
|
query=[
|
||||||
"How do I use Lora?",
|
"How do I use Lora?",
|
||||||
],
|
],
|
||||||
|
@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||||
|
|
||||||
response = await client.query_documents(
|
response = await client.query_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.identifier,
|
||||||
query=[
|
query=[
|
||||||
"Tell me more about llama3 and torchtune",
|
"Tell me more about llama3 and torchtune",
|
||||||
],
|
],
|
||||||
|
|
|
@ -13,9 +13,9 @@ from typing import List, Optional, Protocol
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -26,44 +26,6 @@ class MemoryBankDocument(BaseModel):
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBankType(Enum):
|
|
||||||
vector = "vector"
|
|
||||||
keyvalue = "keyvalue"
|
|
||||||
keyword = "keyword"
|
|
||||||
graph = "graph"
|
|
||||||
|
|
||||||
|
|
||||||
class VectorMemoryBankConfig(BaseModel):
|
|
||||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
||||||
embedding_model: str
|
|
||||||
chunk_size_in_tokens: int
|
|
||||||
overlap_size_in_tokens: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class KeyValueMemoryBankConfig(BaseModel):
|
|
||||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordMemoryBankConfig(BaseModel):
|
|
||||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
|
||||||
|
|
||||||
|
|
||||||
class GraphMemoryBankConfig(BaseModel):
|
|
||||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
VectorMemoryBankConfig,
|
|
||||||
KeyValueMemoryBankConfig,
|
|
||||||
KeywordMemoryBankConfig,
|
|
||||||
GraphMemoryBankConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
content: InterleavedTextMedia
|
content: InterleavedTextMedia
|
||||||
token_count: int
|
token_count: int
|
||||||
|
@ -76,45 +38,12 @@ class QueryDocumentsResponse(BaseModel):
|
||||||
scores: List[float]
|
scores: List[float]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class MemoryBankStore(Protocol):
|
||||||
class QueryAPI(Protocol):
|
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
||||||
@webmethod(route="/query_documents")
|
|
||||||
def query_documents(
|
|
||||||
self,
|
|
||||||
query: InterleavedTextMedia,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> QueryDocumentsResponse: ...
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBank(BaseModel):
|
|
||||||
bank_id: str
|
|
||||||
name: str
|
|
||||||
config: MemoryBankConfig
|
|
||||||
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
|
|
||||||
url: Optional[URL] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Memory(Protocol):
|
class Memory(Protocol):
|
||||||
@webmethod(route="/memory/create")
|
memory_bank_store: MemoryBankStore
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
config: MemoryBankConfig,
|
|
||||||
url: Optional[URL] = None,
|
|
||||||
) -> MemoryBank: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory/list", method="GET")
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory/get", method="GET")
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory/drop", method="DELETE")
|
|
||||||
async def drop_memory_bank(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
) -> str: ...
|
|
||||||
|
|
||||||
# this will just block now until documents are inserted, but it should
|
# this will just block now until documents are inserted, but it should
|
||||||
# probably return a Job instance which can be polled for completion
|
# probably return a Job instance which can be polled for completion
|
||||||
|
@ -154,3 +83,6 @@ class Memory(Protocol):
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_ids: List[str],
|
document_ids: List[str],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory/register_memory_bank")
|
||||||
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -15,6 +15,25 @@ from termcolor import cprint
|
||||||
from .memory_banks import * # noqa: F403
|
from .memory_banks import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
|
||||||
|
if j is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "type" not in j:
|
||||||
|
raise ValueError("Memory bank type not specified")
|
||||||
|
type = j["type"]
|
||||||
|
if type == MemoryBankType.vector.value:
|
||||||
|
return VectorMemoryBankDef(**j)
|
||||||
|
elif type == MemoryBankType.keyvalue.value:
|
||||||
|
return KeyValueMemoryBankDef(**j)
|
||||||
|
elif type == MemoryBankType.keyword.value:
|
||||||
|
return KeywordMemoryBankDef(**j)
|
||||||
|
elif type == MemoryBankType.graph.value:
|
||||||
|
return GraphMemoryBankDef(**j)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown memory bank type: {type}")
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanksClient(MemoryBanks):
|
class MemoryBanksClient(MemoryBanks):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
@ -25,37 +44,36 @@ class MemoryBanksClient(MemoryBanks):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/memory_banks/list",
|
f"{self.base_url}/memory_banks/list",
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return [MemoryBankSpec(**x) for x in response.json()]
|
return [deserialize_memory_bank_def(x) for x in response.json()]
|
||||||
|
|
||||||
async def get_serving_memory_bank(
|
async def get_memory_bank(
|
||||||
self, bank_type: MemoryBankType
|
self,
|
||||||
) -> Optional[MemoryBankSpec]:
|
identifier: str,
|
||||||
|
) -> Optional[MemoryBankDef]:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/memory_banks/get",
|
f"{self.base_url}/memory_banks/get",
|
||||||
params={
|
params={
|
||||||
"bank_type": bank_type.value,
|
"identifier": identifier,
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
j = response.json()
|
j = response.json()
|
||||||
if j is None:
|
return deserialize_memory_bank_def(j)
|
||||||
return None
|
|
||||||
return MemoryBankSpec(**j)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
client = MemoryBanksClient(f"http://{host}:{port}")
|
client = MemoryBanksClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
response = await client.list_available_memory_banks()
|
response = await client.list_memory_banks()
|
||||||
cprint(f"list_memory_banks response={response}", "green")
|
cprint(f"list_memory_banks response={response}", "green")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,29 +4,67 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol
|
from enum import Enum
|
||||||
|
from typing import List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
from llama_stack.apis.memory import MemoryBankType
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MemoryBankSpec(BaseModel):
|
class MemoryBankType(Enum):
|
||||||
bank_type: MemoryBankType
|
vector = "vector"
|
||||||
provider_config: GenericProviderConfig = Field(
|
keyvalue = "keyvalue"
|
||||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
keyword = "keyword"
|
||||||
)
|
graph = "graph"
|
||||||
|
|
||||||
|
|
||||||
|
class CommonDef(BaseModel):
|
||||||
|
identifier: str
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VectorMemoryBankDef(CommonDef):
|
||||||
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
embedding_model: str
|
||||||
|
chunk_size_in_tokens: int
|
||||||
|
overlap_size_in_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class KeyValueMemoryBankDef(CommonDef):
|
||||||
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class KeywordMemoryBankDef(CommonDef):
|
||||||
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GraphMemoryBankDef(CommonDef):
|
||||||
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankDef = Annotated[
|
||||||
|
Union[
|
||||||
|
VectorMemoryBankDef,
|
||||||
|
KeyValueMemoryBankDef,
|
||||||
|
KeywordMemoryBankDef,
|
||||||
|
GraphMemoryBankDef,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanks(Protocol):
|
class MemoryBanks(Protocol):
|
||||||
@webmethod(route="/memory_banks/list", method="GET")
|
@webmethod(route="/memory_banks/list", method="GET")
|
||||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: ...
|
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get", method="GET")
|
@webmethod(route="/memory_banks/get", method="GET")
|
||||||
async def get_serving_memory_bank(
|
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
||||||
self, bank_type: MemoryBankType
|
|
||||||
) -> Optional[MemoryBankSpec]: ...
|
@webmethod(route="/memory_banks/register", method="POST")
|
||||||
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||||
|
|
|
@ -56,7 +56,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
response = await client.list_models()
|
response = await client.list_models()
|
||||||
cprint(f"list_models response={response}", "green")
|
cprint(f"list_models response={response}", "green")
|
||||||
|
|
||||||
response = await client.get_model("Meta-Llama3.1-8B-Instruct")
|
response = await client.get_model("Llama3.1-8B-Instruct")
|
||||||
cprint(f"get_model response={response}", "blue")
|
cprint(f"get_model response={response}", "blue")
|
||||||
|
|
||||||
response = await client.get_model("Llama-Guard-3-1B")
|
response = await client.get_model("Llama-Guard-3-1B")
|
||||||
|
|
|
@ -6,27 +6,32 @@
|
||||||
|
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Model
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ModelServingSpec(BaseModel):
|
class ModelDef(BaseModel):
|
||||||
llama_model: Model = Field(
|
identifier: str = Field(
|
||||||
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
|
description="A unique identifier for the model type",
|
||||||
)
|
)
|
||||||
provider_config: GenericProviderConfig = Field(
|
llama_model: str = Field(
|
||||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
description="Pointer to the core Llama family model",
|
||||||
)
|
)
|
||||||
|
provider_id: Optional[str] = Field(
|
||||||
|
default=None, description="The provider instance which serves this model"
|
||||||
|
)
|
||||||
|
# For now, we are only supporting core llama models but as soon as finetuned
|
||||||
|
# and other custom models (for example various quantizations) are allowed, there
|
||||||
|
# will be more metadata fields here
|
||||||
|
|
||||||
|
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models/list", method="GET")
|
@webmethod(route="/models/list", method="GET")
|
||||||
async def list_models(self) -> List[ModelServingSpec]: ...
|
async def list_models(self) -> List[ModelDef]: ...
|
||||||
|
|
||||||
@webmethod(route="/models/get", method="GET")
|
@webmethod(route="/models/get", method="GET")
|
||||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...
|
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/models/register", method="POST")
|
||||||
|
async def register_model(self, model: ModelDef) -> None: ...
|
||||||
|
|
|
@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
response = await client.run_shield(
|
|
||||||
shield_type="injection_shield",
|
|
||||||
messages=[message],
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, image: str = None):
|
def main(host: str, port: int, image: str = None):
|
||||||
asyncio.run(run_main(host, port, image))
|
asyncio.run(run_main(host, port, image))
|
||||||
|
|
|
@ -11,6 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -37,8 +38,17 @@ class RunShieldResponse(BaseModel):
|
||||||
violation: Optional[SafetyViolation] = None
|
violation: Optional[SafetyViolation] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldStore(Protocol):
|
||||||
|
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||||
|
|
||||||
|
|
||||||
class Safety(Protocol):
|
class Safety(Protocol):
|
||||||
|
shield_store: ShieldStore
|
||||||
|
|
||||||
@webmethod(route="/safety/run_shield")
|
@webmethod(route="/safety/run_shield")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/safety/register_shield")
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||||
|
|
|
@ -4,25 +4,43 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ShieldSpec(BaseModel):
|
class ShieldType(Enum):
|
||||||
shield_type: str
|
generic_content_shield = "generic_content_shield"
|
||||||
provider_config: GenericProviderConfig = Field(
|
llama_guard = "llama_guard"
|
||||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
code_scanner = "code_scanner"
|
||||||
|
prompt_guard = "prompt_guard"
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldDef(BaseModel):
|
||||||
|
identifier: str = Field(
|
||||||
|
description="A unique identifier for the shield type",
|
||||||
|
)
|
||||||
|
type: str = Field(
|
||||||
|
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||||
|
)
|
||||||
|
provider_id: Optional[str] = Field(
|
||||||
|
default=None, description="The provider instance which serves this shield"
|
||||||
|
)
|
||||||
|
params: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Any additional parameters needed for this shield",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Shields(Protocol):
|
class Shields(Protocol):
|
||||||
@webmethod(route="/shields/list", method="GET")
|
@webmethod(route="/shields/list", method="GET")
|
||||||
async def list_shields(self) -> List[ShieldSpec]: ...
|
async def list_shields(self) -> List[ShieldDef]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/get", method="GET")
|
@webmethod(route="/shields/get", method="GET")
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...
|
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/shields/register", method="POST")
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||||
|
|
|
@ -158,19 +158,18 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
info = prompt_guard_download_info()
|
info = prompt_guard_download_info()
|
||||||
else:
|
else:
|
||||||
model = resolve_model(args.model_id)
|
model = resolve_model(args.model_id)
|
||||||
|
if model is None:
|
||||||
|
parser.error(f"Model {args.model_id} not found")
|
||||||
|
return
|
||||||
info = llama_meta_net_info(model)
|
info = llama_meta_net_info(model)
|
||||||
|
|
||||||
if model is None:
|
|
||||||
parser.error(f"Model {args.model_id} not found")
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.source == "huggingface":
|
if args.source == "huggingface":
|
||||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||||
else:
|
else:
|
||||||
meta_url = args.meta_url
|
meta_url = args.meta_url
|
||||||
if not meta_url:
|
if not meta_url:
|
||||||
meta_url = input(
|
meta_url = input(
|
||||||
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
"Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||||
)
|
)
|
||||||
assert meta_url is not None and "llamameta.net" in meta_url
|
assert meta_url is not None and "llamameta.net" in meta_url
|
||||||
_meta_download(model, meta_url, info)
|
_meta_download(model, meta_url, info)
|
||||||
|
|
|
@ -22,7 +22,7 @@ def available_templates_specs() -> List[BuildConfig]:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
template_specs = []
|
template_specs = []
|
||||||
for p in TEMPLATES_PATH.rglob("*.yaml"):
|
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||||
with open(p, "r") as f:
|
with open(p, "r") as f:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
build_config = BuildConfig(**yaml.safe_load(f))
|
||||||
template_specs.append(build_config)
|
template_specs.append(build_config)
|
||||||
|
@ -105,8 +105,7 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.distribution.build import ApiInput, build_image, ImageType
|
from llama_stack.distribution.build import build_image, ImageType
|
||||||
|
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
@ -137,16 +136,19 @@ class StackBuild(Subcommand):
|
||||||
if build_config.image_type == "conda"
|
if build_config.image_type == "conda"
|
||||||
else (f"llamastack-{build_config.name}")
|
else (f"llamastack-{build_config.name}")
|
||||||
)
|
)
|
||||||
cprint(
|
if build_config.image_type == "conda":
|
||||||
f"You can now run `llama stack configure {configure_name}`",
|
cprint(
|
||||||
color="green",
|
f"You can now run `llama stack configure {configure_name}`",
|
||||||
)
|
color="green",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cprint(
|
||||||
|
f"You can now run `llama stack run {build_config.name}`",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
|
@ -172,9 +174,11 @@ class StackBuild(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
|
import textwrap
|
||||||
import yaml
|
import yaml
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
|
from prompt_toolkit.completion import WordCompleter
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -238,26 +242,29 @@ class StackBuild(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Llama Stack is composed of several APIs working together. Let's select
|
||||||
|
the provider types (implementations) you want to use for these APIs.
|
||||||
|
""",
|
||||||
|
),
|
||||||
color="green",
|
color="green",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("Tip: use <TAB> to see options for the providers.\n")
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
for api, providers_for_api in get_provider_registry().items():
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
|
available_providers = [
|
||||||
|
x for x in providers_for_api.keys() if x != "remote"
|
||||||
|
]
|
||||||
api_provider = prompt(
|
api_provider = prompt(
|
||||||
"> Enter provider for the {} API: (default=meta-reference): ".format(
|
"> Enter provider for API {}: ".format(api.value),
|
||||||
api.value
|
completer=WordCompleter(available_providers),
|
||||||
),
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: x in providers_for_api,
|
lambda x: x in available_providers,
|
||||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
error_message="Invalid provider, use <TAB> to see options",
|
||||||
list(providers_for_api.keys())
|
|
||||||
),
|
|
||||||
),
|
|
||||||
default=(
|
|
||||||
"meta-reference"
|
|
||||||
if "meta-reference" in providers_for_api
|
|
||||||
else list(providers_for_api.keys())[0]
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -71,10 +71,8 @@ class StackConfigure(Subcommand):
|
||||||
conda_dir = (
|
conda_dir = (
|
||||||
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
|
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
|
||||||
)
|
)
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(["bash", "-c", "conda info --json"])
|
||||||
["bash", "-c", "conda info --json -a | jq '.envs'"]
|
conda_envs = json.loads(output.decode("utf-8"))["envs"]
|
||||||
)
|
|
||||||
conda_envs = json.loads(output.decode("utf-8"))
|
|
||||||
|
|
||||||
for x in conda_envs:
|
for x in conda_envs:
|
||||||
if x.endswith(f"/llamastack-{args.config}"):
|
if x.endswith(f"/llamastack-{args.config}"):
|
||||||
|
@ -129,7 +127,10 @@ class StackConfigure(Subcommand):
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.configure import configure_api_providers
|
from llama_stack.distribution.configure import (
|
||||||
|
configure_api_providers,
|
||||||
|
parse_and_maybe_upgrade_config,
|
||||||
|
)
|
||||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||||
|
|
||||||
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
||||||
|
@ -145,13 +146,17 @@ class StackConfigure(Subcommand):
|
||||||
"yellow",
|
"yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
config = StackRunConfig(**yaml.safe_load(run_config_file.read_text()))
|
config_dict = yaml.safe_load(run_config_file.read_text())
|
||||||
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
else:
|
else:
|
||||||
config = StackRunConfig(
|
config = StackRunConfig(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
apis_to_serve=[],
|
apis=list(build_config.distribution_spec.providers.keys()),
|
||||||
api_providers={},
|
providers={},
|
||||||
|
models=[],
|
||||||
|
shields=[],
|
||||||
|
memory_banks=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
config = configure_api_providers(config, build_config.distribution_spec)
|
config = configure_api_providers(config, build_config.distribution_spec)
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
|
@ -46,10 +45,11 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import yaml
|
import yaml
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.build import ImageType
|
from llama_stack.distribution.build import ImageType
|
||||||
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.distribution.utils.exec import run_with_pty
|
from llama_stack.distribution.utils.exec import run_with_pty
|
||||||
|
|
||||||
if not args.config:
|
if not args.config:
|
||||||
|
@ -75,8 +75,10 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cprint(f"Using config `{config_file}`", "green")
|
||||||
with open(config_file, "r") as f:
|
with open(config_file, "r") as f:
|
||||||
config = StackRunConfig(**yaml.safe_load(f))
|
config_dict = yaml.safe_load(config_file.read_text())
|
||||||
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
|
||||||
if config.docker_image:
|
if config.docker_image:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
|
|
|
@ -1,105 +0,0 @@
|
||||||
from argparse import Namespace
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from llama_stack.distribution.datatypes import BuildConfig
|
|
||||||
from llama_stack.cli.stack.build import StackBuild
|
|
||||||
|
|
||||||
|
|
||||||
# temporary while we make the tests work
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def stack_build():
|
|
||||||
parser = MagicMock()
|
|
||||||
subparsers = MagicMock()
|
|
||||||
return StackBuild(subparsers)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stack_build_initialization(stack_build):
|
|
||||||
assert stack_build.parser is not None
|
|
||||||
assert stack_build.parser.set_defaults.called_once_with(
|
|
||||||
func=stack_build._run_stack_build_command
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.distribution.build.build_image")
|
|
||||||
def test_run_stack_build_command_with_config(
|
|
||||||
mock_build_image, mock_build_config, stack_build
|
|
||||||
):
|
|
||||||
args = Namespace(
|
|
||||||
config="test_config.yaml",
|
|
||||||
template=None,
|
|
||||||
list_templates=False,
|
|
||||||
name=None,
|
|
||||||
image_type="conda",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("builtins.open", MagicMock()):
|
|
||||||
with patch("yaml.safe_load") as mock_yaml_load:
|
|
||||||
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
|
|
||||||
mock_build_config.return_value = MagicMock()
|
|
||||||
|
|
||||||
stack_build._run_stack_build_command(args)
|
|
||||||
|
|
||||||
mock_build_config.assert_called_once()
|
|
||||||
mock_build_image.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.cli.table.print_table")
|
|
||||||
def test_run_stack_build_command_list_templates(mock_print_table, stack_build):
|
|
||||||
args = Namespace(list_templates=True)
|
|
||||||
|
|
||||||
stack_build._run_stack_build_command(args)
|
|
||||||
|
|
||||||
mock_print_table.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@patch("prompt_toolkit.prompt")
|
|
||||||
@patch("llama_stack.distribution.datatypes.BuildConfig")
|
|
||||||
@patch("llama_stack.distribution.build.build_image")
|
|
||||||
def test_run_stack_build_command_interactive(
|
|
||||||
mock_build_image, mock_build_config, mock_prompt, stack_build
|
|
||||||
):
|
|
||||||
args = Namespace(
|
|
||||||
config=None, template=None, list_templates=False, name=None, image_type=None
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_prompt.side_effect = [
|
|
||||||
"test_name",
|
|
||||||
"conda",
|
|
||||||
"meta-reference",
|
|
||||||
"test description",
|
|
||||||
]
|
|
||||||
mock_build_config.return_value = MagicMock()
|
|
||||||
|
|
||||||
stack_build._run_stack_build_command(args)
|
|
||||||
|
|
||||||
assert mock_prompt.call_count == 4
|
|
||||||
mock_build_config.assert_called_once()
|
|
||||||
mock_build_image.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.distribution.datatypes.BuildConfig")
|
|
||||||
@patch("llama_stack.distribution.build.build_image")
|
|
||||||
def test_run_stack_build_command_with_template(
|
|
||||||
mock_build_image, mock_build_config, stack_build
|
|
||||||
):
|
|
||||||
args = Namespace(
|
|
||||||
config=None,
|
|
||||||
template="test_template",
|
|
||||||
list_templates=False,
|
|
||||||
name="test_name",
|
|
||||||
image_type="docker",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("builtins.open", MagicMock()):
|
|
||||||
with patch("yaml.safe_load") as mock_yaml_load:
|
|
||||||
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
|
|
||||||
mock_build_config.return_value = MagicMock()
|
|
||||||
|
|
||||||
stack_build._run_stack_build_command(args)
|
|
||||||
|
|
||||||
mock_build_config.assert_called_once()
|
|
||||||
mock_build_image.assert_called_once()
|
|
153
llama_stack/cli/tests/test_stack_config.py
Normal file
153
llama_stack/cli/tests/test_stack_config.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from llama_stack.distribution.configure import (
|
||||||
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
|
parse_and_maybe_upgrade_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def up_to_date_config():
|
||||||
|
return yaml.safe_load(
|
||||||
|
"""
|
||||||
|
version: {version}
|
||||||
|
image_name: foo
|
||||||
|
apis_to_serve: []
|
||||||
|
built_at: {built_at}
|
||||||
|
models:
|
||||||
|
- identifier: model1
|
||||||
|
provider_id: provider1
|
||||||
|
llama_model: Llama3.1-8B-Instruct
|
||||||
|
shields:
|
||||||
|
- identifier: shield1
|
||||||
|
type: llama_guard
|
||||||
|
provider_id: provider1
|
||||||
|
memory_banks:
|
||||||
|
- identifier: memory1
|
||||||
|
type: vector
|
||||||
|
provider_id: provider1
|
||||||
|
embedding_model: all-MiniLM-L6-v2
|
||||||
|
chunk_size_in_tokens: 512
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: provider1
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {{}}
|
||||||
|
safety:
|
||||||
|
- provider_id: provider1
|
||||||
|
provider_type: meta-reference
|
||||||
|
config:
|
||||||
|
llama_guard_shield:
|
||||||
|
model: Llama-Guard-3-1B
|
||||||
|
excluded_categories: []
|
||||||
|
disable_input_check: false
|
||||||
|
disable_output_check: false
|
||||||
|
enable_prompt_guard: false
|
||||||
|
memory:
|
||||||
|
- provider_id: provider1
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {{}}
|
||||||
|
""".format(
|
||||||
|
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def old_config():
|
||||||
|
return yaml.safe_load(
|
||||||
|
"""
|
||||||
|
image_name: foo
|
||||||
|
built_at: {built_at}
|
||||||
|
apis_to_serve: []
|
||||||
|
routing_table:
|
||||||
|
inference:
|
||||||
|
- provider_type: remote::ollama
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 11434
|
||||||
|
routing_key: Llama3.2-1B-Instruct
|
||||||
|
- provider_type: meta-reference
|
||||||
|
config:
|
||||||
|
model: Llama3.1-8B-Instruct
|
||||||
|
routing_key: Llama3.1-8B-Instruct
|
||||||
|
safety:
|
||||||
|
- routing_key: ["shield1", "shield2"]
|
||||||
|
provider_type: meta-reference
|
||||||
|
config:
|
||||||
|
llama_guard_shield:
|
||||||
|
model: Llama-Guard-3-1B
|
||||||
|
excluded_categories: []
|
||||||
|
disable_input_check: false
|
||||||
|
disable_output_check: false
|
||||||
|
enable_prompt_guard: false
|
||||||
|
memory:
|
||||||
|
- routing_key: vector
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {{}}
|
||||||
|
api_providers:
|
||||||
|
telemetry:
|
||||||
|
provider_type: noop
|
||||||
|
config: {{}}
|
||||||
|
""".format(
|
||||||
|
built_at=datetime.now().isoformat()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def invalid_config():
|
||||||
|
return yaml.safe_load(
|
||||||
|
"""
|
||||||
|
routing_table: {}
|
||||||
|
api_providers: {}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
||||||
|
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||||
|
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
assert len(result.models) == 1
|
||||||
|
assert len(result.shields) == 1
|
||||||
|
assert len(result.memory_banks) == 1
|
||||||
|
assert "inference" in result.providers
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||||
|
result = parse_and_maybe_upgrade_config(old_config)
|
||||||
|
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
assert len(result.models) == 2
|
||||||
|
assert len(result.shields) == 2
|
||||||
|
assert len(result.memory_banks) == 1
|
||||||
|
assert all(
|
||||||
|
api in result.providers
|
||||||
|
for api in ["inference", "safety", "memory", "telemetry"]
|
||||||
|
)
|
||||||
|
safety_provider = result.providers["safety"][0]
|
||||||
|
assert safety_provider.provider_type == "meta-reference"
|
||||||
|
assert "llama_guard_shield" in safety_provider.config
|
||||||
|
|
||||||
|
inference_providers = result.providers["inference"]
|
||||||
|
assert len(inference_providers) == 2
|
||||||
|
assert set(x.provider_id for x in inference_providers) == {
|
||||||
|
"remote::ollama-00",
|
||||||
|
"meta-reference-01",
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama = inference_providers[0]
|
||||||
|
assert ollama.provider_type == "remote::ollama"
|
||||||
|
assert ollama.config["port"] == 11434
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_and_maybe_upgrade_config(invalid_config)
|
|
@ -8,15 +8,16 @@ from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.exec import run_with_pty
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.utils.exec import run_with_pty
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,6 +96,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
build_config.name,
|
build_config.name,
|
||||||
package_deps.docker_image,
|
package_deps.docker_image,
|
||||||
str(build_file_path),
|
str(build_file_path),
|
||||||
|
str(BUILDS_BASE_DIR / ImageType.docker.value),
|
||||||
" ".join(deps),
|
" ".join(deps),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -23,7 +23,7 @@ if [ "$#" -lt 3 ]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
special_pip_deps="$3"
|
special_pip_deps="$4"
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ if [ "$#" -lt 4 ]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
special_pip_deps="$5"
|
special_pip_deps="$6"
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
@ -18,7 +18,8 @@ build_name="$1"
|
||||||
image_name="llamastack-$build_name"
|
image_name="llamastack-$build_name"
|
||||||
docker_base=$2
|
docker_base=$2
|
||||||
build_file_path=$3
|
build_file_path=$3
|
||||||
pip_dependencies=$4
|
host_build_dir=$4
|
||||||
|
pip_dependencies=$5
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
|
@ -33,7 +34,8 @@ REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
|
||||||
|
|
||||||
TEMP_DIR=$(mktemp -d)
|
TEMP_DIR=$(mktemp -d)
|
||||||
|
|
||||||
llama stack configure $build_file_path --output-dir $REPO_CONFIGS_DIR
|
llama stack configure $build_file_path
|
||||||
|
cp $host_build_dir/$build_name-run.yaml $REPO_CONFIGS_DIR
|
||||||
|
|
||||||
add_to_docker() {
|
add_to_docker() {
|
||||||
local input
|
local input
|
||||||
|
@ -132,6 +134,9 @@ fi
|
||||||
|
|
||||||
set -x
|
set -x
|
||||||
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||||
|
|
||||||
|
# clean up tmp/configs
|
||||||
|
rm -rf $REPO_CONFIGS_DIR
|
||||||
set +x
|
set +x
|
||||||
|
|
||||||
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
|
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
|
||||||
|
|
|
@ -3,171 +3,369 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import textwrap
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from llama_models.sku_list import (
|
||||||
|
llama3_1_family,
|
||||||
|
llama3_2_family,
|
||||||
|
llama3_family,
|
||||||
|
resolve_model,
|
||||||
|
safety_models,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.memory.memory import MemoryBankType
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
stack_apis,
|
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||||
from llama_stack.providers.impls.meta_reference.safety.config import (
|
|
||||||
MetaReferenceShieldType,
|
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_MODELS = (
|
||||||
|
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_routing_entry_type(config_class: Any):
|
def configure_single_provider(
|
||||||
class BaseModelWithConfig(BaseModel):
|
registry: Dict[str, ProviderSpec], provider: Provider
|
||||||
routing_key: str
|
) -> Provider:
|
||||||
config: config_class
|
provider_spec = registry[provider.provider_type]
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
try:
|
||||||
|
if provider.config:
|
||||||
|
existing = config_type(**provider.config)
|
||||||
|
else:
|
||||||
|
existing = None
|
||||||
|
except Exception:
|
||||||
|
existing = None
|
||||||
|
|
||||||
return BaseModelWithConfig
|
cfg = prompt_for_config(config_type, existing)
|
||||||
|
return Provider(
|
||||||
|
provider_id=provider.provider_id,
|
||||||
|
provider_type=provider.provider_type,
|
||||||
|
config=cfg.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
|
|
||||||
"""Get corresponding builtin APIs given provider backed APIs"""
|
|
||||||
res = []
|
|
||||||
for inf in builtin_automatically_routed_apis():
|
|
||||||
if inf.router_api.value in provider_backed_apis:
|
|
||||||
res.append(inf.routing_table_api.value)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: make sure we can deal with existing configuration values correctly
|
|
||||||
# instead of just overwriting them
|
|
||||||
def configure_api_providers(
|
def configure_api_providers(
|
||||||
config: StackRunConfig, spec: DistributionSpec
|
config: StackRunConfig, build_spec: DistributionSpec
|
||||||
) -> StackRunConfig:
|
) -> StackRunConfig:
|
||||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
is_nux = len(config.providers) == 0
|
||||||
# append the bulitin routing APIs
|
|
||||||
apis += get_builtin_apis(apis)
|
|
||||||
|
|
||||||
router_api2builtin_api = {
|
if is_nux:
|
||||||
inf.router_api.value: inf.routing_table_api.value
|
print(
|
||||||
for inf in builtin_automatically_routed_apis()
|
textwrap.dedent(
|
||||||
}
|
"""
|
||||||
|
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||||
|
we need to configure the providers (implementations) you want to use for these APIs.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
provider_registry = get_provider_registry()
|
||||||
|
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||||
|
|
||||||
apis = [v.value for v in stack_apis()]
|
if config.apis:
|
||||||
all_providers = get_provider_registry()
|
apis_to_serve = config.apis
|
||||||
|
else:
|
||||||
|
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||||
|
|
||||||
# configure simple case for with non-routing providers to api_providers
|
for api_str in apis_to_serve:
|
||||||
for api_str in spec.providers.keys():
|
api = Api(api_str)
|
||||||
if api_str not in apis:
|
if api in builtin_apis:
|
||||||
|
continue
|
||||||
|
if api not in provider_registry:
|
||||||
raise ValueError(f"Unknown API `{api_str}`")
|
raise ValueError(f"Unknown API `{api_str}`")
|
||||||
|
|
||||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
existing_providers = config.providers.get(api_str, [])
|
||||||
api = Api(api_str)
|
if existing_providers:
|
||||||
|
|
||||||
p = spec.providers[api_str]
|
|
||||||
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
|
|
||||||
|
|
||||||
if isinstance(p, list):
|
|
||||||
cprint(
|
cprint(
|
||||||
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
|
f"Re-configuring existing providers for API `{api_str}`...",
|
||||||
"yellow",
|
"green",
|
||||||
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
p = p[0]
|
updated_providers = []
|
||||||
|
for p in existing_providers:
|
||||||
provider_spec = all_providers[api][p]
|
print(f"> Configuring provider `({p.provider_type})`")
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
updated_providers.append(
|
||||||
try:
|
configure_single_provider(provider_registry[api], p)
|
||||||
provider_config = config.api_providers.get(api_str)
|
|
||||||
if provider_config:
|
|
||||||
existing = config_type(**provider_config.config)
|
|
||||||
else:
|
|
||||||
existing = None
|
|
||||||
except Exception:
|
|
||||||
existing = None
|
|
||||||
cfg = prompt_for_config(config_type, existing)
|
|
||||||
|
|
||||||
if api_str in router_api2builtin_api:
|
|
||||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
|
||||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
|
||||||
routing_entries = []
|
|
||||||
if api_str == "inference":
|
|
||||||
if hasattr(cfg, "model"):
|
|
||||||
routing_key = cfg.model
|
|
||||||
else:
|
|
||||||
routing_key = prompt(
|
|
||||||
"> Please enter the supported model your provider has for inference: ",
|
|
||||||
default="Meta-Llama3.1-8B-Instruct",
|
|
||||||
)
|
|
||||||
routing_entries.append(
|
|
||||||
RoutableProviderConfig(
|
|
||||||
routing_key=routing_key,
|
|
||||||
provider_type=p,
|
|
||||||
config=cfg.dict(),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
print("")
|
||||||
if api_str == "safety":
|
|
||||||
# TODO: add support for other safety providers, and simplify safety provider config
|
|
||||||
if p == "meta-reference":
|
|
||||||
routing_entries.append(
|
|
||||||
RoutableProviderConfig(
|
|
||||||
routing_key=[s.value for s in MetaReferenceShieldType],
|
|
||||||
provider_type=p,
|
|
||||||
config=cfg.dict(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cprint(
|
|
||||||
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
|
|
||||||
"yellow",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
|
||||||
routing_entries.append(
|
|
||||||
RoutableProviderConfig(
|
|
||||||
routing_key=routing_key,
|
|
||||||
provider_type=p,
|
|
||||||
config=cfg.dict(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if api_str == "memory":
|
|
||||||
bank_types = list([x.value for x in MemoryBankType])
|
|
||||||
routing_key = prompt(
|
|
||||||
"> Please enter the supported memory bank type your provider has for memory: ",
|
|
||||||
default="vector",
|
|
||||||
validator=Validator.from_callable(
|
|
||||||
lambda x: x in bank_types,
|
|
||||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
|
||||||
bank_types
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
routing_entries.append(
|
|
||||||
RoutableProviderConfig(
|
|
||||||
routing_key=routing_key,
|
|
||||||
provider_type=p,
|
|
||||||
config=cfg.dict(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
config.routing_table[api_str] = routing_entries
|
|
||||||
config.api_providers[api_str] = PlaceholderProviderConfig(
|
|
||||||
providers=p if isinstance(p, list) else [p]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
config.api_providers[api_str] = GenericProviderConfig(
|
# we are newly configuring this API
|
||||||
provider_type=p,
|
plist = build_spec.providers.get(api_str, [])
|
||||||
config=cfg.dict(),
|
plist = plist if isinstance(plist, list) else [plist]
|
||||||
)
|
|
||||||
|
|
||||||
|
if not plist:
|
||||||
|
raise ValueError(f"No provider configured for API {api_str}?")
|
||||||
|
|
||||||
|
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||||
|
updated_providers = []
|
||||||
|
for i, provider_type in enumerate(plist):
|
||||||
|
print(f"> Configuring provider `({provider_type})`")
|
||||||
|
updated_providers.append(
|
||||||
|
configure_single_provider(
|
||||||
|
provider_registry[api],
|
||||||
|
Provider(
|
||||||
|
provider_id=(
|
||||||
|
f"{provider_type}-{i:02d}"
|
||||||
|
if len(plist) > 1
|
||||||
|
else provider_type
|
||||||
|
),
|
||||||
|
provider_type=provider_type,
|
||||||
|
config={},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
config.providers[api_str] = updated_providers
|
||||||
|
|
||||||
|
if is_nux:
|
||||||
|
print(
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
=========================================================================================
|
||||||
|
Now let's configure the `objects` you will be serving via the stack. These are:
|
||||||
|
|
||||||
|
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
|
||||||
|
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
|
||||||
|
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
|
||||||
|
|
||||||
|
This wizard will guide you through setting up one of each of these objects. You can
|
||||||
|
always add more later by editing the run.yaml file.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
object_types = {
|
||||||
|
"models": (ModelDef, configure_models, "inference"),
|
||||||
|
"shields": (ShieldDef, configure_shields, "safety"),
|
||||||
|
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||||
|
}
|
||||||
|
safety_providers = config.providers.get("safety", [])
|
||||||
|
|
||||||
|
for otype, (odef, config_method, api_str) in object_types.items():
|
||||||
|
existing_objects = getattr(config, otype)
|
||||||
|
|
||||||
|
if existing_objects:
|
||||||
|
cprint(
|
||||||
|
f"{len(existing_objects)} {otype} exist. Skipping...",
|
||||||
|
"blue",
|
||||||
|
attrs=["bold"],
|
||||||
|
)
|
||||||
|
updated_objects = existing_objects
|
||||||
|
else:
|
||||||
|
providers = config.providers.get(api_str, [])
|
||||||
|
if not providers:
|
||||||
|
updated_objects = []
|
||||||
|
else:
|
||||||
|
# we are newly configuring this API
|
||||||
|
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||||
|
updated_objects = config_method(
|
||||||
|
config.providers[api_str], safety_providers
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(config, otype, updated_objects)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
|
||||||
|
if not safety_providers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider = safety_providers[0]
|
||||||
|
assert provider.provider_type == "meta-reference"
|
||||||
|
|
||||||
|
cfg = provider.config["llama_guard_shield"]
|
||||||
|
if not cfg:
|
||||||
|
return None
|
||||||
|
return cfg["model"]
|
||||||
|
|
||||||
|
|
||||||
|
def configure_models(
|
||||||
|
providers: List[Provider], safety_providers: List[Provider]
|
||||||
|
) -> List[ModelDef]:
|
||||||
|
model = prompt(
|
||||||
|
"> Please enter the model you want to serve: ",
|
||||||
|
default="Llama3.2-1B-Instruct",
|
||||||
|
validator=Validator.from_callable(
|
||||||
|
lambda x: resolve_model(x) is not None,
|
||||||
|
error_message="Model must be: {}".format(
|
||||||
|
[x.descriptor() for x in ALLOWED_MODELS]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model = ModelDef(
|
||||||
|
identifier=model,
|
||||||
|
llama_model=model,
|
||||||
|
provider_id=providers[0].provider_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = [model]
|
||||||
|
if llama_guard := get_llama_guard_model(safety_providers):
|
||||||
|
ret.append(
|
||||||
|
ModelDef(
|
||||||
|
identifier=llama_guard,
|
||||||
|
llama_model=llama_guard,
|
||||||
|
provider_id=providers[0].provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def configure_shields(
|
||||||
|
providers: List[Provider], safety_providers: List[Provider]
|
||||||
|
) -> List[ShieldDef]:
|
||||||
|
if get_llama_guard_model(safety_providers):
|
||||||
|
return [
|
||||||
|
ShieldDef(
|
||||||
|
identifier="llama_guard",
|
||||||
|
type="llama_guard",
|
||||||
|
provider_id=providers[0].provider_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def configure_memory_banks(
|
||||||
|
providers: List[Provider], safety_providers: List[Provider]
|
||||||
|
) -> List[MemoryBankDef]:
|
||||||
|
bank_name = prompt(
|
||||||
|
"> Please enter a name for your memory bank: ",
|
||||||
|
default="my-memory-bank",
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
VectorMemoryBankDef(
|
||||||
|
identifier=bank_name,
|
||||||
|
provider_id=providers[0].provider_id,
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade_from_routing_table_to_registry(
|
||||||
|
config_dict: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
def get_providers(entries):
|
||||||
|
return [
|
||||||
|
Provider(
|
||||||
|
provider_id=(
|
||||||
|
f"{entry['provider_type']}-{i:02d}"
|
||||||
|
if len(entries) > 1
|
||||||
|
else entry["provider_type"]
|
||||||
|
),
|
||||||
|
provider_type=entry["provider_type"],
|
||||||
|
config=entry["config"],
|
||||||
|
)
|
||||||
|
for i, entry in enumerate(entries)
|
||||||
|
]
|
||||||
|
|
||||||
|
providers_by_api = {}
|
||||||
|
models = []
|
||||||
|
shields = []
|
||||||
|
memory_banks = []
|
||||||
|
|
||||||
|
routing_table = config_dict.get("routing_table", {})
|
||||||
|
for api_str, entries in routing_table.items():
|
||||||
|
providers = get_providers(entries)
|
||||||
|
providers_by_api[api_str] = providers
|
||||||
|
|
||||||
|
if api_str == "inference":
|
||||||
|
for entry, provider in zip(entries, providers):
|
||||||
|
key = entry["routing_key"]
|
||||||
|
keys = key if isinstance(key, list) else [key]
|
||||||
|
for key in keys:
|
||||||
|
models.append(
|
||||||
|
ModelDef(
|
||||||
|
identifier=key,
|
||||||
|
provider_id=provider.provider_id,
|
||||||
|
llama_model=key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif api_str == "safety":
|
||||||
|
for entry, provider in zip(entries, providers):
|
||||||
|
key = entry["routing_key"]
|
||||||
|
keys = key if isinstance(key, list) else [key]
|
||||||
|
for key in keys:
|
||||||
|
shields.append(
|
||||||
|
ShieldDef(
|
||||||
|
identifier=key,
|
||||||
|
type=ShieldType.llama_guard.value,
|
||||||
|
provider_id=provider.provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif api_str == "memory":
|
||||||
|
for entry, provider in zip(entries, providers):
|
||||||
|
key = entry["routing_key"]
|
||||||
|
keys = key if isinstance(key, list) else [key]
|
||||||
|
for key in keys:
|
||||||
|
# we currently only support Vector memory banks so this is OK
|
||||||
|
memory_banks.append(
|
||||||
|
VectorMemoryBankDef(
|
||||||
|
identifier=key,
|
||||||
|
provider_id=provider.provider_id,
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
config_dict["models"] = models
|
||||||
|
config_dict["shields"] = shields
|
||||||
|
config_dict["memory_banks"] = memory_banks
|
||||||
|
|
||||||
|
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||||
|
if provider_map:
|
||||||
|
for api_str, provider in provider_map.items():
|
||||||
|
if isinstance(provider, dict) and "provider_type" in provider:
|
||||||
|
providers_by_api[api_str] = [
|
||||||
|
Provider(
|
||||||
|
provider_id=f"{provider['provider_type']}",
|
||||||
|
provider_type=provider["provider_type"],
|
||||||
|
config=provider["config"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
config_dict["providers"] = providers_by_api
|
||||||
|
|
||||||
|
config_dict.pop("routing_table", None)
|
||||||
|
config_dict.pop("api_providers", None)
|
||||||
|
config_dict.pop("provider_map", None)
|
||||||
|
|
||||||
|
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||||
|
config_dict.pop("apis_to_serve", None)
|
||||||
|
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
||||||
|
version = config_dict.get("version", None)
|
||||||
|
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||||
|
return StackRunConfig(**config_dict)
|
||||||
|
|
||||||
|
if "models" not in config_dict:
|
||||||
|
print("Upgrading config...")
|
||||||
|
config_dict = upgrade_from_routing_table_to_registry(config_dict)
|
||||||
|
|
||||||
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
config_dict["built_at"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
return StackRunConfig(**config_dict)
|
||||||
|
|
|
@ -11,28 +11,32 @@ from typing import Dict, List, Optional, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.memory import Memory
|
||||||
|
from llama_stack.apis.safety import Safety
|
||||||
|
|
||||||
|
|
||||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
|
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
|
||||||
|
|
||||||
RoutingKey = Union[str, List[str]]
|
RoutingKey = Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
class GenericProviderConfig(BaseModel):
|
RoutableObject = Union[
|
||||||
provider_type: str
|
ModelDef,
|
||||||
config: Dict[str, Any]
|
ShieldDef,
|
||||||
|
MemoryBankDef,
|
||||||
|
]
|
||||||
|
|
||||||
|
RoutedProtocol = Union[
|
||||||
class RoutableProviderConfig(GenericProviderConfig):
|
Inference,
|
||||||
routing_key: RoutingKey
|
Safety,
|
||||||
|
Memory,
|
||||||
|
]
|
||||||
class PlaceholderProviderConfig(BaseModel):
|
|
||||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
|
||||||
|
|
||||||
providers: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /inference, /safety
|
# Example: /inference, /safety
|
||||||
|
@ -53,18 +57,17 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
|
|
||||||
|
|
||||||
# Example: /models, /shields
|
# Example: /models, /shields
|
||||||
@json_schema_type
|
|
||||||
class RoutingTableProviderSpec(ProviderSpec):
|
class RoutingTableProviderSpec(ProviderSpec):
|
||||||
provider_type: str = "routing_table"
|
provider_type: str = "routing_table"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
docker_image: Optional[str] = None
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
inner_specs: List[ProviderSpec]
|
router_api: Api
|
||||||
|
registry: List[RoutableObject]
|
||||||
module: str
|
module: str
|
||||||
pip_packages: List[str] = Field(default_factory=list)
|
pip_packages: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DistributionSpec(BaseModel):
|
class DistributionSpec(BaseModel):
|
||||||
description: Optional[str] = Field(
|
description: Optional[str] = Field(
|
||||||
default="",
|
default="",
|
||||||
|
@ -80,7 +83,12 @@ in the runtime configuration to help route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class Provider(BaseModel):
|
||||||
|
provider_id: str
|
||||||
|
provider_type: str
|
||||||
|
config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
|
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
built_at: datetime
|
built_at: datetime
|
||||||
|
@ -100,36 +108,39 @@ this could be just a hash
|
||||||
default=None,
|
default=None,
|
||||||
description="Reference to the conda environment if this package refers to a conda environment",
|
description="Reference to the conda environment if this package refers to a conda environment",
|
||||||
)
|
)
|
||||||
apis_to_serve: List[str] = Field(
|
apis: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
description="""
|
description="""
|
||||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||||
)
|
)
|
||||||
|
|
||||||
api_providers: Dict[
|
providers: Dict[str, List[Provider]] = Field(
|
||||||
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
|
|
||||||
] = Field(
|
|
||||||
description="""
|
description="""
|
||||||
Provider configurations for each of the APIs provided by this package.
|
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||||
|
can be instantiated multiple times (with different configs) if necessary.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="""
|
|
||||||
|
|
||||||
E.g. The following is a ProviderRoutingEntry for models:
|
models: List[ModelDef] = Field(
|
||||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
description="""
|
||||||
provider_type: meta-reference
|
List of model definitions to serve. This list may get extended by
|
||||||
config:
|
/models/register API calls at runtime.
|
||||||
model: Meta-Llama3.1-8B-Instruct
|
""",
|
||||||
quantization: null
|
)
|
||||||
torch_seed: null
|
shields: List[ShieldDef] = Field(
|
||||||
max_seq_len: 4096
|
description="""
|
||||||
max_batch_size: 1
|
List of shield definitions to serve. This list may get extended by
|
||||||
""",
|
/shields/register API calls at runtime.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
memory_banks: List[MemoryBankDef] = Field(
|
||||||
|
description="""
|
||||||
|
List of memory bank definitions to serve. This list may get extended by
|
||||||
|
/memory_banks/register API calls at runtime.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
name: str
|
name: str
|
||||||
|
|
|
@ -6,45 +6,58 @@
|
||||||
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from llama_stack.apis.inspect import * # noqa: F403
|
from llama_stack.apis.inspect import * # noqa: F403
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
class DistributionInspectConfig(BaseModel):
|
||||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
run_config: StackRunConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config, deps):
|
||||||
|
impl = DistributionInspectImpl(config, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
class DistributionInspectImpl(Inspect):
|
class DistributionInspectImpl(Inspect):
|
||||||
def __init__(self):
|
def __init__(self, config, deps):
|
||||||
|
self.config = config
|
||||||
|
self.deps = deps
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||||
|
run_config = self.config.run_config
|
||||||
|
|
||||||
ret = {}
|
ret = {}
|
||||||
all_providers = get_provider_registry()
|
for api, providers in run_config.providers.items():
|
||||||
for api, providers in all_providers.items():
|
ret[api] = [
|
||||||
ret[api.value] = [
|
|
||||||
ProviderInfo(
|
ProviderInfo(
|
||||||
|
provider_id=p.provider_id,
|
||||||
provider_type=p.provider_type,
|
provider_type=p.provider_type,
|
||||||
description="Passthrough" if is_passthrough(p) else "",
|
|
||||||
)
|
)
|
||||||
for p in providers.values()
|
for p in providers
|
||||||
]
|
]
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||||
|
run_config = self.config.run_config
|
||||||
|
|
||||||
ret = {}
|
ret = {}
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in all_endpoints.items():
|
||||||
|
providers = run_config.providers.get(api.value, [])
|
||||||
ret[api.value] = [
|
ret[api.value] = [
|
||||||
RouteInfo(
|
RouteInfo(
|
||||||
route=e.route,
|
route=e.route,
|
||||||
method=e.method,
|
method=e.method,
|
||||||
providers=[],
|
provider_types=[p.provider_type for p in providers],
|
||||||
)
|
)
|
||||||
for e in endpoints
|
for e in endpoints
|
||||||
]
|
]
|
||||||
|
|
|
@ -13,138 +13,207 @@ from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.inspect import DistributionInspectImpl
|
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
||||||
|
class ProviderWithSpec(Provider):
|
||||||
|
spec: ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Does two things:
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
all_providers = get_provider_registry()
|
all_api_providers = get_provider_registry()
|
||||||
specs = {}
|
|
||||||
configs = {}
|
|
||||||
|
|
||||||
for api_str, config in run_config.api_providers.items():
|
routing_table_apis = set(
|
||||||
api = Api(api_str)
|
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||||
|
|
||||||
# TODO: check that these APIs are not in the routing table part of the config
|
|
||||||
providers = all_providers[api]
|
|
||||||
|
|
||||||
# skip checks for API whose provider config is specified in routing_table
|
|
||||||
if isinstance(config, PlaceholderProviderConfig):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if config.provider_type not in providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider `{config.provider_type}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
specs[api] = providers[config.provider_type]
|
|
||||||
configs[api] = config
|
|
||||||
|
|
||||||
apis_to_serve = run_config.apis_to_serve or set(
|
|
||||||
list(specs.keys()) + list(run_config.routing_table.keys())
|
|
||||||
)
|
)
|
||||||
|
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||||
|
|
||||||
|
providers_with_specs = {}
|
||||||
|
|
||||||
|
for api_str, providers in run_config.providers.items():
|
||||||
|
api = Api(api_str)
|
||||||
|
if api in routing_table_apis:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||||
|
)
|
||||||
|
|
||||||
|
specs = {}
|
||||||
|
for provider in providers:
|
||||||
|
if provider.provider_type not in all_api_providers[api]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
p = all_api_providers[api][provider.provider_type]
|
||||||
|
p.deps__ = [a.value for a in p.api_dependencies]
|
||||||
|
spec = ProviderWithSpec(
|
||||||
|
spec=p,
|
||||||
|
**(provider.dict()),
|
||||||
|
)
|
||||||
|
specs[provider.provider_id] = spec
|
||||||
|
|
||||||
|
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||||
|
providers_with_specs[key] = specs
|
||||||
|
|
||||||
|
apis_to_serve = run_config.apis or set(
|
||||||
|
list(providers_with_specs.keys())
|
||||||
|
+ [x.value for x in routing_table_apis]
|
||||||
|
+ [x.value for x in router_apis]
|
||||||
|
)
|
||||||
|
|
||||||
for info in builtin_automatically_routed_apis():
|
for info in builtin_automatically_routed_apis():
|
||||||
source_api = info.routing_table_api
|
|
||||||
|
|
||||||
assert (
|
|
||||||
source_api not in specs
|
|
||||||
), f"Routing table API {source_api} specified in wrong place?"
|
|
||||||
assert (
|
|
||||||
info.router_api not in specs
|
|
||||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
|
||||||
|
|
||||||
if info.router_api.value not in apis_to_serve:
|
if info.router_api.value not in apis_to_serve:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if info.router_api.value not in run_config.routing_table:
|
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
||||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
|
||||||
|
|
||||||
routing_table = run_config.routing_table[info.router_api.value]
|
|
||||||
|
|
||||||
providers = all_providers[info.router_api]
|
|
||||||
|
|
||||||
inner_specs = []
|
|
||||||
inner_deps = []
|
inner_deps = []
|
||||||
for rt_entry in routing_table:
|
registry = getattr(run_config, info.routing_table_api.value)
|
||||||
if rt_entry.provider_type not in providers:
|
for entry in registry:
|
||||||
|
if entry.provider_id not in available_providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
|
f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
|
||||||
)
|
)
|
||||||
inner_specs.append(providers[rt_entry.provider_type])
|
|
||||||
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
|
|
||||||
|
|
||||||
specs[source_api] = RoutingTableProviderSpec(
|
provider = available_providers[entry.provider_id]
|
||||||
api=source_api,
|
inner_deps.extend(provider.spec.api_dependencies)
|
||||||
module="llama_stack.distribution.routers",
|
|
||||||
api_dependencies=inner_deps,
|
providers_with_specs[info.routing_table_api.value] = {
|
||||||
inner_specs=inner_specs,
|
"__builtin__": ProviderWithSpec(
|
||||||
|
provider_id="__builtin__",
|
||||||
|
provider_type="__routing_table__",
|
||||||
|
config={},
|
||||||
|
spec=RoutingTableProviderSpec(
|
||||||
|
api=info.routing_table_api,
|
||||||
|
router_api=info.router_api,
|
||||||
|
registry=registry,
|
||||||
|
module="llama_stack.distribution.routers",
|
||||||
|
api_dependencies=inner_deps,
|
||||||
|
deps__=(
|
||||||
|
[x.value for x in inner_deps]
|
||||||
|
+ [f"inner-{info.router_api.value}"]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers_with_specs[info.router_api.value] = {
|
||||||
|
"__builtin__": ProviderWithSpec(
|
||||||
|
provider_id="__builtin__",
|
||||||
|
provider_type="__autorouted__",
|
||||||
|
config={},
|
||||||
|
spec=AutoRoutedProviderSpec(
|
||||||
|
api=info.router_api,
|
||||||
|
module="llama_stack.distribution.routers",
|
||||||
|
routing_table_api=info.routing_table_api,
|
||||||
|
api_dependencies=[info.routing_table_api],
|
||||||
|
deps__=([info.routing_table_api.value]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
sorted_providers = topological_sort(
|
||||||
|
{k: v.values() for k, v in providers_with_specs.items()}
|
||||||
|
)
|
||||||
|
apis = [x[1].spec.api for x in sorted_providers]
|
||||||
|
sorted_providers.append(
|
||||||
|
(
|
||||||
|
"inspect",
|
||||||
|
ProviderWithSpec(
|
||||||
|
provider_id="__builtin__",
|
||||||
|
provider_type="__builtin__",
|
||||||
|
config={
|
||||||
|
"run_config": run_config.dict(),
|
||||||
|
},
|
||||||
|
spec=InlineProviderSpec(
|
||||||
|
api=Api.inspect,
|
||||||
|
provider_type="__builtin__",
|
||||||
|
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||||
|
module="llama_stack.distribution.inspect",
|
||||||
|
api_dependencies=apis,
|
||||||
|
deps__=([x.value for x in apis]),
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
configs[source_api] = routing_table
|
|
||||||
|
|
||||||
specs[info.router_api] = AutoRoutedProviderSpec(
|
|
||||||
api=info.router_api,
|
|
||||||
module="llama_stack.distribution.routers",
|
|
||||||
routing_table_api=source_api,
|
|
||||||
api_dependencies=[source_api],
|
|
||||||
)
|
|
||||||
configs[info.router_api] = {}
|
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
|
||||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
|
||||||
for spec in sorted_specs:
|
|
||||||
print(f" {spec.api}: {spec.provider_type}")
|
|
||||||
print("")
|
|
||||||
impls = {}
|
|
||||||
for spec in sorted_specs:
|
|
||||||
api = spec.api
|
|
||||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
|
||||||
impl = await instantiate_provider(spec, deps, configs[api])
|
|
||||||
|
|
||||||
impls[api] = impl
|
|
||||||
|
|
||||||
impls[Api.inspect] = DistributionInspectImpl()
|
|
||||||
specs[Api.inspect] = InlineProviderSpec(
|
|
||||||
api=Api.inspect,
|
|
||||||
provider_type="__distribution_builtin__",
|
|
||||||
config_class="",
|
|
||||||
module="",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls, specs
|
print(f"Resolved {len(sorted_providers)} providers in topological order")
|
||||||
|
for api_str, provider in sorted_providers:
|
||||||
|
print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
impls = {}
|
||||||
|
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
|
for api_str, provider in sorted_providers:
|
||||||
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
|
|
||||||
|
inner_impls = {}
|
||||||
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
|
inner_impls = inner_impls_by_provider_id[
|
||||||
|
f"inner-{provider.spec.router_api.value}"
|
||||||
|
]
|
||||||
|
|
||||||
|
impl = await instantiate_provider(
|
||||||
|
provider,
|
||||||
|
deps,
|
||||||
|
inner_impls,
|
||||||
|
)
|
||||||
|
# TODO: ugh slightly redesign this shady looking code
|
||||||
|
if "inner-" in api_str:
|
||||||
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
|
else:
|
||||||
|
api = Api(api_str)
|
||||||
|
impls[api] = impl
|
||||||
|
|
||||||
|
return impls
|
||||||
|
|
||||||
|
|
||||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
def topological_sort(
|
||||||
by_id = {x.api: x for x in providers}
|
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||||
|
) -> List[ProviderWithSpec]:
|
||||||
|
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||||
|
api_str, providers = kv
|
||||||
|
visited.add(api_str)
|
||||||
|
|
||||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
deps = []
|
||||||
visited.add(a.api)
|
for provider in providers:
|
||||||
|
for dep in provider.spec.deps__:
|
||||||
|
deps.append(dep)
|
||||||
|
|
||||||
for api in a.api_dependencies:
|
for dep in deps:
|
||||||
if api not in visited:
|
if dep not in visited:
|
||||||
dfs(by_id[api], visited, stack)
|
dfs((dep, providers_with_specs[dep]), visited, stack)
|
||||||
|
|
||||||
stack.append(a.api)
|
stack.append(api_str)
|
||||||
|
|
||||||
visited = set()
|
visited = set()
|
||||||
stack = []
|
stack = []
|
||||||
|
|
||||||
for a in providers:
|
for api_str, providers in providers_with_specs.items():
|
||||||
if a.api not in visited:
|
if api_str not in visited:
|
||||||
dfs(a, visited, stack)
|
dfs((api_str, providers), visited, stack)
|
||||||
|
|
||||||
return [by_id[x] for x in stack]
|
flattened = []
|
||||||
|
for api_str in stack:
|
||||||
|
for provider in providers_with_specs[api_str]:
|
||||||
|
flattened.append((api_str, provider))
|
||||||
|
return flattened
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider_spec: ProviderSpec,
|
provider: ProviderWithSpec,
|
||||||
deps: Dict[str, Any],
|
deps: Dict[str, Any],
|
||||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
inner_impls: Dict[str, Any],
|
||||||
):
|
):
|
||||||
|
provider_spec = provider.spec
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
args = []
|
args = []
|
||||||
|
@ -154,9 +223,8 @@ async def instantiate_provider(
|
||||||
else:
|
else:
|
||||||
method = "get_client_impl"
|
method = "get_client_impl"
|
||||||
|
|
||||||
assert isinstance(provider_config, GenericProviderConfig)
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider_config.config)
|
config = config_type(**provider.config)
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
@ -166,31 +234,18 @@ async def instantiate_provider(
|
||||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
assert isinstance(provider_config, List)
|
|
||||||
routing_table = provider_config
|
|
||||||
|
|
||||||
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
|
|
||||||
inner_impls = []
|
|
||||||
for routing_entry in routing_table:
|
|
||||||
impl = await instantiate_provider(
|
|
||||||
inner_specs[routing_entry.provider_type],
|
|
||||||
deps,
|
|
||||||
routing_entry,
|
|
||||||
)
|
|
||||||
inner_impls.append((routing_entry.routing_key, impl))
|
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
assert isinstance(provider_config, GenericProviderConfig)
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider_config.config)
|
config = config_type(**provider.config)
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
|
|
||||||
fn = getattr(module, method)
|
fn = getattr(module, method)
|
||||||
impl = await fn(*args)
|
impl = await fn(*args)
|
||||||
|
impl.__provider_id__ = provider.provider_id
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,23 +4,22 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from .routing_tables import (
|
||||||
|
MemoryBanksRoutingTable,
|
||||||
|
ModelsRoutingTable,
|
||||||
|
ShieldsRoutingTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
api: Api,
|
api: Api,
|
||||||
inner_impls: List[Tuple[str, Any]],
|
registry: List[RoutableObject],
|
||||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||||
_deps,
|
_deps,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from .routing_tables import (
|
|
||||||
MemoryBanksRoutingTable,
|
|
||||||
ModelsRoutingTable,
|
|
||||||
ShieldsRoutingTable,
|
|
||||||
)
|
|
||||||
|
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"memory_banks": MemoryBanksRoutingTable,
|
"memory_banks": MemoryBanksRoutingTable,
|
||||||
"models": ModelsRoutingTable,
|
"models": ModelsRoutingTable,
|
||||||
|
@ -29,7 +28,7 @@ async def get_routing_table_impl(
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
impl = api_to_tables[api.value](registry, impls_by_provider_id)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
|
@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class MemoryRouter(Memory):
|
class MemoryRouter(Memory):
|
||||||
"""Routes to an provider based on the memory bank type"""
|
"""Routes to an provider based on the memory bank identifier"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.bank_id_to_type = {}
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -29,32 +28,8 @@ class MemoryRouter(Memory):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
await self.routing_table.register_memory_bank(memory_bank)
|
||||||
if not bank_type:
|
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(bank_type)
|
|
||||||
if not provider:
|
|
||||||
raise ValueError(f"Could not find provider for {bank_type}")
|
|
||||||
return provider
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
config: MemoryBankConfig,
|
|
||||||
url: Optional[URL] = None,
|
|
||||||
) -> MemoryBank:
|
|
||||||
bank_type = config.type
|
|
||||||
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
|
|
||||||
name, config, url
|
|
||||||
)
|
|
||||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
provider = self.get_provider_from_bank_id(bank_id)
|
|
||||||
return await provider.get_memory_bank(bank_id)
|
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -62,7 +37,7 @@ class MemoryRouter(Memory):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
return await self.get_provider_from_bank_id(bank_id).insert_documents(
|
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
|
||||||
bank_id, documents, ttl_seconds
|
bank_id, documents, ttl_seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -72,7 +47,7 @@ class MemoryRouter(Memory):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
return await self.get_provider_from_bank_id(bank_id).query_documents(
|
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
||||||
bank_id, query, params
|
bank_id, query, params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -92,7 +67,10 @@ class InferenceRouter(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def chat_completion(
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
|
await self.routing_table.register_model(model)
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -113,27 +91,32 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
provider = self.routing_table.get_provider_impl(model)
|
||||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
if stream:
|
||||||
**params
|
return (chunk async for chunk in provider.chat_completion(**params))
|
||||||
):
|
else:
|
||||||
yield chunk
|
return provider.chat_completion(**params)
|
||||||
|
|
||||||
async def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> AsyncGenerator:
|
||||||
return await self.routing_table.get_provider_impl(model).completion(
|
provider = self.routing_table.get_provider_impl(model)
|
||||||
|
params = dict(
|
||||||
model=model,
|
model=model,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
if stream:
|
||||||
|
return (chunk async for chunk in provider.completion(**params))
|
||||||
|
else:
|
||||||
|
return provider.completion(**params)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -159,6 +142,9 @@ class SafetyRouter(Safety):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
|
await self.routing_table.register_shield(shield)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_type: str,
|
shield_type: str,
|
||||||
|
|
|
@ -4,9 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
@ -16,129 +15,129 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
return p.__provider_spec__.api
|
||||||
|
|
||||||
|
|
||||||
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.inference:
|
||||||
|
await p.register_model(obj)
|
||||||
|
elif api == Api.safety:
|
||||||
|
await p.register_shield(obj)
|
||||||
|
elif api == Api.memory:
|
||||||
|
await p.register_memory_bank(obj)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this routing table maintains state in memory purely. We need to
|
||||||
|
# add persistence to it when we add dynamic registration of objects.
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inner_impls: List[Tuple[RoutingKey, Any]],
|
registry: List[RoutableObject],
|
||||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.unique_providers = []
|
for obj in registry:
|
||||||
self.providers = {}
|
if obj.provider_id not in impls_by_provider_id:
|
||||||
self.routing_keys = []
|
print(f"{impls_by_provider_id=}")
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
||||||
|
)
|
||||||
|
|
||||||
for key, impl in inner_impls:
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
keys = key if isinstance(key, list) else [key]
|
self.registry = registry
|
||||||
self.unique_providers.append((keys, impl))
|
|
||||||
|
|
||||||
for k in keys:
|
for p in self.impls_by_provider_id.values():
|
||||||
if k in self.providers:
|
api = get_impl_api(p)
|
||||||
raise ValueError(f"Duplicate routing key {k}")
|
if api == Api.inference:
|
||||||
self.providers[k] = impl
|
p.model_store = self
|
||||||
self.routing_keys.append(k)
|
elif api == Api.safety:
|
||||||
|
p.shield_store = self
|
||||||
|
elif api == Api.memory:
|
||||||
|
p.memory_bank_store = self
|
||||||
|
|
||||||
self.routing_table_config = routing_table_config
|
self.routing_key_to_object = {}
|
||||||
|
for obj in self.registry:
|
||||||
|
self.routing_key_to_object[obj.identifier] = obj
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
for keys, p in self.unique_providers:
|
for obj in self.registry:
|
||||||
spec = p.__provider_spec__
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
await register_object_with_provider(obj, p)
|
||||||
continue
|
|
||||||
|
|
||||||
await p.validate_routing_keys(keys)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for _, p in self.unique_providers:
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Any:
|
def get_provider_impl(self, routing_key: str) -> Any:
|
||||||
if routing_key not in self.providers:
|
if routing_key not in self.routing_key_to_object:
|
||||||
raise ValueError(f"Could not find provider for {routing_key}")
|
raise ValueError(f"`{routing_key}` not registered")
|
||||||
return self.providers[routing_key]
|
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
obj = self.routing_key_to_object[routing_key]
|
||||||
return self.routing_keys
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
return self.impls_by_provider_id[obj.provider_id]
|
||||||
for entry in self.routing_table_config:
|
|
||||||
if entry.routing_key == routing_key:
|
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
|
||||||
return entry
|
for obj in self.registry:
|
||||||
|
if obj.identifier == identifier:
|
||||||
|
return obj
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def register_object(self, obj: RoutableObject):
|
||||||
|
if obj.identifier in self.routing_key_to_object:
|
||||||
|
print(f"`{obj.identifier}` is already registered")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not obj.provider_id:
|
||||||
|
provider_ids = list(self.impls_by_provider_id.keys())
|
||||||
|
if not provider_ids:
|
||||||
|
raise ValueError("No providers found")
|
||||||
|
|
||||||
|
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
|
||||||
|
obj.provider_id = provider_ids[0]
|
||||||
|
else:
|
||||||
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
await register_object_with_provider(obj, p)
|
||||||
|
|
||||||
|
self.routing_key_to_object[obj.identifier] = obj
|
||||||
|
self.registry.append(obj)
|
||||||
|
|
||||||
|
# TODO: persist this to a store
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
async def list_models(self) -> List[ModelDef]:
|
||||||
|
return self.registry
|
||||||
|
|
||||||
async def list_models(self) -> List[ModelServingSpec]:
|
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||||
specs = []
|
return self.get_object_by_identifier(identifier)
|
||||||
for entry in self.routing_table_config:
|
|
||||||
model_id = entry.routing_key
|
|
||||||
specs.append(
|
|
||||||
ModelServingSpec(
|
|
||||||
llama_model=resolve_model(model_id),
|
|
||||||
provider_config=entry,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return specs
|
|
||||||
|
|
||||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
for entry in self.routing_table_config:
|
await self.register_object(model)
|
||||||
if entry.routing_key == core_model_id:
|
|
||||||
return ModelServingSpec(
|
|
||||||
llama_model=resolve_model(core_model_id),
|
|
||||||
provider_config=entry,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
|
return self.registry
|
||||||
|
|
||||||
async def list_shields(self) -> List[ShieldSpec]:
|
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
|
||||||
specs = []
|
return self.get_object_by_identifier(shield_type)
|
||||||
for entry in self.routing_table_config:
|
|
||||||
if isinstance(entry.routing_key, list):
|
|
||||||
for k in entry.routing_key:
|
|
||||||
specs.append(
|
|
||||||
ShieldSpec(
|
|
||||||
shield_type=k,
|
|
||||||
provider_config=entry,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
specs.append(
|
|
||||||
ShieldSpec(
|
|
||||||
shield_type=entry.routing_key,
|
|
||||||
provider_config=entry,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return specs
|
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
for entry in self.routing_table_config:
|
await self.register_object(shield)
|
||||||
if entry.routing_key == shield_type:
|
|
||||||
return ShieldSpec(
|
|
||||||
shield_type=entry.routing_key,
|
|
||||||
provider_config=entry,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
|
return self.registry
|
||||||
|
|
||||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||||
specs = []
|
return self.get_object_by_identifier(identifier)
|
||||||
for entry in self.routing_table_config:
|
|
||||||
specs.append(
|
|
||||||
MemoryBankSpec(
|
|
||||||
bank_type=entry.routing_key,
|
|
||||||
provider_config=entry,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return specs
|
|
||||||
|
|
||||||
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||||
for entry in self.routing_table_config:
|
await self.register_object(bank)
|
||||||
if entry.routing_key == bank_type:
|
|
||||||
return MemoryBankSpec(
|
|
||||||
bank_type=entry.routing_key,
|
|
||||||
provider_config=entry,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
|
@ -5,18 +5,15 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from collections.abc import (
|
|
||||||
AsyncGenerator as AsyncGeneratorABC,
|
|
||||||
AsyncIterator as AsyncIteratorABC,
|
|
||||||
)
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from ssl import SSLError
|
from ssl import SSLError
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -43,20 +40,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
|
||||||
if hasattr(typ, "__origin__"):
|
|
||||||
origin = typ.__origin__
|
|
||||||
if isinstance(origin, type):
|
|
||||||
return issubclass(
|
|
||||||
origin,
|
|
||||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
return isinstance(
|
|
||||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sse_event(data: Any) -> str:
|
def create_sse_event(data: Any) -> str:
|
||||||
if isinstance(data, BaseModel):
|
if isinstance(data, BaseModel):
|
||||||
data = data.json()
|
data = data.json()
|
||||||
|
@ -169,11 +152,20 @@ async def passthrough(
|
||||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||||
|
|
||||||
|
|
||||||
def handle_sigint(*args, **kwargs):
|
def handle_sigint(app, *args, **kwargs):
|
||||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||||
|
|
||||||
|
async def run_shutdown():
|
||||||
|
for impl in app.__llama_stack_impls__.values():
|
||||||
|
print(f"Shutting down {impl}")
|
||||||
|
await impl.shutdown()
|
||||||
|
|
||||||
|
asyncio.run(run_shutdown())
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
for task in asyncio.all_tasks(loop):
|
for task in asyncio.all_tasks(loop):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
loop.stop()
|
loop.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,7 +173,10 @@ def handle_sigint(*args, **kwargs):
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
print("Starting up")
|
print("Starting up")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
print("Shutting down")
|
print("Shutting down")
|
||||||
|
for impl in app.__llama_stack_impls__.values():
|
||||||
|
await impl.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_passthrough(
|
def create_dynamic_passthrough(
|
||||||
|
@ -193,65 +188,59 @@ def create_dynamic_passthrough(
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
|
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||||
|
return kwargs.get("stream", False)
|
||||||
|
|
||||||
|
|
||||||
|
async def maybe_await(value):
|
||||||
|
if inspect.iscoroutine(value):
|
||||||
|
return await value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
async def sse_generator(event_gen):
|
||||||
|
try:
|
||||||
|
async for item in event_gen:
|
||||||
|
yield create_sse_event(item)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print("Generator cancelled")
|
||||||
|
await event_gen.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exception(e)
|
||||||
|
yield create_sse_event(
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": str(translate_exception(e)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str):
|
def create_dynamic_typed_route(func: Any, method: str):
|
||||||
hints = get_type_hints(func)
|
|
||||||
response_model = hints.get("return")
|
|
||||||
|
|
||||||
# NOTE: I think it is better to just add a method within each Api
|
async def endpoint(request: Request, **kwargs):
|
||||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
await start_trace(func.__name__)
|
||||||
# is going to produce. /chat_completion can produce a streaming or
|
|
||||||
# non-streaming response depending on if request.stream is True / False.
|
|
||||||
is_streaming = is_async_iterator_type(response_model)
|
|
||||||
|
|
||||||
if is_streaming:
|
set_request_provider_data(request.headers)
|
||||||
|
|
||||||
async def endpoint(request: Request, **kwargs):
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
await start_trace(func.__name__)
|
try:
|
||||||
|
if is_streaming:
|
||||||
set_request_provider_data(request.headers)
|
return StreamingResponse(
|
||||||
|
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||||
async def sse_generator(event_gen):
|
|
||||||
try:
|
|
||||||
async for item in event_gen:
|
|
||||||
yield create_sse_event(item)
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
print("Generator cancelled")
|
|
||||||
await event_gen.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exception(e)
|
|
||||||
yield create_sse_event(
|
|
||||||
{
|
|
||||||
"error": {
|
|
||||||
"message": str(translate_exception(e)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
async def endpoint(request: Request, **kwargs):
|
|
||||||
await start_trace(func.__name__)
|
|
||||||
|
|
||||||
set_request_provider_data(request.headers)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return (
|
|
||||||
await func(**kwargs)
|
|
||||||
if asyncio.iscoroutinefunction(func)
|
|
||||||
else func(**kwargs)
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
else:
|
||||||
traceback.print_exception(e)
|
value = func(**kwargs)
|
||||||
raise translate_exception(e) from e
|
return await maybe_await(value)
|
||||||
finally:
|
except Exception as e:
|
||||||
await end_trace()
|
traceback.print_exception(e)
|
||||||
|
raise translate_exception(e) from e
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
new_params = [
|
new_params = [
|
||||||
|
@ -285,29 +274,25 @@ def main(
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
impls = asyncio.run(resolve_impls_with_routing(config))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
if config.apis_to_serve:
|
if config.apis:
|
||||||
apis_to_serve = set(config.apis_to_serve)
|
apis_to_serve = set(config.apis)
|
||||||
else:
|
else:
|
||||||
apis_to_serve = set(impls.keys())
|
apis_to_serve = set(impls.keys())
|
||||||
|
|
||||||
apis_to_serve.add(Api.inspect)
|
apis_to_serve.add("inspect")
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
provider_spec = specs[api]
|
if is_passthrough(impl.__provider_spec__):
|
||||||
if (
|
|
||||||
isinstance(provider_spec, RemoteProviderSpec)
|
|
||||||
and provider_spec.adapter is None
|
|
||||||
):
|
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
getattr(app, endpoint.method)(endpoint.route)(
|
||||||
|
@ -337,7 +322,9 @@ def main(
|
||||||
print("")
|
print("")
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, handle_sigint)
|
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
||||||
|
|
||||||
|
app.__llama_stack_impls__ = impls
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
built_at: '2024-09-30T09:04:30.533391'
|
version: '2'
|
||||||
|
built_at: '2024-10-08T17:42:07.505267'
|
||||||
image_name: local-cpu
|
image_name: local-cpu
|
||||||
docker_image: local-cpu
|
docker_image: local-cpu
|
||||||
conda_env: null
|
conda_env: null
|
||||||
apis_to_serve:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
- inference
|
- inference
|
||||||
- models
|
- models
|
||||||
|
@ -10,40 +11,48 @@ apis_to_serve:
|
||||||
- safety
|
- safety
|
||||||
- shields
|
- shields
|
||||||
- memory_banks
|
- memory_banks
|
||||||
api_providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
providers:
|
- provider_id: remote::ollama
|
||||||
- remote::ollama
|
provider_type: remote::ollama
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 6000
|
||||||
safety:
|
safety:
|
||||||
providers:
|
- provider_id: meta-reference
|
||||||
- meta-reference
|
provider_type: meta-reference
|
||||||
|
config:
|
||||||
|
llama_guard_shield: null
|
||||||
|
prompt_guard_shield: null
|
||||||
|
memory:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
agents:
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
persistence_store:
|
persistence_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: /home/xiyan/.llama/runtime/kvstore.db
|
db_path: ~/.llama/runtime/kvstore.db
|
||||||
memory:
|
|
||||||
providers:
|
|
||||||
- meta-reference
|
|
||||||
telemetry:
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
routing_table:
|
models:
|
||||||
inference:
|
- identifier: Llama3.1-8B-Instruct
|
||||||
- provider_type: remote::ollama
|
llama_model: Llama3.1-8B-Instruct
|
||||||
config:
|
provider_id: remote::ollama
|
||||||
host: localhost
|
shields:
|
||||||
port: 6000
|
- identifier: llama_guard
|
||||||
routing_key: Meta-Llama3.1-8B-Instruct
|
type: llama_guard
|
||||||
safety:
|
provider_id: meta-reference
|
||||||
- provider_type: meta-reference
|
params: {}
|
||||||
config:
|
memory_banks:
|
||||||
llama_guard_shield: null
|
- identifier: vector
|
||||||
prompt_guard_shield: null
|
provider_id: meta-reference
|
||||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
type: vector
|
||||||
memory:
|
embedding_model: all-MiniLM-L6-v2
|
||||||
- provider_type: meta-reference
|
chunk_size_in_tokens: 512
|
||||||
config: {}
|
overlap_size_in_tokens: null
|
||||||
routing_key: vector
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
built_at: '2024-09-30T09:00:56.693751'
|
version: '2'
|
||||||
|
built_at: '2024-10-08T17:42:33.690666'
|
||||||
image_name: local-gpu
|
image_name: local-gpu
|
||||||
docker_image: local-gpu
|
docker_image: local-gpu
|
||||||
conda_env: null
|
conda_env: null
|
||||||
apis_to_serve:
|
apis:
|
||||||
- memory
|
- memory
|
||||||
- inference
|
- inference
|
||||||
- agents
|
- agents
|
||||||
|
@ -10,43 +11,51 @@ apis_to_serve:
|
||||||
- safety
|
- safety
|
||||||
- models
|
- models
|
||||||
- memory_banks
|
- memory_banks
|
||||||
api_providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
providers:
|
- provider_id: meta-reference
|
||||||
- meta-reference
|
|
||||||
safety:
|
|
||||||
providers:
|
|
||||||
- meta-reference
|
|
||||||
agents:
|
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config:
|
|
||||||
persistence_store:
|
|
||||||
namespace: null
|
|
||||||
type: sqlite
|
|
||||||
db_path: /home/xiyan/.llama/runtime/kvstore.db
|
|
||||||
memory:
|
|
||||||
providers:
|
|
||||||
- meta-reference
|
|
||||||
telemetry:
|
|
||||||
provider_type: meta-reference
|
|
||||||
config: {}
|
|
||||||
routing_table:
|
|
||||||
inference:
|
|
||||||
- provider_type: meta-reference
|
|
||||||
config:
|
config:
|
||||||
model: Llama3.1-8B-Instruct
|
model: Llama3.1-8B-Instruct
|
||||||
quantization: null
|
quantization: null
|
||||||
torch_seed: null
|
torch_seed: null
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
routing_key: Llama3.1-8B-Instruct
|
|
||||||
safety:
|
safety:
|
||||||
- provider_type: meta-reference
|
- provider_id: meta-reference
|
||||||
|
provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
llama_guard_shield: null
|
llama_guard_shield: null
|
||||||
prompt_guard_shield: null
|
prompt_guard_shield: null
|
||||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
|
||||||
memory:
|
memory:
|
||||||
- provider_type: meta-reference
|
- provider_id: meta-reference
|
||||||
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
routing_key: vector
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
namespace: null
|
||||||
|
type: sqlite
|
||||||
|
db_path: ~/.llama/runtime/kvstore.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
|
models:
|
||||||
|
- identifier: Llama3.1-8B-Instruct
|
||||||
|
llama_model: Llama3.1-8B-Instruct
|
||||||
|
provider_id: meta-reference
|
||||||
|
shields:
|
||||||
|
- identifier: llama_guard
|
||||||
|
type: llama_guard
|
||||||
|
provider_id: meta-reference
|
||||||
|
params: {}
|
||||||
|
memory_banks:
|
||||||
|
- identifier: vector
|
||||||
|
provider_id: meta-reference
|
||||||
|
type: vector
|
||||||
|
embedding_model: all-MiniLM-L6-v2
|
||||||
|
chunk_size_in_tokens: 512
|
||||||
|
overlap_size_in_tokens: null
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
name: local-databricks
|
||||||
|
distribution_spec:
|
||||||
|
description: Use Databricks for running LLM inference
|
||||||
|
providers:
|
||||||
|
inference: remote::databricks
|
||||||
|
memory: meta-reference
|
||||||
|
safety: meta-reference
|
||||||
|
agents: meta-reference
|
||||||
|
telemetry: meta-reference
|
||||||
|
image_type: conda
|
10
llama_stack/distribution/templates/local-vllm-build.yaml
Normal file
10
llama_stack/distribution/templates/local-vllm-build.yaml
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
name: local-vllm
|
||||||
|
distribution_spec:
|
||||||
|
description: Like local, but use vLLM for running LLM inference
|
||||||
|
providers:
|
||||||
|
inference: vllm
|
||||||
|
memory: meta-reference
|
||||||
|
safety: meta-reference
|
||||||
|
agents: meta-reference
|
||||||
|
telemetry: meta-reference
|
||||||
|
image_type: conda
|
|
@ -1,445 +1,445 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||||
|
|
||||||
|
|
||||||
BEDROCK_SUPPORTED_MODELS = {
|
BEDROCK_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
|
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||||
retries_config = {
|
retries_config = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(
|
for k, v in dict(
|
||||||
total_max_attempts=config.total_max_attempts,
|
total_max_attempts=config.total_max_attempts,
|
||||||
mode=config.retry_mode,
|
mode=config.retry_mode,
|
||||||
).items()
|
).items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
config_args = {
|
config_args = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(
|
for k, v in dict(
|
||||||
region_name=config.region_name,
|
region_name=config.region_name,
|
||||||
retries=retries_config if retries_config else None,
|
retries=retries_config if retries_config else None,
|
||||||
connect_timeout=config.connect_timeout,
|
connect_timeout=config.connect_timeout,
|
||||||
read_timeout=config.read_timeout,
|
read_timeout=config.read_timeout,
|
||||||
).items()
|
).items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
boto3_config = Config(**config_args)
|
boto3_config = Config(**config_args)
|
||||||
|
|
||||||
session_args = {
|
session_args = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(
|
for k, v in dict(
|
||||||
aws_access_key_id=config.aws_access_key_id,
|
aws_access_key_id=config.aws_access_key_id,
|
||||||
aws_secret_access_key=config.aws_secret_access_key,
|
aws_secret_access_key=config.aws_secret_access_key,
|
||||||
aws_session_token=config.aws_session_token,
|
aws_session_token=config.aws_session_token,
|
||||||
region_name=config.region_name,
|
region_name=config.region_name,
|
||||||
profile_name=config.profile_name,
|
profile_name=config.profile_name,
|
||||||
).items()
|
).items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
boto3_session = boto3.session.Session(**session_args)
|
boto3_session = boto3.session.Session(**session_args)
|
||||||
|
|
||||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||||
|
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
def __init__(self, config: BedrockConfig) -> None:
|
||||||
RoutableProviderForModels.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> BaseClient:
|
def client(self) -> BaseClient:
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||||
if bedrock_stop_reason == "max_tokens":
|
if bedrock_stop_reason == "max_tokens":
|
||||||
return StopReason.out_of_tokens
|
return StopReason.out_of_tokens
|
||||||
return StopReason.end_of_turn
|
return StopReason.end_of_turn
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||||
for builtin_tool in BuiltinTool:
|
for builtin_tool in BuiltinTool:
|
||||||
if builtin_tool.value == tool_name_str:
|
if builtin_tool.value == tool_name_str:
|
||||||
return builtin_tool
|
return builtin_tool
|
||||||
else:
|
else:
|
||||||
return tool_name_str
|
return tool_name_str
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||||
converse_api_res["stopReason"]
|
converse_api_res["stopReason"]
|
||||||
)
|
)
|
||||||
|
|
||||||
bedrock_message = converse_api_res["output"]["message"]
|
bedrock_message = converse_api_res["output"]["message"]
|
||||||
|
|
||||||
role = bedrock_message["role"]
|
role = bedrock_message["role"]
|
||||||
contents = bedrock_message["content"]
|
contents = bedrock_message["content"]
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
text_content = []
|
text_content = []
|
||||||
for content in contents:
|
for content in contents:
|
||||||
if "toolUse" in content:
|
if "toolUse" in content:
|
||||||
tool_use = content["toolUse"]
|
tool_use = content["toolUse"]
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||||
tool_use["name"]
|
tool_use["name"]
|
||||||
),
|
),
|
||||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||||
call_id=tool_use["toolUseId"],
|
call_id=tool_use["toolUseId"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "text" in content:
|
elif "text" in content:
|
||||||
text_content.append(content["text"])
|
text_content.append(content["text"])
|
||||||
|
|
||||||
return CompletionMessage(
|
return CompletionMessage(
|
||||||
role=role,
|
role=role,
|
||||||
content=text_content,
|
content=text_content,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _messages_to_bedrock_messages(
|
def _messages_to_bedrock_messages(
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||||
bedrock_messages = []
|
bedrock_messages = []
|
||||||
system_bedrock_messages = []
|
system_bedrock_messages = []
|
||||||
|
|
||||||
user_contents = []
|
user_contents = []
|
||||||
assistant_contents = None
|
assistant_contents = None
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message.role
|
role = message.role
|
||||||
content_list = (
|
content_list = (
|
||||||
message.content
|
message.content
|
||||||
if isinstance(message.content, list)
|
if isinstance(message.content, list)
|
||||||
else [message.content]
|
else [message.content]
|
||||||
)
|
)
|
||||||
if role == "ipython" or role == "user":
|
if role == "ipython" or role == "user":
|
||||||
if not user_contents:
|
if not user_contents:
|
||||||
user_contents = []
|
user_contents = []
|
||||||
|
|
||||||
if role == "ipython":
|
if role == "ipython":
|
||||||
user_contents.extend(
|
user_contents.extend(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"toolResult": {
|
"toolResult": {
|
||||||
"toolUseId": message.call_id,
|
"toolUseId": message.call_id,
|
||||||
"content": [
|
"content": [
|
||||||
{"text": content} for content in content_list
|
{"text": content} for content in content_list
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
user_contents.extend(
|
user_contents.extend(
|
||||||
[{"text": content} for content in content_list]
|
[{"text": content} for content in content_list]
|
||||||
)
|
)
|
||||||
|
|
||||||
if assistant_contents:
|
if assistant_contents:
|
||||||
bedrock_messages.append(
|
bedrock_messages.append(
|
||||||
{"role": "assistant", "content": assistant_contents}
|
{"role": "assistant", "content": assistant_contents}
|
||||||
)
|
)
|
||||||
assistant_contents = None
|
assistant_contents = None
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
system_bedrock_messages.extend(
|
system_bedrock_messages.extend(
|
||||||
[{"text": content} for content in content_list]
|
[{"text": content} for content in content_list]
|
||||||
)
|
)
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
if not assistant_contents:
|
if not assistant_contents:
|
||||||
assistant_contents = []
|
assistant_contents = []
|
||||||
|
|
||||||
assistant_contents.extend(
|
assistant_contents.extend(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"text": content,
|
"text": content,
|
||||||
}
|
}
|
||||||
for content in content_list
|
for content in content_list
|
||||||
]
|
]
|
||||||
+ [
|
+ [
|
||||||
{
|
{
|
||||||
"toolUse": {
|
"toolUse": {
|
||||||
"input": tool_call.arguments,
|
"input": tool_call.arguments,
|
||||||
"name": (
|
"name": (
|
||||||
tool_call.tool_name
|
tool_call.tool_name
|
||||||
if isinstance(tool_call.tool_name, str)
|
if isinstance(tool_call.tool_name, str)
|
||||||
else tool_call.tool_name.value
|
else tool_call.tool_name.value
|
||||||
),
|
),
|
||||||
"toolUseId": tool_call.call_id,
|
"toolUseId": tool_call.call_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for tool_call in message.tool_calls
|
for tool_call in message.tool_calls
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_contents:
|
if user_contents:
|
||||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||||
user_contents = None
|
user_contents = None
|
||||||
else:
|
else:
|
||||||
# Unknown role
|
# Unknown role
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if user_contents:
|
if user_contents:
|
||||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||||
if assistant_contents:
|
if assistant_contents:
|
||||||
bedrock_messages.append(
|
bedrock_messages.append(
|
||||||
{"role": "assistant", "content": assistant_contents}
|
{"role": "assistant", "content": assistant_contents}
|
||||||
)
|
)
|
||||||
|
|
||||||
if system_bedrock_messages:
|
if system_bedrock_messages:
|
||||||
return bedrock_messages, system_bedrock_messages
|
return bedrock_messages, system_bedrock_messages
|
||||||
|
|
||||||
return bedrock_messages, None
|
return bedrock_messages, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||||
inference_config = {}
|
inference_config = {}
|
||||||
if sampling_params:
|
if sampling_params:
|
||||||
param_mapping = {
|
param_mapping = {
|
||||||
"max_tokens": "maxTokens",
|
"max_tokens": "maxTokens",
|
||||||
"temperature": "temperature",
|
"temperature": "temperature",
|
||||||
"top_p": "topP",
|
"top_p": "topP",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v in param_mapping.items():
|
for k, v in param_mapping.items():
|
||||||
if getattr(sampling_params, k):
|
if getattr(sampling_params, k):
|
||||||
inference_config[v] = getattr(sampling_params, k)
|
inference_config[v] = getattr(sampling_params, k)
|
||||||
|
|
||||||
return inference_config
|
return inference_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tool_parameters_to_input_schema(
|
def _tool_parameters_to_input_schema(
|
||||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
input_schema = {"type": "object"}
|
input_schema = {"type": "object"}
|
||||||
if not tool_parameters:
|
if not tool_parameters:
|
||||||
return input_schema
|
return input_schema
|
||||||
|
|
||||||
json_properties = {}
|
json_properties = {}
|
||||||
required = []
|
required = []
|
||||||
for name, param in tool_parameters.items():
|
for name, param in tool_parameters.items():
|
||||||
json_property = {
|
json_property = {
|
||||||
"type": param.param_type,
|
"type": param.param_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
if param.description:
|
if param.description:
|
||||||
json_property["description"] = param.description
|
json_property["description"] = param.description
|
||||||
if param.required:
|
if param.required:
|
||||||
required.append(name)
|
required.append(name)
|
||||||
json_properties[name] = json_property
|
json_properties[name] = json_property
|
||||||
|
|
||||||
input_schema["properties"] = json_properties
|
input_schema["properties"] = json_properties
|
||||||
if required:
|
if required:
|
||||||
input_schema["required"] = required
|
input_schema["required"] = required
|
||||||
return input_schema
|
return input_schema
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tools_to_tool_config(
|
def _tools_to_tool_config(
|
||||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||||
) -> Optional[Dict]:
|
) -> Optional[Dict]:
|
||||||
if not tools:
|
if not tools:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
bedrock_tools = []
|
bedrock_tools = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
tool_name = (
|
tool_name = (
|
||||||
tool.tool_name
|
tool.tool_name
|
||||||
if isinstance(tool.tool_name, str)
|
if isinstance(tool.tool_name, str)
|
||||||
else tool.tool_name.value
|
else tool.tool_name.value
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_spec = {
|
tool_spec = {
|
||||||
"toolSpec": {
|
"toolSpec": {
|
||||||
"name": tool_name,
|
"name": tool_name,
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||||
tool.parameters
|
tool.parameters
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool.description:
|
if tool.description:
|
||||||
tool_spec["toolSpec"]["description"] = tool.description
|
tool_spec["toolSpec"]["description"] = tool.description
|
||||||
|
|
||||||
bedrock_tools.append(tool_spec)
|
bedrock_tools.append(tool_spec)
|
||||||
tool_config = {
|
tool_config = {
|
||||||
"tools": bedrock_tools,
|
"tools": bedrock_tools,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
tool_config["toolChoice"] = (
|
tool_config["toolChoice"] = (
|
||||||
{"any": {}}
|
{"any": {}}
|
||||||
if tool_choice.value == ToolChoice.required
|
if tool_choice.value == ToolChoice.required
|
||||||
else {"auto": {}}
|
else {"auto": {}}
|
||||||
)
|
)
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> (
|
) -> (
|
||||||
AsyncGenerator
|
AsyncGenerator
|
||||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||||
bedrock_model = self.map_to_provider_model(model)
|
bedrock_model = self.map_to_provider_model(model)
|
||||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||||
sampling_params
|
sampling_params
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||||
bedrock_messages, system_bedrock_messages = (
|
bedrock_messages, system_bedrock_messages = (
|
||||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||||
)
|
)
|
||||||
|
|
||||||
converse_api_params = {
|
converse_api_params = {
|
||||||
"modelId": bedrock_model,
|
"modelId": bedrock_model,
|
||||||
"messages": bedrock_messages,
|
"messages": bedrock_messages,
|
||||||
}
|
}
|
||||||
if inference_config:
|
if inference_config:
|
||||||
converse_api_params["inferenceConfig"] = inference_config
|
converse_api_params["inferenceConfig"] = inference_config
|
||||||
|
|
||||||
# Tool use is not supported in streaming mode
|
# Tool use is not supported in streaming mode
|
||||||
if tool_config and not stream:
|
if tool_config and not stream:
|
||||||
converse_api_params["toolConfig"] = tool_config
|
converse_api_params["toolConfig"] = tool_config
|
||||||
if system_bedrock_messages:
|
if system_bedrock_messages:
|
||||||
converse_api_params["system"] = system_bedrock_messages
|
converse_api_params["system"] = system_bedrock_messages
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
converse_api_res = self.client.converse(**converse_api_params)
|
converse_api_res = self.client.converse(**converse_api_params)
|
||||||
|
|
||||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||||
converse_api_res
|
converse_api_res
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
completion_message=output_message,
|
completion_message=output_message,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||||
event_stream = converse_stream_api_res["stream"]
|
event_stream = converse_stream_api_res["stream"]
|
||||||
|
|
||||||
for chunk in event_stream:
|
for chunk in event_stream:
|
||||||
if "messageStart" in chunk:
|
if "messageStart" in chunk:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
delta="",
|
delta="",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "contentBlockStart" in chunk:
|
elif "contentBlockStart" in chunk:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||||
"name"
|
"name"
|
||||||
],
|
],
|
||||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||||
"toolUseId"
|
"toolUseId"
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
parse_status=ToolCallParseStatus.started,
|
parse_status=ToolCallParseStatus.started,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "contentBlockDelta" in chunk:
|
elif "contentBlockDelta" in chunk:
|
||||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||||
else:
|
else:
|
||||||
delta = ToolCallDelta(
|
delta = ToolCallDelta(
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
arguments=chunk["contentBlockDelta"]["delta"][
|
arguments=chunk["contentBlockDelta"]["delta"][
|
||||||
"toolUse"
|
"toolUse"
|
||||||
]["input"]
|
]["input"]
|
||||||
),
|
),
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.success,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "contentBlockStop" in chunk:
|
elif "contentBlockStop" in chunk:
|
||||||
# Ignored
|
# Ignored
|
||||||
pass
|
pass
|
||||||
elif "messageStop" in chunk:
|
elif "messageStop" in chunk:
|
||||||
stop_reason = (
|
stop_reason = (
|
||||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||||
chunk["messageStop"]["stopReason"]
|
chunk["messageStop"]["stopReason"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta="",
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "metadata" in chunk:
|
elif "metadata" in chunk:
|
||||||
# Ignored
|
# Ignored
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Ignored
|
# Ignored
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import DatabricksImplConfig
|
||||||
|
from .databricks import DatabricksInferenceAdapter
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||||
|
assert isinstance(
|
||||||
|
config, DatabricksImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = DatabricksInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DatabricksImplConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
default=None,
|
||||||
|
description="The URL for the Databricks model serving endpoint",
|
||||||
|
)
|
||||||
|
api_token: str = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Databricks API token",
|
||||||
|
)
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
get_sampling_options,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import DatabricksImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
DATABRICKS_SUPPORTED_MODELS = {
|
||||||
|
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||||
|
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(
|
||||||
|
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request, client)
|
||||||
|
else:
|
||||||
|
return self._nonstream_chat_completion(request, client)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest, client: OpenAI
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
params = self._get_params(request)
|
||||||
|
r = client.completions.create(**params)
|
||||||
|
return process_chat_completion_response(request, r, self.formatter)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest, client: OpenAI
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = self._get_params(request)
|
||||||
|
|
||||||
|
async def _to_async_generator():
|
||||||
|
s = client.completions.create(**params)
|
||||||
|
for chunk in s:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
|
async for chunk in process_chat_completion_stream_response(
|
||||||
|
request, stream, self.formatter
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
|
return {
|
||||||
|
"model": self.map_to_provider_model(request.model),
|
||||||
|
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||||
|
"stream": request.stream,
|
||||||
|
**get_sampling_options(request),
|
||||||
|
}
|
|
@ -10,14 +10,19 @@ from fireworks.client import Fireworks
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
|
||||||
augment_messages_for_tools,
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
get_sampling_options,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
@ -27,21 +32,18 @@ FIREWORKS_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||||
|
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
||||||
|
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
RoutableProviderForModels.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
tokenizer = Tokenizer.get_instance()
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
self.formatter = ChatFormat(tokenizer)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self) -> Fireworks:
|
|
||||||
return Fireworks(api_key=self.config.api_key)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
return
|
return
|
||||||
|
@ -49,7 +51,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -59,27 +61,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
|
def chat_completion(
|
||||||
fireworks_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if message.role == "ipython":
|
|
||||||
role = "tool"
|
|
||||||
else:
|
|
||||||
role = message.role
|
|
||||||
fireworks_messages.append({"role": role, "content": message.content})
|
|
||||||
|
|
||||||
return fireworks_messages
|
|
||||||
|
|
||||||
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
|
||||||
options = {}
|
|
||||||
if request.sampling_params is not None:
|
|
||||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
|
||||||
if getattr(request.sampling_params, attr):
|
|
||||||
options[attr] = getattr(request.sampling_params, attr)
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -101,147 +83,41 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = augment_messages_for_tools(request)
|
client = Fireworks(api_key=self.config.api_key)
|
||||||
|
if stream:
|
||||||
# accumulate sampling params and other options to pass to fireworks
|
return self._stream_chat_completion(request, client)
|
||||||
options = self.get_fireworks_chat_options(request)
|
|
||||||
fireworks_model = self.map_to_provider_model(request.model)
|
|
||||||
|
|
||||||
if not request.stream:
|
|
||||||
r = await self.client.chat.completions.acreate(
|
|
||||||
model=fireworks_model,
|
|
||||||
messages=self._messages_to_fireworks_messages(messages),
|
|
||||||
stream=False,
|
|
||||||
**options,
|
|
||||||
)
|
|
||||||
stop_reason = None
|
|
||||||
if r.choices[0].finish_reason:
|
|
||||||
if r.choices[0].finish_reason == "stop":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif r.choices[0].finish_reason == "length":
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
|
||||||
r.choices[0].message.content, stop_reason
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponse(
|
|
||||||
completion_message=completion_message,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
return self._nonstream_chat_completion(request, client)
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
|
||||||
delta="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
buffer = ""
|
async def _nonstream_chat_completion(
|
||||||
ipython = False
|
self, request: ChatCompletionRequest, client: Fireworks
|
||||||
stop_reason = None
|
) -> ChatCompletionResponse:
|
||||||
|
params = self._get_params(request)
|
||||||
|
r = await client.completion.acreate(**params)
|
||||||
|
return process_chat_completion_response(request, r, self.formatter)
|
||||||
|
|
||||||
async for chunk in self.client.chat.completions.acreate(
|
async def _stream_chat_completion(
|
||||||
model=fireworks_model,
|
self, request: ChatCompletionRequest, client: Fireworks
|
||||||
messages=self._messages_to_fireworks_messages(messages),
|
) -> AsyncGenerator:
|
||||||
stream=True,
|
params = self._get_params(request)
|
||||||
**options,
|
|
||||||
):
|
|
||||||
if chunk.choices[0].finish_reason:
|
|
||||||
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif (
|
|
||||||
stop_reason is None
|
|
||||||
and chunk.choices[0].finish_reason == "length"
|
|
||||||
):
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
break
|
|
||||||
|
|
||||||
text = chunk.choices[0].delta.content
|
stream = client.completion.acreate(**params)
|
||||||
if text is None:
|
async for chunk in process_chat_completion_stream_response(
|
||||||
continue
|
request, stream, self.formatter
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
if not ipython and text.startswith("<|python_tag|>"):
|
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||||
ipython = True
|
# Fireworks always prepends with BOS
|
||||||
yield ChatCompletionResponseStreamChunk(
|
if prompt.startswith("<|begin_of_text|>"):
|
||||||
event=ChatCompletionResponseEvent(
|
prompt = prompt[len("<|begin_of_text|>") :]
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
buffer += text
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ipython:
|
options = get_sampling_options(request)
|
||||||
if text == "<|eot_id|>":
|
options.setdefault("max_tokens", 512)
|
||||||
stop_reason = StopReason.end_of_turn
|
return {
|
||||||
text = ""
|
"model": self.map_to_provider_model(request.model),
|
||||||
continue
|
"prompt": prompt,
|
||||||
elif text == "<|eom_id|>":
|
"stream": request.stream,
|
||||||
stop_reason = StopReason.end_of_message
|
**options,
|
||||||
text = ""
|
}
|
||||||
continue
|
|
||||||
|
|
||||||
buffer += text
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=text,
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=delta,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
buffer += text
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=text,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# parse tool calls and report errors
|
|
||||||
message = self.formatter.decode_assistant_message_from_content(
|
|
||||||
buffer, stop_reason
|
|
||||||
)
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.failure,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta="",
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
|
@ -7,6 +7,10 @@
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaImplConfig(RemoteProviderConfig):
|
||||||
|
port: int = 11434
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||||
from .ollama import OllamaInferenceAdapter
|
from .ollama import OllamaInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -9,34 +9,36 @@ from typing import AsyncGenerator
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
augment_messages_for_tools,
|
get_sampling_options,
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
|
||||||
|
|
||||||
# TODO: Eventually this will move to the llama cli model list command
|
OLLAMA_SUPPORTED_MODELS = {
|
||||||
# mapping of Model SKUs to ollama models
|
|
||||||
OLLAMA_SUPPORTED_SKUS = {
|
|
||||||
# "Llama3.1-8B-Instruct": "llama3.1",
|
|
||||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||||
|
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||||
|
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||||
|
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
class OllamaInferenceAdapter(Inference):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
RoutableProviderForModels.__init__(
|
|
||||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
|
||||||
)
|
|
||||||
self.url = url
|
self.url = url
|
||||||
tokenizer = Tokenizer.get_instance()
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
self.formatter = ChatFormat(tokenizer)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> AsyncClient:
|
def client(self) -> AsyncClient:
|
||||||
|
@ -54,7 +56,29 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def completion(
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
|
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier]
|
||||||
|
res = await self.client.ps()
|
||||||
|
need_model_pull = True
|
||||||
|
for r in res["models"]:
|
||||||
|
if ollama_model == r["model"]:
|
||||||
|
need_model_pull = False
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}")
|
||||||
|
if need_model_pull:
|
||||||
|
print(f"Pulling model: {ollama_model}")
|
||||||
|
status = await self.client.pull(ollama_model)
|
||||||
|
assert (
|
||||||
|
status["status"] == "success"
|
||||||
|
), f"Failed to pull model {self.model} in ollama"
|
||||||
|
|
||||||
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -64,32 +88,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
def chat_completion(
|
||||||
ollama_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if message.role == "ipython":
|
|
||||||
role = "tool"
|
|
||||||
else:
|
|
||||||
role = message.role
|
|
||||||
ollama_messages.append({"role": role, "content": message.content})
|
|
||||||
|
|
||||||
return ollama_messages
|
|
||||||
|
|
||||||
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
|
||||||
options = {}
|
|
||||||
if request.sampling_params is not None:
|
|
||||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
|
||||||
if getattr(request.sampling_params, attr):
|
|
||||||
options[attr] = getattr(request.sampling_params, attr)
|
|
||||||
if (
|
|
||||||
request.sampling_params.repetition_penalty is not None
|
|
||||||
and request.sampling_params.repetition_penalty != 1.0
|
|
||||||
):
|
|
||||||
options["repeat_penalty"] = request.sampling_params.repetition_penalty
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -110,156 +109,54 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
if stream:
|
||||||
messages = augment_messages_for_tools(request)
|
return self._stream_chat_completion(request)
|
||||||
# accumulate sampling params and other options to pass to ollama
|
|
||||||
options = self.get_ollama_chat_options(request)
|
|
||||||
ollama_model = self.map_to_provider_model(request.model)
|
|
||||||
|
|
||||||
res = await self.client.ps()
|
|
||||||
need_model_pull = True
|
|
||||||
for r in res["models"]:
|
|
||||||
if ollama_model == r["model"]:
|
|
||||||
need_model_pull = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if need_model_pull:
|
|
||||||
print(f"Pulling model: {ollama_model}")
|
|
||||||
status = await self.client.pull(ollama_model)
|
|
||||||
assert (
|
|
||||||
status["status"] == "success"
|
|
||||||
), f"Failed to pull model {self.model} in ollama"
|
|
||||||
|
|
||||||
if not request.stream:
|
|
||||||
r = await self.client.chat(
|
|
||||||
model=ollama_model,
|
|
||||||
messages=self._messages_to_ollama_messages(messages),
|
|
||||||
stream=False,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
stop_reason = None
|
|
||||||
if r["done"]:
|
|
||||||
if r["done_reason"] == "stop":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif r["done_reason"] == "length":
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
|
||||||
r["message"]["content"], stop_reason
|
|
||||||
)
|
|
||||||
yield ChatCompletionResponse(
|
|
||||||
completion_message=completion_message,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
return self._nonstream_chat_completion(request)
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
delta="",
|
return {
|
||||||
|
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||||
|
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||||
|
"options": get_sampling_options(request),
|
||||||
|
"raw": True,
|
||||||
|
"stream": request.stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
params = self._get_params(request)
|
||||||
|
r = await self.client.generate(**params)
|
||||||
|
assert isinstance(r, dict)
|
||||||
|
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=r["done_reason"] if r["done"] else None,
|
||||||
|
text=r["response"],
|
||||||
|
)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
return process_chat_completion_response(request, response, self.formatter)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = self._get_params(request)
|
||||||
|
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
s = await self.client.generate(**params)
|
||||||
|
async for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||||
|
text=chunk["response"],
|
||||||
)
|
)
|
||||||
)
|
yield OpenAICompatCompletionResponse(
|
||||||
stream = await self.client.chat(
|
choices=[choice],
|
||||||
model=ollama_model,
|
|
||||||
messages=self._messages_to_ollama_messages(messages),
|
|
||||||
stream=True,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
buffer = ""
|
|
||||||
ipython = False
|
|
||||||
stop_reason = None
|
|
||||||
|
|
||||||
async for chunk in stream:
|
|
||||||
if chunk["done"]:
|
|
||||||
if stop_reason is None and chunk["done_reason"] == "stop":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif stop_reason is None and chunk["done_reason"] == "length":
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
break
|
|
||||||
|
|
||||||
text = chunk["message"]["content"]
|
|
||||||
|
|
||||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
|
||||||
if not ipython and text.startswith("<|python_tag|>"):
|
|
||||||
ipython = True
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
buffer += text
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ipython:
|
|
||||||
if text == "<|eot_id|>":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
text = ""
|
|
||||||
continue
|
|
||||||
elif text == "<|eom_id|>":
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
text = ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
buffer += text
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=text,
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=delta,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
buffer += text
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=text,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# parse tool calls and report errors
|
|
||||||
message = self.formatter.decode_assistant_message_from_content(
|
|
||||||
buffer, stop_reason
|
|
||||||
)
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.failure,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
yield ChatCompletionResponseStreamChunk(
|
async for chunk in process_chat_completion_stream_response(
|
||||||
event=ChatCompletionResponseEvent(
|
request, stream, self.formatter
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
):
|
||||||
delta=ToolCallDelta(
|
yield chunk
|
||||||
content=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta="",
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
|
class SampleInferenceImpl(Inference):
|
||||||
class SampleInferenceImpl(Inference, RoutableProvider):
|
|
||||||
def __init__(self, config: SampleConfig):
|
def __init__(self, config: SampleConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
# these are the model names the Llama Stack will use to route requests to this provider
|
# these are the model names the Llama Stack will use to route requests to this provider
|
||||||
# perform validation here if necessary
|
# perform validation here if necessary
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InferenceAPIImplConfig(BaseModel):
|
class InferenceAPIImplConfig(BaseModel):
|
||||||
model_id: str = Field(
|
huggingface_repo: str = Field(
|
||||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||||
)
|
)
|
||||||
api_token: Optional[str] = Field(
|
api_token: Optional[str] = Field(
|
||||||
|
|
|
@ -10,14 +10,19 @@ from typing import AsyncGenerator
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
augment_messages_for_tools,
|
get_sampling_options,
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_model_input_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||||
|
@ -25,24 +30,31 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _HfAdapter(Inference, RoutableProvider):
|
class _HfAdapter(Inference):
|
||||||
client: AsyncInferenceClient
|
client: AsyncInferenceClient
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
model_id: str
|
model_id: str
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
# these are the model names the Llama Stack will use to route requests to this provider
|
resolved_model = resolve_model(model.identifier)
|
||||||
# perform validation here if necessary
|
if resolved_model is None:
|
||||||
pass
|
raise ValueError(f"Unknown model: {model.identifier}")
|
||||||
|
|
||||||
|
if not resolved_model.huggingface_repo:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model.identifier} does not have a HuggingFace repo"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_id != resolved_model.huggingface_repo:
|
||||||
|
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -52,16 +64,7 @@ class _HfAdapter(Inference, RoutableProvider):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_chat_options(self, request: ChatCompletionRequest) -> dict:
|
def chat_completion(
|
||||||
options = {}
|
|
||||||
if request.sampling_params is not None:
|
|
||||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
|
||||||
if getattr(request.sampling_params, attr):
|
|
||||||
options[attr] = getattr(request.sampling_params, attr)
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -83,146 +86,64 @@ class _HfAdapter(Inference, RoutableProvider):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = augment_messages_for_tools(request)
|
if stream:
|
||||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
return self._stream_chat_completion(request)
|
||||||
prompt = self.tokenizer.decode(model_input.tokens)
|
else:
|
||||||
|
return self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
input_tokens = len(model_input.tokens)
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
params = self._get_params(request)
|
||||||
|
r = await self.client.text_generation(**params)
|
||||||
|
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=r.details.finish_reason,
|
||||||
|
text="".join(t.text for t in r.details.tokens),
|
||||||
|
)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
return process_chat_completion_response(request, response, self.formatter)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = self._get_params(request)
|
||||||
|
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
s = await self.client.text_generation(**params)
|
||||||
|
async for chunk in s:
|
||||||
|
token_result = chunk.token
|
||||||
|
|
||||||
|
choice = OpenAICompatCompletionChoice(text=token_result.text)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
|
async for chunk in process_chat_completion_stream_response(
|
||||||
|
request, stream, self.formatter
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
|
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||||
|
request, self.formatter
|
||||||
|
)
|
||||||
max_new_tokens = min(
|
max_new_tokens = min(
|
||||||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||||
self.max_tokens - input_tokens - 1,
|
self.max_tokens - input_tokens - 1,
|
||||||
)
|
)
|
||||||
|
options = get_sampling_options(request)
|
||||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
return dict(
|
||||||
|
prompt=prompt,
|
||||||
options = self.get_chat_options(request)
|
stream=request.stream,
|
||||||
if not request.stream:
|
details=True,
|
||||||
response = await self.client.text_generation(
|
max_new_tokens=max_new_tokens,
|
||||||
prompt=prompt,
|
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||||
stream=False,
|
**options,
|
||||||
details=True,
|
)
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
|
||||||
**options,
|
|
||||||
)
|
|
||||||
stop_reason = None
|
|
||||||
if response.details.finish_reason:
|
|
||||||
if response.details.finish_reason in ["stop", "eos_token"]:
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif response.details.finish_reason == "length":
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
|
||||||
response.generated_text,
|
|
||||||
stop_reason,
|
|
||||||
)
|
|
||||||
yield ChatCompletionResponse(
|
|
||||||
completion_message=completion_message,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
|
||||||
delta="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
buffer = ""
|
|
||||||
ipython = False
|
|
||||||
stop_reason = None
|
|
||||||
tokens = []
|
|
||||||
|
|
||||||
async for response in await self.client.text_generation(
|
|
||||||
prompt=prompt,
|
|
||||||
stream=True,
|
|
||||||
details=True,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
|
||||||
**options,
|
|
||||||
):
|
|
||||||
token_result = response.token
|
|
||||||
|
|
||||||
buffer += token_result.text
|
|
||||||
tokens.append(token_result.id)
|
|
||||||
|
|
||||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
|
||||||
ipython = True
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
buffer = buffer[len("<|python_tag|>") :]
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
text = ""
|
|
||||||
elif token_result.text == "<|eom_id|>":
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
text = ""
|
|
||||||
else:
|
|
||||||
text = token_result.text
|
|
||||||
|
|
||||||
if ipython:
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=text,
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
delta = text
|
|
||||||
|
|
||||||
if stop_reason is None:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=delta,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if stop_reason is None:
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
# parse tool calls and report errors
|
|
||||||
message = self.formatter.decode_assistant_message(tokens, stop_reason)
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.failure,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta="",
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
|
@ -236,7 +157,7 @@ class TGIAdapter(_HfAdapter):
|
||||||
class InferenceAPIAdapter(_HfAdapter):
|
class InferenceAPIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(
|
||||||
model=config.model_id, token=config.api_token
|
model=config.huggingface_repo, token=config.api_token
|
||||||
)
|
)
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
|
|
|
@ -8,17 +8,22 @@ from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
augment_messages_for_tools,
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
get_sampling_options,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
@ -34,19 +39,14 @@ TOGETHER_SUPPORTED_MODELS = {
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(
|
class TogetherInferenceAdapter(
|
||||||
Inference, NeedsRequestProviderData, RoutableProviderForModels
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
):
|
):
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
RoutableProviderForModels.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
tokenizer = Tokenizer.get_instance()
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
self.formatter = ChatFormat(tokenizer)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self) -> Together:
|
|
||||||
return Together(api_key=self.config.api_key)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
return
|
return
|
||||||
|
@ -64,27 +64,7 @@ class TogetherInferenceAdapter(
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _messages_to_together_messages(self, messages: list[Message]) -> list:
|
def chat_completion(
|
||||||
together_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if message.role == "ipython":
|
|
||||||
role = "tool"
|
|
||||||
else:
|
|
||||||
role = message.role
|
|
||||||
together_messages.append({"role": role, "content": message.content})
|
|
||||||
|
|
||||||
return together_messages
|
|
||||||
|
|
||||||
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
|
|
||||||
options = {}
|
|
||||||
if request.sampling_params is not None:
|
|
||||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
|
||||||
if getattr(request.sampling_params, attr):
|
|
||||||
options[attr] = getattr(request.sampling_params, attr)
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -95,7 +75,6 @@ class TogetherInferenceAdapter(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
if self.config.api_key is not None:
|
if self.config.api_key is not None:
|
||||||
together_api_key = self.config.api_key
|
together_api_key = self.config.api_key
|
||||||
|
@ -108,7 +87,6 @@ class TogetherInferenceAdapter(
|
||||||
together_api_key = provider_data.together_api_key
|
together_api_key = provider_data.together_api_key
|
||||||
|
|
||||||
client = Together(api_key=together_api_key)
|
client = Together(api_key=together_api_key)
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -120,146 +98,39 @@ class TogetherInferenceAdapter(
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# accumulate sampling params and other options to pass to together
|
if stream:
|
||||||
options = self.get_together_chat_options(request)
|
return self._stream_chat_completion(request, client)
|
||||||
together_model = self.map_to_provider_model(request.model)
|
|
||||||
messages = augment_messages_for_tools(request)
|
|
||||||
|
|
||||||
if not request.stream:
|
|
||||||
# TODO: might need to add back an async here
|
|
||||||
r = client.chat.completions.create(
|
|
||||||
model=together_model,
|
|
||||||
messages=self._messages_to_together_messages(messages),
|
|
||||||
stream=False,
|
|
||||||
**options,
|
|
||||||
)
|
|
||||||
stop_reason = None
|
|
||||||
if r.choices[0].finish_reason:
|
|
||||||
if (
|
|
||||||
r.choices[0].finish_reason == "stop"
|
|
||||||
or r.choices[0].finish_reason == "eos"
|
|
||||||
):
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif r.choices[0].finish_reason == "length":
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
|
||||||
r.choices[0].message.content, stop_reason
|
|
||||||
)
|
|
||||||
yield ChatCompletionResponse(
|
|
||||||
completion_message=completion_message,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
return self._nonstream_chat_completion(request, client)
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
|
||||||
delta="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
buffer = ""
|
async def _nonstream_chat_completion(
|
||||||
ipython = False
|
self, request: ChatCompletionRequest, client: Together
|
||||||
stop_reason = None
|
) -> ChatCompletionResponse:
|
||||||
|
params = self._get_params(request)
|
||||||
|
r = client.completions.create(**params)
|
||||||
|
return process_chat_completion_response(request, r, self.formatter)
|
||||||
|
|
||||||
for chunk in client.chat.completions.create(
|
async def _stream_chat_completion(
|
||||||
model=together_model,
|
self, request: ChatCompletionRequest, client: Together
|
||||||
messages=self._messages_to_together_messages(messages),
|
) -> AsyncGenerator:
|
||||||
stream=True,
|
params = self._get_params(request)
|
||||||
**options,
|
|
||||||
):
|
|
||||||
if finish_reason := chunk.choices[0].finish_reason:
|
|
||||||
if stop_reason is None and finish_reason in ["stop", "eos"]:
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif stop_reason is None and finish_reason == "length":
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
break
|
|
||||||
|
|
||||||
text = chunk.choices[0].delta.content
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
if text is None:
|
async def _to_async_generator():
|
||||||
continue
|
s = client.completions.create(**params)
|
||||||
|
for chunk in s:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
stream = _to_async_generator()
|
||||||
if not ipython and text.startswith("<|python_tag|>"):
|
async for chunk in process_chat_completion_stream_response(
|
||||||
ipython = True
|
request, stream, self.formatter
|
||||||
yield ChatCompletionResponseStreamChunk(
|
):
|
||||||
event=ChatCompletionResponseEvent(
|
yield chunk
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
buffer += text
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ipython:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
if text == "<|eot_id|>":
|
return {
|
||||||
stop_reason = StopReason.end_of_turn
|
"model": self.map_to_provider_model(request.model),
|
||||||
text = ""
|
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||||
continue
|
"stream": request.stream,
|
||||||
elif text == "<|eom_id|>":
|
**get_sampling_options(request),
|
||||||
stop_reason = StopReason.end_of_message
|
}
|
||||||
text = ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
buffer += text
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=text,
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=delta,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
buffer += text
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=text,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# parse tool calls and report errors
|
|
||||||
message = self.formatter.decode_assistant_message_from_content(
|
|
||||||
buffer, stop_reason
|
|
||||||
)
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.failure,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta="",
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -13,7 +12,6 @@ import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
|
@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, RoutableProvider):
|
class ChromaMemoryAdapter(Memory):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
url = url.rstrip("/")
|
url = url.rstrip("/")
|
||||||
|
@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_memory_bank(
|
||||||
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
self,
|
||||||
name: str,
|
memory_bank: MemoryBankDef,
|
||||||
config: MemoryBankConfig,
|
) -> None:
|
||||||
url: Optional[URL] = None,
|
assert (
|
||||||
) -> MemoryBank:
|
memory_bank.type == MemoryBankType.vector.value
|
||||||
bank_id = str(uuid.uuid4())
|
), f"Only vector banks are supported {memory_bank.type}"
|
||||||
bank = MemoryBank(
|
|
||||||
bank_id=bank_id,
|
collection = await self.client.get_or_create_collection(
|
||||||
name=name,
|
name=memory_bank.identifier,
|
||||||
config=config,
|
|
||||||
url=url,
|
|
||||||
)
|
|
||||||
collection = await self.client.create_collection(
|
|
||||||
name=bank_id,
|
|
||||||
metadata={"bank": bank.json()},
|
|
||||||
)
|
)
|
||||||
bank_index = BankWithIndex(
|
bank_index = BankWithIndex(
|
||||||
bank=bank, index=ChromaIndex(self.client, collection)
|
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = bank_index
|
self.cache[memory_bank.identifier] = bank_index
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
|
||||||
if bank_index is None:
|
|
||||||
return None
|
|
||||||
return bank_index.bank
|
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
|
if bank is None:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
collections = await self.client.list_collections()
|
collections = await self.client.list_collections()
|
||||||
for collection in collections:
|
for collection in collections:
|
||||||
if collection.name == bank_id:
|
if collection.name == bank_id:
|
||||||
print(collection.metadata)
|
|
||||||
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=ChromaIndex(self.client, collection),
|
index=ChromaIndex(self.client, collection),
|
||||||
|
|
|
@ -4,18 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
from typing import List
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from psycopg2 import sql
|
from psycopg2 import sql
|
||||||
from psycopg2.extras import execute_values, Json
|
from psycopg2.extras import execute_values, Json
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
|
@ -32,33 +28,6 @@ def check_extension_version(cur):
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
|
||||||
query = sql.SQL(
|
|
||||||
"""
|
|
||||||
INSERT INTO metadata_store (key, data)
|
|
||||||
VALUES %s
|
|
||||||
ON CONFLICT (key) DO UPDATE
|
|
||||||
SET data = EXCLUDED.data
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
values = [(key, Json(model.dict())) for key, model in keys_models]
|
|
||||||
execute_values(cur, query, values, template="(%s, %s)")
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(cur, keys: List[str], cls):
|
|
||||||
query = "SELECT key, data FROM metadata_store"
|
|
||||||
if keys:
|
|
||||||
placeholders = ",".join(["%s"] * len(keys))
|
|
||||||
query += f" WHERE key IN ({placeholders})"
|
|
||||||
cur.execute(query, keys)
|
|
||||||
else:
|
|
||||||
cur.execute(query)
|
|
||||||
|
|
||||||
rows = cur.fetchall()
|
|
||||||
return [cls(**row["data"]) for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
class PGVectorIndex(EmbeddingIndex):
|
class PGVectorIndex(EmbeddingIndex):
|
||||||
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
||||||
self.cursor = cursor
|
self.cursor = cursor
|
||||||
|
@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
class PGVectorMemoryAdapter(Memory):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig) -> None:
|
||||||
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Vector extension is not installed.")
|
raise RuntimeError("Vector extension is not installed.")
|
||||||
|
|
||||||
self.cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
data JSONB
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_memory_bank(
|
||||||
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
self,
|
||||||
name: str,
|
memory_bank: MemoryBankDef,
|
||||||
config: MemoryBankConfig,
|
) -> None:
|
||||||
url: Optional[URL] = None,
|
assert (
|
||||||
) -> MemoryBank:
|
memory_bank.type == MemoryBankType.vector.value
|
||||||
bank_id = str(uuid.uuid4())
|
), f"Only vector banks are supported {memory_bank.type}"
|
||||||
bank = MemoryBank(
|
|
||||||
bank_id=bank_id,
|
|
||||||
name=name,
|
|
||||||
config=config,
|
|
||||||
url=url,
|
|
||||||
)
|
|
||||||
upsert_models(
|
|
||||||
self.cursor,
|
|
||||||
[
|
|
||||||
(bank.bank_id, bank),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=memory_bank,
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
|
||||||
if bank_index is None:
|
|
||||||
return None
|
|
||||||
return bank_index.bank
|
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
banks = load_models(self.cursor, [bank_id], MemoryBank)
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
if not banks:
|
if not bank:
|
||||||
return None
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
bank = banks[0]
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
|
class SampleMemoryImpl(Memory):
|
||||||
class SampleMemoryImpl(Memory, RoutableProvider):
|
|
||||||
def __init__(self, config: SampleConfig):
|
def __init__(self, config: SampleConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||||
# these are the memory banks the Llama Stack will use to route requests to this provider
|
# these are the memory banks the Llama Stack will use to route requests to this provider
|
||||||
# perform validation here if necessary
|
# perform validation here if necessary
|
||||||
pass
|
pass
|
||||||
|
|
15
llama_stack/providers/adapters/memory/weaviate/__init__.py
Normal file
15
llama_stack/providers/adapters/memory/weaviate/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||||
|
from .weaviate import WeaviateMemoryAdapter
|
||||||
|
|
||||||
|
impl = WeaviateMemoryAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
16
llama_stack/providers/adapters/memory/weaviate/config.py
Normal file
16
llama_stack/providers/adapters/memory/weaviate/config.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateRequestProviderData(BaseModel):
|
||||||
|
weaviate_api_key: str
|
||||||
|
weaviate_cluster_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateConfig(BaseModel):
|
||||||
|
pass
|
180
llama_stack/providers/adapters/memory/weaviate/weaviate.py
Normal file
180
llama_stack/providers/adapters/memory/weaviate/weaviate.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import json
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import weaviate
|
||||||
|
import weaviate.classes as wvc
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from weaviate.classes.init import Auth
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||||
|
self.client = client
|
||||||
|
self.collection_name = collection_name
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(
|
||||||
|
embeddings
|
||||||
|
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
|
||||||
|
data_objects = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
data_objects.append(
|
||||||
|
wvc.data.DataObject(
|
||||||
|
properties={
|
||||||
|
"chunk_content": chunk.json(),
|
||||||
|
},
|
||||||
|
vector=embeddings[i].tolist(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inserting chunks into a prespecified Weaviate collection
|
||||||
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
|
||||||
|
# TODO: make this async friendly
|
||||||
|
collection.data.insert_many(data_objects)
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
|
||||||
|
results = collection.query.near_vector(
|
||||||
|
near_vector=embedding.tolist(),
|
||||||
|
limit=k,
|
||||||
|
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc in results.objects:
|
||||||
|
chunk_json = doc.properties["chunk_content"]
|
||||||
|
try:
|
||||||
|
chunk_dict = json.loads(chunk_json)
|
||||||
|
chunk = Chunk(**chunk_dict)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Failed to parse document: {chunk_json}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(1.0 / doc.metadata.distance)
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
|
||||||
|
def __init__(self, config: WeaviateConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.client_cache = {}
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def _get_client(self) -> weaviate.Client:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
assert provider_data is not None, "Request provider data must be set"
|
||||||
|
assert isinstance(provider_data, WeaviateRequestProviderData)
|
||||||
|
|
||||||
|
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
|
||||||
|
if key in self.client_cache:
|
||||||
|
return self.client_cache[key]
|
||||||
|
|
||||||
|
client = weaviate.connect_to_weaviate_cloud(
|
||||||
|
cluster_url=provider_data.weaviate_cluster_url,
|
||||||
|
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
|
||||||
|
)
|
||||||
|
self.client_cache[key] = client
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for client in self.client_cache.values():
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
async def register_memory_bank(
|
||||||
|
self,
|
||||||
|
memory_bank: MemoryBankDef,
|
||||||
|
) -> None:
|
||||||
|
assert (
|
||||||
|
memory_bank.type == MemoryBankType.vector.value
|
||||||
|
), f"Only vector banks are supported {memory_bank.type}"
|
||||||
|
|
||||||
|
client = self._get_client()
|
||||||
|
|
||||||
|
# Create collection if it doesn't exist
|
||||||
|
if not client.collections.exists(memory_bank.identifier):
|
||||||
|
client.collections.create(
|
||||||
|
name=memory_bank.identifier,
|
||||||
|
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||||
|
properties=[
|
||||||
|
wvc.config.Property(
|
||||||
|
name="chunk_content",
|
||||||
|
data_type=wvc.config.DataType.TEXT,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=memory_bank,
|
||||||
|
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||||
|
)
|
||||||
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
|
if not bank:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
client = self._get_client()
|
||||||
|
if not client.collections.exists(bank_id):
|
||||||
|
raise ValueError(f"Collection with name `{bank_id}` not found")
|
||||||
|
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
return await index.query_documents(query, params)
|
|
@ -7,14 +7,12 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import traceback
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa
|
from llama_stack.apis.safety import * # noqa
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
|
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_SHIELD_TYPES = [
|
BEDROCK_SUPPORTED_SHIELDS = [
|
||||||
"bedrock_guardrail",
|
ShieldType.generic_content_shield.value,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, RoutableProvider):
|
class BedrockSafetyAdapter(Safety):
|
||||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
if not config.aws_profile:
|
if not config.aws_profile:
|
||||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.registered_shields = []
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
for key in routing_keys:
|
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
|
||||||
if key not in SUPPORTED_SHIELD_TYPES:
|
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||||
raise ValueError(f"Unknown safety shield type: {key}")
|
|
||||||
|
shield_params = shield.params
|
||||||
|
if "guardrailIdentifier" not in shield_params:
|
||||||
|
raise ValueError(
|
||||||
|
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "guardrailVersion" not in shield_params:
|
||||||
|
raise ValueError(
|
||||||
|
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||||
|
)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
if shield_type not in SUPPORTED_SHIELD_TYPES:
|
shield_def = await self.shield_store.get_shield(shield_type)
|
||||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {shield_type}")
|
||||||
|
|
||||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||||
```content = [
|
```content = [
|
||||||
|
@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||||
|
|
||||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
logger.debug(f"run_shield::{params}::messages={messages}")
|
|
||||||
if "guardrailIdentifier" not in params:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "guardrailVersion" not in params:
|
shield_params = shield_def.params
|
||||||
raise RuntimeError(
|
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
|
||||||
)
|
|
||||||
|
|
||||||
# - convert the messages into format Bedrock expects
|
# - convert the messages into format Bedrock expects
|
||||||
content_messages = []
|
content_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content_messages.append({"text": {"text": message.content}})
|
content_messages.append({"text": {"text": message.content}})
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.boto_client.apply_guardrail(
|
response = self.boto_client.apply_guardrail(
|
||||||
guardrailIdentifier=params.get("guardrailIdentifier"),
|
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||||
guardrailVersion=params.get("guardrailVersion"),
|
guardrailVersion=shield_params["guardrailVersion"],
|
||||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||||
content=content_messages,
|
content=content_messages,
|
||||||
)
|
)
|
||||||
logger.debug(f"run_shield:: response: {response}::")
|
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
user_message = ""
|
||||||
user_message = ""
|
metadata = {}
|
||||||
metadata = {}
|
for output in response["outputs"]:
|
||||||
for output in response["outputs"]:
|
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
user_message = output["text"]
|
||||||
user_message = output["text"]
|
for assessment in response["assessments"]:
|
||||||
for assessment in response["assessments"]:
|
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
metadata = dict(assessment)
|
||||||
metadata = dict(assessment)
|
|
||||||
return SafetyViolation(
|
|
||||||
user_message=user_message,
|
|
||||||
violation_level=ViolationLevel.ERROR,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
return SafetyViolation(
|
||||||
error_str = traceback.format_exc()
|
user_message=user_message,
|
||||||
logger.error(
|
violation_level=ViolationLevel.ERROR,
|
||||||
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
|
class SampleSafetyImpl(Safety):
|
||||||
class SampleSafetyImpl(Safety, RoutableProvider):
|
|
||||||
def __init__(self, config: SampleConfig):
|
def __init__(self, config: SampleConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||||
# perform validation here if necessary
|
# perform validation here if necessary
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -6,26 +6,20 @@
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
RunShieldResponse,
|
|
||||||
Safety,
|
|
||||||
SafetyViolation,
|
|
||||||
ViolationLevel,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
|
||||||
from .config import TogetherSafetyConfig
|
from .config import TogetherSafetyConfig
|
||||||
|
|
||||||
|
|
||||||
SAFETY_SHIELD_TYPES = {
|
TOGETHER_SHIELD_MODEL_MAP = {
|
||||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -35,16 +29,20 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
for key in routing_keys:
|
if shield.type != ShieldType.llama_guard.value:
|
||||||
if key not in SAFETY_SHIELD_TYPES:
|
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||||
raise ValueError(f"Unknown safety shield type: {key}")
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
if shield_type not in SAFETY_SHIELD_TYPES:
|
shield_def = await self.shield_store.get_shield(shield_type)
|
||||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {shield_type}")
|
||||||
|
|
||||||
|
model = shield_def.params.get("model", "llama_guard")
|
||||||
|
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||||
|
raise ValueError(f"Unsupported safety model: {model}")
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
if self.config.api_key is not None:
|
if self.config.api_key is not None:
|
||||||
|
@ -57,8 +55,6 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||||
)
|
)
|
||||||
together_api_key = provider_data.together_api_key
|
together_api_key = provider_data.together_api_key
|
||||||
|
|
||||||
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
|
||||||
|
|
||||||
# messages can have role assistant or user
|
# messages can have role assistant or user
|
||||||
api_messages = []
|
api_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
@ -66,7 +62,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||||
api_messages.append({"role": message.role, "content": message.content})
|
api_messages.append({"role": message.role, "content": message.content})
|
||||||
|
|
||||||
violation = await get_safety_response(
|
violation = await get_safety_response(
|
||||||
together_api_key, model_name, api_messages
|
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
|
||||||
)
|
)
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
|
|
|
@ -43,23 +43,14 @@ class ProviderSpec(BaseModel):
|
||||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# used internally by the resolver; this is a hack for now
|
||||||
|
deps__: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class RoutingTable(Protocol):
|
class RoutingTable(Protocol):
|
||||||
def get_routing_keys(self) -> List[str]: ...
|
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
class RoutableProvider(Protocol):
|
|
||||||
"""
|
|
||||||
A provider which sits behind the RoutingTable and can get routed to.
|
|
||||||
|
|
||||||
All Inference / Safety / Memory providers fall into this bucket.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def validate_routing_keys(self, keys: List[str]) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_type: str = Field(
|
adapter_type: str = Field(
|
||||||
|
@ -156,6 +147,10 @@ as being "Llama Stack compatible"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||||
|
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||||
|
|
||||||
|
|
||||||
# Can avoid this by using Pydantic computed_field
|
# Can avoid this by using Pydantic computed_field
|
||||||
def remote_provider_spec(
|
def remote_provider_spec(
|
||||||
api: Api, adapter: Optional[AdapterSpec] = None
|
api: Api, adapter: Optional[AdapterSpec] = None
|
||||||
|
|
|
@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_and_execute_turn(
|
async def create_and_execute_turn(
|
||||||
self, request: AgentTurnCreateRequest
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(request.session_id)
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
if session_info.memory_bank_id is None:
|
if session_info.memory_bank_id is None:
|
||||||
memory_bank = await self.memory_api.create_memory_bank(
|
bank_id = f"memory_bank_{session_id}"
|
||||||
name=f"memory_bank_{session_id}",
|
memory_bank = VectorMemoryBankDef(
|
||||||
config=VectorMemoryBankConfig(
|
identifier=bank_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
bank_id = memory_bank.bank_id
|
await self.memory_api.register_memory_bank(memory_bank)
|
||||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||||
else:
|
else:
|
||||||
bank_id = session_info.memory_bank_id
|
bank_id = session_info.memory_bank_id
|
||||||
|
@ -673,7 +674,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def _retrieve_context(
|
async def _retrieve_context(
|
||||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||||
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
|
||||||
bank_ids = []
|
bank_ids = []
|
||||||
|
|
||||||
memory = self._memory_tool_definition()
|
memory = self._memory_tool_definition()
|
||||||
|
@ -722,12 +723,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
chunks = [c for r in results for c in r.chunks]
|
chunks = [c for r in results for c in r.chunks]
|
||||||
scores = [s for r in results for s in r.scores]
|
scores = [s for r in results for s in r.scores]
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return None, bank_ids
|
||||||
|
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(
|
chunks, scores = zip(
|
||||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||||
)
|
)
|
||||||
if not chunks:
|
|
||||||
return None, bank_ids
|
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = []
|
||||||
|
|
|
@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_turn(
|
def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
attachments: Optional[List[Attachment]] = None,
|
attachments: Optional[List[Attachment]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
agent = await self.get_agent(agent_id)
|
|
||||||
|
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
stream=stream,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
if stream:
|
||||||
|
return self._create_agent_turn_streaming(request)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||||
|
|
||||||
|
async def _create_agent_turn_streaming(
|
||||||
|
self,
|
||||||
|
request: AgentTurnCreateRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(request.agent_id)
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import CodeShieldConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: CodeShieldConfig, deps):
|
||||||
|
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||||
|
|
||||||
|
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
|
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
|
if shield.type != ShieldType.code_scanner.value:
|
||||||
|
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_type: str,
|
||||||
|
messages: List[Message],
|
||||||
|
params: Dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
shield_def = await self.shield_store.get_shield(shield_type)
|
||||||
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {shield_type}")
|
||||||
|
|
||||||
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
|
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||||
|
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||||
|
result = await CodeShield.scan_code(text)
|
||||||
|
|
||||||
|
violation = None
|
||||||
|
if result.is_insecure:
|
||||||
|
violation = SafetyViolation(
|
||||||
|
violation_level=(ViolationLevel.ERROR),
|
||||||
|
user_message="Sorry, I found security concerns in the code.",
|
||||||
|
metadata={
|
||||||
|
"violation_type": ",".join(
|
||||||
|
[issue.pattern_id for issue in result.issues_found]
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return RunShieldResponse(violation=violation)
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CodeShieldConfig(BaseModel):
|
||||||
|
pass
|
|
@ -43,13 +43,12 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
print("generation start")
|
print("generation start")
|
||||||
for msg in x1[:5]:
|
for msg in x1[:5]:
|
||||||
print("generation for msg: ", msg)
|
print("generation for msg: ", msg)
|
||||||
response = self.inference_api.chat_completion(
|
response = await self.inference_api.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[msg],
|
messages=[msg],
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
async for x in response:
|
generation_outputs.append(response.completion_message.content)
|
||||||
generation_outputs.append(x.completion_message.content)
|
|
||||||
|
|
||||||
x2 = task_impl.postprocess(generation_outputs)
|
x2 = task_impl.postprocess(generation_outputs)
|
||||||
eval_results = task_impl.score(x2)
|
eval_results = task_impl.score(x2)
|
||||||
|
|
|
@ -297,7 +297,7 @@ class Llama:
|
||||||
token=next_token[0].item(),
|
token=next_token[0].item(),
|
||||||
text=self.tokenizer.decode(next_token.tolist()),
|
text=self.tokenizer.decode(next_token.tolist()),
|
||||||
logprobs=(
|
logprobs=(
|
||||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
||||||
if logprobs
|
if logprobs
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
|
|
@ -6,15 +6,14 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from typing import AsyncIterator, List, Union
|
from typing import AsyncGenerator, List
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
chat_completion_request_to_messages,
|
||||||
augment_messages_for_tools,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
@ -25,7 +24,7 @@ from .model_parallel import LlamaModelParallelGenerator
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
class MetaReferenceInferenceImpl(Inference):
|
||||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
|
@ -35,21 +34,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
# verify that the checkpoint actually is for this model lol
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
print(f"Loading model `{self.model.descriptor()}`")
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
assert (
|
if model.identifier != self.model.descriptor():
|
||||||
len(routing_keys) == 1
|
raise RuntimeError(
|
||||||
), f"Only one routing key is supported {routing_keys}"
|
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||||
assert routing_keys[0] == self.config.model
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
def chat_completion(
|
||||||
# top-level server is going to do. make the typing more specific here
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -59,9 +57,10 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncIterator[
|
) -> AsyncGenerator:
|
||||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
if logprobs:
|
||||||
]:
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -74,7 +73,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = augment_messages_for_tools(request)
|
|
||||||
model = resolve_model(request.model)
|
model = resolve_model(request.model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -88,21 +86,74 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
if SEMAPHORE.locked():
|
if SEMAPHORE.locked():
|
||||||
raise RuntimeError("Only one concurrent request is supported")
|
raise RuntimeError("Only one concurrent request is supported")
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
async with SEMAPHORE:
|
async with SEMAPHORE:
|
||||||
if request.stream:
|
messages = chat_completion_request_to_messages(request)
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
|
||||||
delta="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
buffer = ""
|
for token_result in self.generator.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
temperature=request.sampling_params.temperature,
|
||||||
|
top_p=request.sampling_params.top_p,
|
||||||
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
|
logprobs=request.logprobs,
|
||||||
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
|
):
|
||||||
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
if token_result.text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif token_result.text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
message = self.generator.formatter.decode_assistant_message(
|
||||||
|
tokens, stop_reason
|
||||||
|
)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=message,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
async with SEMAPHORE:
|
||||||
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
delta="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
logprobs = []
|
||||||
|
stop_reason = None
|
||||||
ipython = False
|
ipython = False
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
for token_result in self.generator.chat_completion(
|
||||||
|
@ -113,10 +164,9 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
):
|
):
|
||||||
buffer += token_result.text
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||||
ipython = True
|
ipython = True
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -127,13 +177,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
buffer = buffer[len("<|python_tag|>") :]
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not request.stream:
|
|
||||||
if request.logprobs:
|
|
||||||
logprobs.append(token_result.logprob)
|
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
|
@ -154,59 +197,61 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
delta = text
|
delta = text
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
|
||||||
# if someone breaks the iteration before coming here we are toast
|
|
||||||
message = self.generator.formatter.decode_assistant_message(
|
message = self.generator.formatter.decode_assistant_message(
|
||||||
tokens, stop_reason
|
tokens, stop_reason
|
||||||
)
|
)
|
||||||
if request.stream:
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.failure,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
|
if ipython and not parsed_tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta="",
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.failure,
|
||||||
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
for tool_call in message.tool_calls:
|
||||||
else:
|
yield ChatCompletionResponseStreamChunk(
|
||||||
yield ChatCompletionResponse(
|
event=ChatCompletionResponseEvent(
|
||||||
completion_message=message,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
delta=ToolCallDelta(
|
||||||
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -13,15 +13,15 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
|
||||||
|
|
||||||
|
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||||
|
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from llama_stack.apis.inference import QuantizationType
|
||||||
|
|
||||||
from llama_stack.apis.inference.config import (
|
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||||
CheckpointQuantizationFormat,
|
|
||||||
MetaReferenceImplConfig,
|
MetaReferenceImplConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -14,7 +13,6 @@ import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -63,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory, RoutableProvider):
|
class FaissMemoryImpl(Memory):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -72,37 +70,18 @@ class FaissMemoryImpl(Memory, RoutableProvider):
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_memory_bank(
|
||||||
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
self,
|
||||||
name: str,
|
memory_bank: MemoryBankDef,
|
||||||
config: MemoryBankConfig,
|
) -> None:
|
||||||
url: Optional[URL] = None,
|
|
||||||
) -> MemoryBank:
|
|
||||||
assert url is None, "URL is not supported for this implementation"
|
|
||||||
assert (
|
assert (
|
||||||
config.type == MemoryBankType.vector.value
|
memory_bank.type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {config.type}"
|
), f"Only vector banks are supported {memory_bank.type}"
|
||||||
|
|
||||||
bank_id = str(uuid.uuid4())
|
index = BankWithIndex(
|
||||||
bank = MemoryBank(
|
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||||
bank_id=bank_id,
|
|
||||||
name=name,
|
|
||||||
config=config,
|
|
||||||
url=url,
|
|
||||||
)
|
)
|
||||||
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
|
self.cache[memory_bank.identifier] = index
|
||||||
self.cache[bank_id] = index
|
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
index = self.cache.get(bank_id)
|
|
||||||
if index is None:
|
|
||||||
return None
|
|
||||||
return index.bank
|
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str:
|
||||||
return interleaved_text_media_as_str(message.content)
|
return interleaved_text_media_as_str(message.content)
|
||||||
|
|
||||||
|
|
||||||
# For shields that operate on simple strings
|
|
||||||
class TextShield(ShieldBase):
|
class TextShield(ShieldBase):
|
||||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||||
return "\n".join([message_content_as_str(m) for m in messages])
|
return "\n".join([message_content_as_str(m) for m in messages])
|
||||||
|
@ -56,9 +55,3 @@ class TextShield(ShieldBase):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class DummyShield(TextShield):
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
|
||||||
# Dummy return LOW to test e2e
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
|
@ -9,23 +9,19 @@ from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import CoreModelId, safety_models
|
from llama_models.sku_list import CoreModelId, safety_models
|
||||||
|
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceShieldType(Enum):
|
class PromptGuardType(Enum):
|
||||||
llama_guard = "llama_guard"
|
injection = "injection"
|
||||||
code_scanner_guard = "code_scanner_guard"
|
jailbreak = "jailbreak"
|
||||||
injection_shield = "injection_shield"
|
|
||||||
jailbreak_shield = "jailbreak_shield"
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShieldConfig(BaseModel):
|
class LlamaGuardShieldConfig(BaseModel):
|
||||||
model: str = "Llama-Guard-3-1B"
|
model: str = "Llama-Guard-3-1B"
|
||||||
excluded_categories: List[str] = []
|
excluded_categories: List[str] = []
|
||||||
disable_input_check: bool = False
|
|
||||||
disable_output_check: bool = False
|
|
||||||
|
|
||||||
@validator("model")
|
@field_validator("model")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
permitted_models = [
|
permitted_models = [
|
||||||
|
@ -47,10 +43,6 @@ class LlamaGuardShieldConfig(BaseModel):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardShieldConfig(BaseModel):
|
|
||||||
model: str = "Prompt-Guard-86M"
|
|
||||||
|
|
||||||
|
|
||||||
class SafetyConfig(BaseModel):
|
class SafetyConfig(BaseModel):
|
||||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||||
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None
|
enable_prompt_guard: Optional[bool] = False
|
||||||
|
|
|
@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
model: str,
|
model: str,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: List[str] = None,
|
||||||
disable_input_check: bool = False,
|
|
||||||
disable_output_check: bool = False,
|
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
):
|
):
|
||||||
super().__init__(on_violation_action)
|
super().__init__(on_violation_action)
|
||||||
|
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.excluded_categories = excluded_categories
|
self.excluded_categories = excluded_categories
|
||||||
self.disable_input_check = disable_input_check
|
|
||||||
self.disable_output_check = disable_output_check
|
|
||||||
|
|
||||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
match = re.match(r"^unsafe\n(.*)$", response)
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
|
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
messages = self.validate_messages(messages)
|
messages = self.validate_messages(messages)
|
||||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
||||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||||
shield_input_message = self.build_vision_shield_input(messages)
|
shield_input_message = self.build_vision_shield_input(messages)
|
|
@ -6,56 +6,43 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import Api, RoutableProvider
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
from .base import OnViolationAction, ShieldBase
|
||||||
OnViolationAction,
|
from .config import SafetyConfig
|
||||||
)
|
from .llama_guard import LlamaGuardShield
|
||||||
|
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
||||||
from .config import MetaReferenceShieldType, SafetyConfig
|
|
||||||
|
|
||||||
from .shields import (
|
|
||||||
CodeScannerShield,
|
|
||||||
InjectionShield,
|
|
||||||
JailbreakShield,
|
|
||||||
LlamaGuardShield,
|
|
||||||
PromptGuardShield,
|
|
||||||
ShieldBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_and_get_path(model_name: str) -> str:
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
model = resolve_model(model_name)
|
|
||||||
assert model is not None, f"Could not resolve model {model_name}"
|
|
||||||
model_dir = model_local_dir(model.descriptor())
|
|
||||||
return model_dir
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
class MetaReferenceSafetyImpl(Safety):
|
||||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = deps[Api.inference]
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
|
self.available_shields = []
|
||||||
|
if config.llama_guard_shield:
|
||||||
|
self.available_shields.append(ShieldType.llama_guard.value)
|
||||||
|
if config.enable_prompt_guard:
|
||||||
|
self.available_shields.append(ShieldType.prompt_guard.value)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
shield_cfg = self.config.prompt_guard_shield
|
if self.config.enable_prompt_guard:
|
||||||
if shield_cfg is not None:
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
|
||||||
_ = PromptGuardShield.instance(model_dir)
|
_ = PromptGuardShield.instance(model_dir)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
if shield.type not in self.available_shields:
|
||||||
for key in routing_keys:
|
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||||
if key not in available_shields:
|
|
||||||
raise ValueError(f"Unknown safety shield type: {key}")
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -63,10 +50,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
shield_def = await self.shield_store.get_shield(shield_type)
|
||||||
assert shield_type in available_shields, f"Unknown shield {shield_type}"
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {shield_type}")
|
||||||
|
|
||||||
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
|
shield = self.get_shield_impl(shield_def)
|
||||||
|
|
||||||
messages = messages.copy()
|
messages = messages.copy()
|
||||||
# some shields like llama-guard require the first message to be a user message
|
# some shields like llama-guard require the first message to be a user message
|
||||||
|
@ -92,34 +80,22 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||||
|
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||||
cfg = self.config
|
if shield.type == ShieldType.llama_guard.value:
|
||||||
if typ == MetaReferenceShieldType.llama_guard:
|
cfg = self.config.llama_guard_shield
|
||||||
cfg = cfg.llama_guard_shield
|
|
||||||
assert (
|
|
||||||
cfg is not None
|
|
||||||
), "Cannot use LlamaGuardShield since not present in config"
|
|
||||||
|
|
||||||
return LlamaGuardShield(
|
return LlamaGuardShield(
|
||||||
model=cfg.model,
|
model=cfg.model,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
excluded_categories=cfg.excluded_categories,
|
excluded_categories=cfg.excluded_categories,
|
||||||
disable_input_check=cfg.disable_input_check,
|
|
||||||
disable_output_check=cfg.disable_output_check,
|
|
||||||
)
|
)
|
||||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
elif shield.type == ShieldType.prompt_guard.value:
|
||||||
assert (
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
cfg.prompt_guard_shield is not None
|
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
if subtype == "injection":
|
||||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
return InjectionShield.instance(model_dir)
|
||||||
return JailbreakShield.instance(model_dir)
|
elif subtype == "jailbreak":
|
||||||
elif typ == MetaReferenceShieldType.injection_shield:
|
return JailbreakShield.instance(model_dir)
|
||||||
assert (
|
else:
|
||||||
cfg.prompt_guard_shield is not None
|
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||||
), "Cannot use PromptGuardShield since not present in config"
|
|
||||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
|
||||||
return InjectionShield.instance(model_dir)
|
|
||||||
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
|
||||||
return CodeScannerShield.instance()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown shield type: {typ}")
|
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||||
|
|
|
@ -1,33 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
# supress warnings and spew of logs from hugging face
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from .base import ( # noqa: F401
|
|
||||||
DummyShield,
|
|
||||||
OnViolationAction,
|
|
||||||
ShieldBase,
|
|
||||||
ShieldResponse,
|
|
||||||
TextShield,
|
|
||||||
)
|
|
||||||
from .code_scanner import CodeScannerShield # noqa: F401
|
|
||||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
|
||||||
from .prompt_guard import ( # noqa: F401
|
|
||||||
InjectionShield,
|
|
||||||
JailbreakShield,
|
|
||||||
PromptGuardShield,
|
|
||||||
)
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
|
|
@ -1,27 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from .base import ShieldResponse, TextShield
|
|
||||||
|
|
||||||
|
|
||||||
class CodeScannerShield(TextShield):
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
|
||||||
from codeshield.cs import CodeShield
|
|
||||||
|
|
||||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
|
||||||
result = await CodeShield.scan_code(text)
|
|
||||||
if result.is_insecure:
|
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=True,
|
|
||||||
violation_type=",".join(
|
|
||||||
[issue.pattern_id for issue in result.issues_found]
|
|
||||||
),
|
|
||||||
violation_return_message="Sorry, I found security concerns in the code.",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
11
llama_stack/providers/impls/vllm/__init__.py
Normal file
11
llama_stack/providers/impls/vllm/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import VLLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||||
|
from .vllm import VLLMInferenceImpl
|
||||||
|
|
||||||
|
impl = VLLMInferenceImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
35
llama_stack/providers/impls/vllm/config.py
Normal file
35
llama_stack/providers/impls/vllm/config.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VLLMConfig(BaseModel):
|
||||||
|
"""Configuration for the vLLM inference provider."""
|
||||||
|
|
||||||
|
model: str = Field(
|
||||||
|
default="Llama3.1-8B-Instruct",
|
||||||
|
description="Model descriptor from `llama model list`",
|
||||||
|
)
|
||||||
|
tensor_parallel_size: int = Field(
|
||||||
|
default=1,
|
||||||
|
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("model")
|
||||||
|
@classmethod
|
||||||
|
def validate_model(cls, model: str) -> str:
|
||||||
|
permitted_models = supported_inference_models()
|
||||||
|
if model not in permitted_models:
|
||||||
|
model_list = "\n\t".join(permitted_models)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||||
|
)
|
||||||
|
return model
|
241
llama_stack/providers/impls/vllm/vllm.py
Normal file
241
llama_stack/providers/impls/vllm/vllm.py
Normal file
|
@ -0,0 +1,241 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import VLLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _random_uuid() -> str:
|
||||||
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
|
||||||
|
"""Convert sampling params to vLLM sampling params."""
|
||||||
|
if sampling_params is None:
|
||||||
|
return SamplingParams()
|
||||||
|
|
||||||
|
# TODO convert what I saw in my first test ... but surely there's more to do here
|
||||||
|
kwargs = {
|
||||||
|
"temperature": sampling_params.temperature,
|
||||||
|
}
|
||||||
|
if sampling_params.top_k >= 1:
|
||||||
|
kwargs["top_k"] = sampling_params.top_k
|
||||||
|
if sampling_params.top_p:
|
||||||
|
kwargs["top_p"] = sampling_params.top_p
|
||||||
|
if sampling_params.max_tokens >= 1:
|
||||||
|
kwargs["max_tokens"] = sampling_params.max_tokens
|
||||||
|
if sampling_params.repetition_penalty > 0:
|
||||||
|
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
||||||
|
|
||||||
|
return SamplingParams(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||||
|
"""Inference implementation for vLLM."""
|
||||||
|
|
||||||
|
HF_MODEL_MAPPINGS = {
|
||||||
|
# TODO: seems like we should be able to build this table dynamically ...
|
||||||
|
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||||
|
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||||
|
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||||
|
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
||||||
|
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
||||||
|
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
||||||
|
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
||||||
|
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||||
|
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
||||||
|
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||||
|
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||||
|
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
|
||||||
|
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
|
||||||
|
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||||
|
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
|
||||||
|
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||||
|
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
||||||
|
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
|
||||||
|
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
|
||||||
|
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: VLLMConfig):
|
||||||
|
Inference.__init__(self)
|
||||||
|
ModelRegistryHelper.__init__(
|
||||||
|
self,
|
||||||
|
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.engine = None
|
||||||
|
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize the vLLM inference adapter."""
|
||||||
|
|
||||||
|
log.info("Initializing vLLM inference adapter")
|
||||||
|
|
||||||
|
# Disable usage stats reporting. This would be a surprising thing for most
|
||||||
|
# people to find out was on by default.
|
||||||
|
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
|
||||||
|
if "VLLM_NO_USAGE_STATS" not in os.environ:
|
||||||
|
os.environ["VLLM_NO_USAGE_STATS"] = "1"
|
||||||
|
|
||||||
|
hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model)
|
||||||
|
|
||||||
|
# TODO -- there are a ton of options supported here ...
|
||||||
|
engine_args = AsyncEngineArgs()
|
||||||
|
engine_args.model = hf_model
|
||||||
|
# We will need a new config item for this in the future if model support is more broad
|
||||||
|
# than it is today (llama only)
|
||||||
|
engine_args.tokenizer = hf_model
|
||||||
|
engine_args.tensor_parallel_size = self.config.tensor_parallel_size
|
||||||
|
|
||||||
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Shutdown the vLLM inference adapter."""
|
||||||
|
log.info("Shutting down vLLM inference adapter")
|
||||||
|
if self.engine:
|
||||||
|
self.engine.shutdown_background_loop()
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
content: InterleavedTextMedia,
|
||||||
|
sampling_params: Any | None = ...,
|
||||||
|
stream: bool | None = False,
|
||||||
|
logprobs: LogProbConfig | None = None,
|
||||||
|
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||||
|
log.info("vLLM completion")
|
||||||
|
messages = [UserMessage(content=content)]
|
||||||
|
return self.chat_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[Message],
|
||||||
|
sampling_params: Any | None = ...,
|
||||||
|
tools: list[ToolDefinition] | None = ...,
|
||||||
|
tool_choice: ToolChoice | None = ...,
|
||||||
|
tool_prompt_format: ToolPromptFormat | None = ...,
|
||||||
|
stream: bool | None = False,
|
||||||
|
logprobs: LogProbConfig | None = None,
|
||||||
|
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||||
|
log.info("vLLM chat completion")
|
||||||
|
|
||||||
|
assert self.engine is not None
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Sampling params: %s", sampling_params)
|
||||||
|
request_id = _random_uuid()
|
||||||
|
|
||||||
|
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||||
|
vllm_sampling_params = _vllm_sampling_params(request.sampling_params)
|
||||||
|
results_generator = self.engine.generate(
|
||||||
|
prompt, vllm_sampling_params, request_id
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request, results_generator)
|
||||||
|
else:
|
||||||
|
return self._nonstream_chat_completion(request, results_generator)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
outputs = [o async for o in results_generator]
|
||||||
|
final_output = outputs[-1]
|
||||||
|
|
||||||
|
assert final_output is not None
|
||||||
|
outputs = final_output.outputs
|
||||||
|
finish_reason = outputs[-1].stop_reason
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
text="".join([output.text for output in outputs]),
|
||||||
|
)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
return process_chat_completion_response(request, response, self.formatter)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
async for chunk in results_generator:
|
||||||
|
if not chunk.outputs:
|
||||||
|
log.warning("Empty chunk received")
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = "".join([output.text for output in chunk.outputs])
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=chunk.outputs[-1].stop_reason,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
|
async for chunk in process_chat_completion_stream_response(
|
||||||
|
request, stream, self.formatter
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self, model: str, contents: list[InterleavedTextMedia]
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
log.info("vLLM embeddings")
|
||||||
|
# TODO
|
||||||
|
raise NotImplementedError()
|
|
@ -41,6 +41,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="ollama",
|
adapter_type="ollama",
|
||||||
pip_packages=["ollama"],
|
pip_packages=["ollama"],
|
||||||
|
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
|
||||||
module="llama_stack.providers.adapters.inference.ollama",
|
module="llama_stack.providers.adapters.inference.ollama",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
@ -103,4 +104,24 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="databricks",
|
||||||
|
pip_packages=[
|
||||||
|
"openai",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.adapters.inference.databricks",
|
||||||
|
config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="vllm",
|
||||||
|
pip_packages=[
|
||||||
|
"vllm",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.impls.vllm",
|
||||||
|
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -56,6 +56,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
Api.memory,
|
||||||
|
AdapterSpec(
|
||||||
|
adapter_type="weaviate",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||||
|
module="llama_stack.providers.adapters.memory.weaviate",
|
||||||
|
config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -21,7 +21,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_type="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
|
||||||
"transformers",
|
"transformers",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
],
|
],
|
||||||
|
@ -61,4 +60,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.safety,
|
||||||
|
provider_type="meta-reference/codeshield",
|
||||||
|
pip_packages=[
|
||||||
|
"codeshield",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.impls.meta_reference.codeshield",
|
||||||
|
config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig",
|
||||||
|
api_dependencies=[],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
5
llama_stack/providers/tests/__init__.py
Normal file
5
llama_stack/providers/tests/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
5
llama_stack/providers/tests/agents/__init__.py
Normal file
5
llama_stack/providers/tests/agents/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
5
llama_stack/providers/tests/inference/__init__.py
Normal file
5
llama_stack/providers/tests/inference/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,25 @@
|
||||||
|
providers:
|
||||||
|
- provider_id: test-ollama
|
||||||
|
provider_type: remote::ollama
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 11434
|
||||||
|
- provider_id: test-tgi
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://localhost:7001
|
||||||
|
- provider_id: test-remote
|
||||||
|
provider_type: remote
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 7002
|
||||||
|
- provider_id: test-together
|
||||||
|
provider_type: remote::together
|
||||||
|
config: {}
|
||||||
|
# if a provider needs private keys from the client, they use the
|
||||||
|
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||||
|
# this is a place to provide such data.
|
||||||
|
provider_data:
|
||||||
|
"test-together":
|
||||||
|
together_api_key:
|
||||||
|
0xdeadbeefputrealapikeyhere
|
243
llama_stack/providers/tests/inference/test_inference.py
Normal file
243
llama_stack/providers/tests/inference/test_inference.py
Normal file
|
@ -0,0 +1,243 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||||
|
# since it depends on the provider you are testing. On top of that you need
|
||||||
|
# `pytest` and `pytest-asyncio` installed.
|
||||||
|
#
|
||||||
|
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||||
|
#
|
||||||
|
# 3. Run:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# PROVIDER_ID=<your_provider> \
|
||||||
|
# PROVIDER_CONFIG=provider_config.yaml \
|
||||||
|
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
|
||||||
|
# --tb=short --disable-warnings
|
||||||
|
# ```
|
||||||
|
|
||||||
|
|
||||||
|
def group_chunks(response):
|
||||||
|
return {
|
||||||
|
event_type: list(group)
|
||||||
|
for event_type, group in itertools.groupby(
|
||||||
|
response, key=lambda chunk: chunk.event.event_type
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Llama_8B = "Llama3.1-8B-Instruct"
|
||||||
|
Llama_3B = "Llama3.2-3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
def get_expected_stop_reason(model: str):
|
||||||
|
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||||
|
|
||||||
|
|
||||||
|
# This is going to create multiple Stack impls without tearing down the previous one
|
||||||
|
# Fix that!
|
||||||
|
@pytest_asyncio.fixture(
|
||||||
|
scope="session",
|
||||||
|
params=[
|
||||||
|
{"model": Llama_8B},
|
||||||
|
{"model": Llama_3B},
|
||||||
|
],
|
||||||
|
ids=lambda d: d["model"],
|
||||||
|
)
|
||||||
|
async def inference_settings(request):
|
||||||
|
model = request.param["model"]
|
||||||
|
impls = await resolve_impls_for_test(
|
||||||
|
Api.inference,
|
||||||
|
models=[
|
||||||
|
ModelDef(
|
||||||
|
identifier=model,
|
||||||
|
llama_model=model,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"impl": impls[Api.inference],
|
||||||
|
"common_params": {
|
||||||
|
"model": model,
|
||||||
|
"tool_choice": ToolChoice.auto,
|
||||||
|
"tool_prompt_format": (
|
||||||
|
ToolPromptFormat.json
|
||||||
|
if "Llama3.1" in model
|
||||||
|
else ToolPromptFormat.python_list
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_messages():
|
||||||
|
return [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
UserMessage(content="What's the weather like today?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tool_definition():
|
||||||
|
return ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get the current weather",
|
||||||
|
parameters={
|
||||||
|
"location": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The city and state, e.g. San Francisco, CA",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
response = await inference_impl.chat_completion(
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=False,
|
||||||
|
**inference_settings["common_params"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ChatCompletionResponse)
|
||||||
|
assert response.completion_message.role == "assistant"
|
||||||
|
assert isinstance(response.completion_message.content, str)
|
||||||
|
assert len(response.completion_message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_streaming(inference_settings, sample_messages):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
response = [
|
||||||
|
r
|
||||||
|
async for r in inference_impl.chat_completion(
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=True,
|
||||||
|
**inference_settings["common_params"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||||
|
)
|
||||||
|
grouped = group_chunks(response)
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||||
|
|
||||||
|
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||||
|
assert end.event.stop_reason == StopReason.end_of_turn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_with_tool_calling(
|
||||||
|
inference_settings,
|
||||||
|
sample_messages,
|
||||||
|
sample_tool_definition,
|
||||||
|
):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
messages = sample_messages + [
|
||||||
|
UserMessage(
|
||||||
|
content="What's the weather like in San Francisco?",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await inference_impl.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
tools=[sample_tool_definition],
|
||||||
|
stream=False,
|
||||||
|
**inference_settings["common_params"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ChatCompletionResponse)
|
||||||
|
|
||||||
|
message = response.completion_message
|
||||||
|
|
||||||
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||||
|
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
||||||
|
# assert message.stop_reason == stop_reason
|
||||||
|
assert message.tool_calls is not None
|
||||||
|
assert len(message.tool_calls) > 0
|
||||||
|
|
||||||
|
call = message.tool_calls[0]
|
||||||
|
assert call.tool_name == "get_weather"
|
||||||
|
assert "location" in call.arguments
|
||||||
|
assert "San Francisco" in call.arguments["location"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_with_tool_calling_streaming(
|
||||||
|
inference_settings,
|
||||||
|
sample_messages,
|
||||||
|
sample_tool_definition,
|
||||||
|
):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
messages = sample_messages + [
|
||||||
|
UserMessage(
|
||||||
|
content="What's the weather like in San Francisco?",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
response = [
|
||||||
|
r
|
||||||
|
async for r in inference_impl.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
tools=[sample_tool_definition],
|
||||||
|
stream=True,
|
||||||
|
**inference_settings["common_params"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||||
|
)
|
||||||
|
grouped = group_chunks(response)
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||||
|
|
||||||
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||||
|
# expected_stop_reason = get_expected_stop_reason(
|
||||||
|
# inference_settings["common_params"]["model"]
|
||||||
|
# )
|
||||||
|
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||||
|
# assert end.event.stop_reason == expected_stop_reason
|
||||||
|
|
||||||
|
model = inference_settings["common_params"]["model"]
|
||||||
|
if "Llama3.1" in model:
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk.event.delta, ToolCallDelta)
|
||||||
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||||
|
)
|
||||||
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||||
|
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||||
|
|
||||||
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||||
|
# assert last.event.stop_reason == expected_stop_reason
|
||||||
|
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||||
|
assert isinstance(last.event.delta.content, ToolCall)
|
||||||
|
|
||||||
|
call = last.event.delta.content
|
||||||
|
assert call.tool_name == "get_weather"
|
||||||
|
assert "location" in call.arguments
|
||||||
|
assert "San Francisco" in call.arguments["location"]
|
|
@ -8,7 +8,7 @@ import unittest
|
||||||
|
|
||||||
from llama_models.llama3.api import * # noqa: F403
|
from llama_models.llama3.api import * # noqa: F403
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.inference.api import * # noqa: F403
|
||||||
from llama_stack.inference.augment_messages import augment_messages_for_tools
|
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
|
||||||
|
|
||||||
MODEL = "Llama3.1-8B-Instruct"
|
MODEL = "Llama3.1-8B-Instruct"
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
UserMessage(content=content),
|
UserMessage(content=content),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 2)
|
self.assertEqual(len(messages), 2)
|
||||||
self.assertEqual(messages[-1].content, content)
|
self.assertEqual(messages[-1].content, content)
|
||||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||||
|
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 2)
|
self.assertEqual(len(messages), 2)
|
||||||
self.assertEqual(messages[-1].content, content)
|
self.assertEqual(messages[-1].content, content)
|
||||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||||
|
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
tool_prompt_format=ToolPromptFormat.json,
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 3)
|
self.assertEqual(len(messages), 3)
|
||||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 3)
|
self.assertEqual(len(messages), 3)
|
||||||
|
|
||||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||||
|
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
self.assertEqual(len(messages), 2, messages)
|
self.assertEqual(len(messages), 2, messages)
|
||||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||||
|
|
5
llama_stack/providers/tests/memory/__init__.py
Normal file
5
llama_stack/providers/tests/memory/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,24 @@
|
||||||
|
providers:
|
||||||
|
- provider_id: test-faiss
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
|
- provider_id: test-chroma
|
||||||
|
provider_type: remote::chroma
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 6001
|
||||||
|
- provider_id: test-remote
|
||||||
|
provider_type: remote
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 7002
|
||||||
|
- provider_id: test-weaviate
|
||||||
|
provider_type: remote::weaviate
|
||||||
|
config: {}
|
||||||
|
# if a provider needs private keys from the client, they use the
|
||||||
|
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||||
|
# this is a place to provide such data.
|
||||||
|
provider_data:
|
||||||
|
"test-weaviate":
|
||||||
|
weaviate_api_key: 0xdeadbeefputrealapikeyhere
|
||||||
|
weaviate_cluster_url: http://foobarbaz
|
119
llama_stack/providers/tests/memory/test_memory.py
Normal file
119
llama_stack/providers/tests/memory/test_memory.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||||
|
# since it depends on the provider you are testing. On top of that you need
|
||||||
|
# `pytest` and `pytest-asyncio` installed.
|
||||||
|
#
|
||||||
|
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||||
|
#
|
||||||
|
# 3. Run:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# PROVIDER_ID=<your_provider> \
|
||||||
|
# PROVIDER_CONFIG=provider_config.yaml \
|
||||||
|
# pytest -s llama_stack/providers/tests/memory/test_memory.py \
|
||||||
|
# --tb=short --disable-warnings
|
||||||
|
# ```
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def memory_impl():
|
||||||
|
impls = await resolve_impls_for_test(
|
||||||
|
Api.memory,
|
||||||
|
memory_banks=[],
|
||||||
|
)
|
||||||
|
return impls[Api.memory]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_documents():
|
||||||
|
return [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id="doc1",
|
||||||
|
content="Python is a high-level programming language.",
|
||||||
|
metadata={"category": "programming", "difficulty": "beginner"},
|
||||||
|
),
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id="doc2",
|
||||||
|
content="Machine learning is a subset of artificial intelligence.",
|
||||||
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
|
),
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id="doc3",
|
||||||
|
content="Data structures are fundamental to computer science.",
|
||||||
|
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||||
|
),
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id="doc4",
|
||||||
|
content="Neural networks are inspired by biological neural networks.",
|
||||||
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def register_memory_bank(memory_impl: Memory):
|
||||||
|
bank = VectorMemoryBankDef(
|
||||||
|
identifier="test_bank",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
await memory_impl.register_memory_bank(bank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_documents(memory_impl, sample_documents):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
|
await register_memory_bank(memory_impl)
|
||||||
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
|
query1 = "programming language"
|
||||||
|
response1 = await memory_impl.query_documents("test_bank", query1)
|
||||||
|
assert_valid_response(response1)
|
||||||
|
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||||
|
|
||||||
|
# Test case 3: Query with semantic similarity
|
||||||
|
query3 = "AI and brain-inspired computing"
|
||||||
|
response3 = await memory_impl.query_documents("test_bank", query3)
|
||||||
|
assert_valid_response(response3)
|
||||||
|
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
|
||||||
|
|
||||||
|
# Test case 4: Query with limit on number of results
|
||||||
|
query4 = "computer"
|
||||||
|
params4 = {"max_chunks": 2}
|
||||||
|
response4 = await memory_impl.query_documents("test_bank", query4, params4)
|
||||||
|
assert_valid_response(response4)
|
||||||
|
assert len(response4.chunks) <= 2
|
||||||
|
|
||||||
|
# Test case 5: Query with threshold on similarity score
|
||||||
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
|
params5 = {"score_threshold": 0.5}
|
||||||
|
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
||||||
|
assert_valid_response(response5)
|
||||||
|
assert all(score >= 0.5 for score in response5.scores)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_valid_response(response: QueryDocumentsResponse):
|
||||||
|
assert isinstance(response, QueryDocumentsResponse)
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert len(response.scores) > 0
|
||||||
|
assert len(response.chunks) == len(response.scores)
|
||||||
|
for chunk in response.chunks:
|
||||||
|
assert isinstance(chunk.content, str)
|
||||||
|
assert chunk.document_id is not None
|
100
llama_stack/providers/tests/resolver.py
Normal file
100
llama_stack/providers/tests/resolver.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
|
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_impls_for_test(
|
||||||
|
api: Api,
|
||||||
|
models: List[ModelDef] = None,
|
||||||
|
memory_banks: List[MemoryBankDef] = None,
|
||||||
|
shields: List[ShieldDef] = None,
|
||||||
|
):
|
||||||
|
if "PROVIDER_CONFIG" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if "providers" not in config_dict:
|
||||||
|
raise ValueError("Config file should contain a `providers` key")
|
||||||
|
|
||||||
|
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
|
||||||
|
if len(providers_by_id) == 0:
|
||||||
|
raise ValueError("No providers found in config file")
|
||||||
|
|
||||||
|
if "PROVIDER_ID" in os.environ:
|
||||||
|
provider_id = os.environ["PROVIDER_ID"]
|
||||||
|
if provider_id not in providers_by_id:
|
||||||
|
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
||||||
|
provider = providers_by_id[provider_id]
|
||||||
|
else:
|
||||||
|
provider = list(providers_by_id.values())[0]
|
||||||
|
provider_id = provider["provider_id"]
|
||||||
|
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||||
|
|
||||||
|
models = models or []
|
||||||
|
shields = shields or []
|
||||||
|
memory_banks = memory_banks or []
|
||||||
|
|
||||||
|
models = [
|
||||||
|
ModelDef(
|
||||||
|
**{
|
||||||
|
**m.dict(),
|
||||||
|
"provider_id": provider_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for m in models
|
||||||
|
]
|
||||||
|
shields = [
|
||||||
|
ShieldDef(
|
||||||
|
**{
|
||||||
|
**s.dict(),
|
||||||
|
"provider_id": provider_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for s in shields
|
||||||
|
]
|
||||||
|
memory_banks = [
|
||||||
|
MemoryBankDef(
|
||||||
|
**{
|
||||||
|
**m.dict(),
|
||||||
|
"provider_id": provider_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for m in memory_banks
|
||||||
|
]
|
||||||
|
run_config = dict(
|
||||||
|
built_at=datetime.now(),
|
||||||
|
image_name="test-fixture",
|
||||||
|
apis=[api],
|
||||||
|
providers={api.value: [Provider(**provider)]},
|
||||||
|
models=models,
|
||||||
|
memory_banks=memory_banks,
|
||||||
|
shields=shields,
|
||||||
|
)
|
||||||
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
|
impls = await resolve_impls_with_routing(run_config)
|
||||||
|
|
||||||
|
if "provider_data" in config_dict:
|
||||||
|
provider_data = config_dict["provider_data"].get(provider_id, {})
|
||||||
|
if provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return impls
|
35
llama_stack/providers/utils/inference/model_registry.py
Normal file
35
llama_stack/providers/utils/inference/model_registry.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistryHelper:
|
||||||
|
|
||||||
|
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||||
|
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||||
|
|
||||||
|
def map_to_provider_model(self, identifier: str) -> str:
|
||||||
|
model = resolve_model(identifier)
|
||||||
|
if not model:
|
||||||
|
raise ValueError(f"Unknown model: `{identifier}`")
|
||||||
|
|
||||||
|
if identifier not in self.stack_to_provider_models_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.stack_to_provider_models_map[identifier]
|
||||||
|
|
||||||
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
|
if model.identifier not in self.stack_to_provider_models_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||||
|
)
|
189
llama_stack/providers/utils/inference/openai_compat.py
Normal file
189
llama_stack/providers/utils/inference/openai_compat.py
Normal file
|
@ -0,0 +1,189 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import AsyncGenerator, Optional
|
||||||
|
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatCompletionChoice(BaseModel):
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
text: Optional[str] = None
|
||||||
|
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatCompletionResponse(BaseModel):
|
||||||
|
choices: List[OpenAICompatCompletionChoice]
|
||||||
|
|
||||||
|
|
||||||
|
def get_sampling_options(request: ChatCompletionRequest) -> dict:
|
||||||
|
options = {}
|
||||||
|
if params := request.sampling_params:
|
||||||
|
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||||
|
if getattr(params, attr):
|
||||||
|
options[attr] = getattr(params, attr)
|
||||||
|
|
||||||
|
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||||
|
options["repeat_penalty"] = params.repetition_penalty
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def text_from_choice(choice) -> str:
|
||||||
|
if hasattr(choice, "delta") and choice.delta:
|
||||||
|
return choice.delta.content
|
||||||
|
|
||||||
|
return choice.text
|
||||||
|
|
||||||
|
|
||||||
|
def process_chat_completion_response(
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
response: OpenAICompatCompletionResponse,
|
||||||
|
formatter: ChatFormat,
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
choice = response.choices[0]
|
||||||
|
|
||||||
|
stop_reason = None
|
||||||
|
if reason := choice.finish_reason:
|
||||||
|
if reason in ["stop", "eos"]:
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif reason == "eom":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
elif reason == "length":
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
completion_message = formatter.decode_assistant_message_from_content(
|
||||||
|
text_from_choice(choice), stop_reason
|
||||||
|
)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=completion_message,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_chat_completion_stream_response(
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
|
formatter: ChatFormat,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
delta="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
ipython = False
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
if finish_reason:
|
||||||
|
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif stop_reason is None and finish_reason == "length":
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
break
|
||||||
|
|
||||||
|
text = text_from_choice(choice)
|
||||||
|
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||||
|
if not ipython and text.startswith("<|python_tag|>"):
|
||||||
|
ipython = True
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.started,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
buffer += text
|
||||||
|
continue
|
||||||
|
|
||||||
|
if text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
continue
|
||||||
|
elif text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ipython:
|
||||||
|
buffer += text
|
||||||
|
delta = ToolCallDelta(
|
||||||
|
content=text,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=delta,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
buffer += text
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=text,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# parse tool calls and report errors
|
||||||
|
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
||||||
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
|
if ipython and not parsed_tool_calls:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.failure,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue