rebase on top of registry

This commit is contained in:
Xi Yan 2024-10-08 23:41:03 -07:00
commit 6abef716dd
107 changed files with 4813 additions and 3587 deletions

1
.gitignore vendored
View file

@ -13,3 +13,4 @@ xcuserdata/
Package.resolved Package.resolved
*.pte *.pte
*.ipynb_checkpoints* *.ipynb_checkpoints*
.idea

View file

@ -1,4 +1,4 @@
# Contributing to Llama-Models # Contributing to Llama-Stack
We want to make contributing to this project as easy and transparent as We want to make contributing to this project as easy and transparent as
possible. possible.
@ -32,7 +32,7 @@ outlined on that page and do not file a public issue.
* ... * ...
## Tips ## Tips
* If you are developing with a llama-models repository checked out and need your distribution to reflect changes from there, set `LLAMA_MODELS_DIR` to that dir when running any of the `llama` CLI commands. * If you are developing with a llama-stack repository checked out and need your distribution to reflect changes from there, set `LLAMA_STACK_DIR` to that dir when running any of the `llama` CLI commands.
## License ## License
By contributing to Llama, you agree that your contributions will be licensed By contributing to Llama, you agree that your contributions will be licensed

View file

@ -81,11 +81,24 @@ cd llama-stack
$CONDA_PREFIX/bin/pip install -e . $CONDA_PREFIX/bin/pip install -e .
``` ```
## The Llama CLI ## Documentations
The `llama` CLI makes it easy to work with the Llama Stack set of tools, including installing and running Distributions, downloading models, studying model prompt formats, etc. Please see the [CLI reference](docs/cli_reference.md) for details. Please see the [Getting Started](docs/getting_started.md) guide for running a Llama Stack server. The `llama` CLI makes it easy to work with the Llama Stack set of tools. Please find the following docs for details.
* [CLI reference](docs/cli_reference.md)
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
* [Getting Started](docs/getting_started.md)
* Guide to build and run a Llama Stack server.
* [Contributing](CONTRIBUTING.md)
## Llama Stack Client SDK ## Llama Stack Client SDK
| **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) |
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) |
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.

5
SECURITY.md Normal file
View file

@ -0,0 +1,5 @@
# Security Policy
## Reporting a Vulnerability
Please report vulnerabilities to our bug bounty program at https://bugbounty.meta.com/

View file

@ -1,6 +1,6 @@
# Llama CLI Reference # Llama CLI Reference
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package. The `llama` CLI tool helps you setup and use the Llama Stack & agentic systems. It should be available on your path after installing the `llama-stack` package.
### Subcommands ### Subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face. 1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
@ -117,9 +117,9 @@ llama download --source meta --model-id Llama-Guard-3-1B --meta-url META_URL
Essentially, the same commands above work, just replace `--source meta` with `--source huggingface`. Essentially, the same commands above work, just replace `--source meta` with `--source huggingface`.
```bash ```bash
llama download --source huggingface --model-id Meta-Llama3.1-8B-Instruct --hf-token <HF_TOKEN> llama download --source huggingface --model-id Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
llama download --source huggingface --model-id Meta-Llama3.1-70B-Instruct --hf-token <HF_TOKEN> llama download --source huggingface --model-id Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
llama download --source huggingface --model-id Llama-Guard-3-1B --ignore-patterns *original* llama download --source huggingface --model-id Llama-Guard-3-1B --ignore-patterns *original*
llama download --source huggingface --model-id Prompt-Guard-86M --ignore-patterns *original* llama download --source huggingface --model-id Prompt-Guard-86M --ignore-patterns *original*
@ -215,9 +215,8 @@ You can even run `llama model prompt-format` see all of the templates and their
``` ```
llama model prompt-format -m Llama3.2-3B-Instruct llama model prompt-format -m Llama3.2-3B-Instruct
``` ```
<p align="center"> ![alt text](resources/prompt-format.png)
<img width="719" alt="image" src="https://github.com/user-attachments/assets/c5332026-8c0b-4edc-b438-ec60cd7ca554">
</p>
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios. You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
@ -230,7 +229,7 @@ You will be shown a Markdown formatted description of the model interface and ho
- Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution. - Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
### Step 3.1 Build ### Step 3.1 Build
In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify: In the following steps, imagine we'll be working with a `Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
- `name`: the name for our distribution (e.g. `8b-instruct`) - `name`: the name for our distribution (e.g. `8b-instruct`)
- `image_type`: our build image type (`conda | docker`) - `image_type`: our build image type (`conda | docker`)
- `distribution_spec`: our distribution specs for specifying API providers - `distribution_spec`: our distribution specs for specifying API providers
@ -365,7 +364,7 @@ llama stack configure [ <name> | <docker-image-name> | <path/to/name.build.yaml>
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml $ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml
Configuring API: inference (meta-reference) Configuring API: inference (meta-reference)
Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required): Enter value for model (existing: Llama3.1-8B-Instruct) (required):
Enter value for quantization (optional): Enter value for quantization (optional):
Enter value for torch_seed (optional): Enter value for torch_seed (optional):
Enter value for max_seq_len (existing: 4096) (required): Enter value for max_seq_len (existing: 4096) (required):
@ -397,7 +396,7 @@ YAML configuration has been written to ~/.llama/builds/conda/8b-instruct-run.yam
After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/8b-instruct-run.yaml` with the following contents. You may edit this file to change the settings. After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/8b-instruct-run.yaml` with the following contents. You may edit this file to change the settings.
As you can see, we did basic configuration above and configured: As you can see, we did basic configuration above and configured:
- inference to run on model `Meta-Llama3.1-8B-Instruct` (obtained from `llama model list`) - inference to run on model `Llama3.1-8B-Instruct` (obtained from `llama model list`)
- Llama Guard safety shield with model `Llama-Guard-3-1B` - Llama Guard safety shield with model `Llama-Guard-3-1B`
- Prompt Guard safety shield with model `Prompt-Guard-86M` - Prompt Guard safety shield with model `Prompt-Guard-86M`

View file

@ -1,7 +1,7 @@
# llama-stack # llama-stack
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU) [![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack)
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack. This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
@ -66,8 +66,17 @@ This guides allows you to quickly get started with building and running a Llama
You may also checkout this [notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) for trying out out demo scripts. You may also checkout this [notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) for trying out out demo scripts.
## Quick Cheatsheet ## Quick Cheatsheet
- Quick 3 line command to build and start a LlamaStack server using our Meta Reference implementation for all API endpoints with `conda` as build type.
#### Via docker
```
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack-local-gpu
```
> [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models.
#### Via conda
**`llama stack build`** **`llama stack build`**
- You'll be prompted to enter build information interactively. - You'll be prompted to enter build information interactively.
``` ```

File diff suppressed because it is too large Load diff

View file

@ -580,63 +580,6 @@ components:
- uuid - uuid
- dataset - dataset
type: object type: object
CreateMemoryBankRequest:
additionalProperties: false
properties:
config:
oneOf:
- additionalProperties: false
properties:
chunk_size_in_tokens:
type: integer
embedding_model:
type: string
overlap_size_in_tokens:
type: integer
type:
const: vector
default: vector
type: string
required:
- type
- embedding_model
- chunk_size_in_tokens
type: object
- additionalProperties: false
properties:
type:
const: keyvalue
default: keyvalue
type: string
required:
- type
type: object
- additionalProperties: false
properties:
type:
const: keyword
default: keyword
type: string
required:
- type
type: object
- additionalProperties: false
properties:
type:
const: graph
default: graph
type: string
required:
- type
type: object
name:
type: string
url:
$ref: '#/components/schemas/URL'
required:
- name
- config
type: object
DPOAlignmentConfig: DPOAlignmentConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -739,14 +682,6 @@ components:
- rank - rank
- alpha - alpha
type: object type: object
DropMemoryBankRequest:
additionalProperties: false
properties:
bank_id:
type: string
required:
- bank_id
type: object
EmbeddingsRequest: EmbeddingsRequest:
additionalProperties: false additionalProperties: false
properties: properties:
@ -908,6 +843,21 @@ components:
required: required:
- document_ids - document_ids
type: object type: object
GraphMemoryBankDef:
additionalProperties: false
properties:
identifier:
type: string
provider_id:
type: string
type:
const: graph
default: graph
type: string
required:
- identifier
- type
type: object
HealthInfo: HealthInfo:
additionalProperties: false additionalProperties: false
properties: properties:
@ -973,6 +923,36 @@ components:
- bank_id - bank_id
- documents - documents
type: object type: object
KeyValueMemoryBankDef:
additionalProperties: false
properties:
identifier:
type: string
provider_id:
type: string
type:
const: keyvalue
default: keyvalue
type: string
required:
- identifier
- type
type: object
KeywordMemoryBankDef:
additionalProperties: false
properties:
identifier:
type: string
provider_id:
type: string
type:
const: keyword
default: keyword
type: string
required:
- identifier
- type
type: object
LogEventRequest: LogEventRequest:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1015,66 +995,6 @@ components:
- rank - rank
- alpha - alpha
type: object type: object
MemoryBank:
additionalProperties: false
properties:
bank_id:
type: string
config:
oneOf:
- additionalProperties: false
properties:
chunk_size_in_tokens:
type: integer
embedding_model:
type: string
overlap_size_in_tokens:
type: integer
type:
const: vector
default: vector
type: string
required:
- type
- embedding_model
- chunk_size_in_tokens
type: object
- additionalProperties: false
properties:
type:
const: keyvalue
default: keyvalue
type: string
required:
- type
type: object
- additionalProperties: false
properties:
type:
const: keyword
default: keyword
type: string
required:
- type
type: object
- additionalProperties: false
properties:
type:
const: graph
default: graph
type: string
required:
- type
type: object
name:
type: string
url:
$ref: '#/components/schemas/URL'
required:
- bank_id
- name
- config
type: object
MemoryBankDocument: MemoryBankDocument:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1107,41 +1027,6 @@ components:
- content - content
- metadata - metadata
type: object type: object
MemoryBankSpec:
additionalProperties: false
properties:
bank_type:
$ref: '#/components/schemas/MemoryBankType'
provider_config:
additionalProperties: false
properties:
config:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_type:
type: string
required:
- provider_type
- config
type: object
required:
- bank_type
- provider_config
type: object
MemoryBankType:
enum:
- vector
- keyvalue
- keyword
- graph
type: string
MemoryRetrievalStep: MemoryRetrievalStep:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1349,36 +1234,18 @@ components:
- value - value
- unit - unit
type: object type: object
Model: ModelDef:
description: The model family and SKU of the model along with other parameters
corresponding to the model.
ModelServingSpec:
additionalProperties: false additionalProperties: false
properties: properties:
identifier:
type: string
llama_model: llama_model:
$ref: '#/components/schemas/Model' type: string
provider_config: provider_id:
additionalProperties: false
properties:
config:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_type:
type: string type: string
required: required:
- provider_type - identifier
- config
type: object
required:
- llama_model - llama_model
- provider_config
type: object type: object
OptimizerConfig: OptimizerConfig:
additionalProperties: false additionalProperties: false
@ -1554,13 +1421,13 @@ components:
ProviderInfo: ProviderInfo:
additionalProperties: false additionalProperties: false
properties: properties:
description: provider_id:
type: string type: string
provider_type: provider_type:
type: string type: string
required: required:
- provider_id
- provider_type - provider_type
- description
type: object type: object
QLoraFinetuningConfig: QLoraFinetuningConfig:
additionalProperties: false additionalProperties: false
@ -1650,6 +1517,56 @@ components:
enum: enum:
- dpo - dpo
type: string type: string
RegisterMemoryBankRequest:
additionalProperties: false
properties:
memory_bank:
oneOf:
- $ref: '#/components/schemas/VectorMemoryBankDef'
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
- $ref: '#/components/schemas/KeywordMemoryBankDef'
- $ref: '#/components/schemas/GraphMemoryBankDef'
required:
- memory_bank
type: object
RegisterModelRequest:
additionalProperties: false
properties:
model:
$ref: '#/components/schemas/ModelDef'
required:
- model
type: object
RegisterShieldRequest:
additionalProperties: false
properties:
shield:
additionalProperties: false
properties:
identifier:
type: string
params:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_id:
type: string
type:
type: string
required:
- identifier
- type
- params
type: object
required:
- shield
type: object
RestAPIExecutionConfig: RestAPIExecutionConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1728,7 +1645,7 @@ components:
properties: properties:
method: method:
type: string type: string
providers: provider_types:
items: items:
type: string type: string
type: array type: array
@ -1737,7 +1654,7 @@ components:
required: required:
- route - route
- method - method
- providers - provider_types
type: object type: object
RunShieldRequest: RunShieldRequest:
additionalProperties: false additionalProperties: false
@ -1892,7 +1809,11 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
memory_bank: memory_bank:
$ref: '#/components/schemas/MemoryBank' oneOf:
- $ref: '#/components/schemas/VectorMemoryBankDef'
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
- $ref: '#/components/schemas/KeywordMemoryBankDef'
- $ref: '#/components/schemas/GraphMemoryBankDef'
session_id: session_id:
type: string type: string
session_name: session_name:
@ -1935,13 +1856,12 @@ components:
- step_id - step_id
- step_type - step_type
type: object type: object
ShieldSpec: ShieldDef:
additionalProperties: false additionalProperties: false
properties: properties:
provider_config: identifier:
additionalProperties: false type: string
properties: params:
config:
additionalProperties: additionalProperties:
oneOf: oneOf:
- type: 'null' - type: 'null'
@ -1951,17 +1871,14 @@ components:
- type: array - type: array
- type: object - type: object
type: object type: object
provider_type: provider_id:
type: string
type:
type: string type: string
required: required:
- provider_type - identifier
- config - type
type: object - params
shield_type:
type: string
required:
- shield_type
- provider_config
type: object type: object
SpanEndPayload: SpanEndPayload:
additionalProperties: false additionalProperties: false
@ -2571,6 +2488,29 @@ components:
- role - role
- content - content
type: object type: object
VectorMemoryBankDef:
additionalProperties: false
properties:
chunk_size_in_tokens:
type: integer
embedding_model:
type: string
identifier:
type: string
overlap_size_in_tokens:
type: integer
provider_id:
type: string
type:
const: vector
default: vector
type: string
required:
- identifier
- type
- embedding_model
- chunk_size_in_tokens
type: object
ViolationLevel: ViolationLevel:
enum: enum:
- info - info
@ -2604,7 +2544,7 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257" \ draft and subject to change.\n Generated at 2024-10-08 15:18:57.600111"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3226,7 +3166,7 @@ paths:
description: OK description: OK
tags: tags:
- Inference - Inference
/memory/create: /inference/register_model:
post: post:
parameters: parameters:
- description: JSON-encoded provider data which will be made available to the - description: JSON-encoded provider data which will be made available to the
@ -3240,17 +3180,13 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/CreateMemoryBankRequest' $ref: '#/components/schemas/RegisterModelRequest'
required: true required: true
responses: responses:
'200': '200':
content:
application/json:
schema:
$ref: '#/components/schemas/MemoryBank'
description: OK description: OK
tags: tags:
- Memory - Models
/memory/documents/delete: /memory/documents/delete:
post: post:
parameters: parameters:
@ -3302,57 +3238,6 @@ paths:
description: OK description: OK
tags: tags:
- Memory - Memory
/memory/drop:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/DropMemoryBankRequest'
required: true
responses:
'200':
content:
application/json:
schema:
type: string
description: OK
tags:
- Memory
/memory/get:
get:
parameters:
- in: query
name: bank_id
required: true
schema:
type: string
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/MemoryBank'
- type: 'null'
description: OK
tags:
- Memory
/memory/insert: /memory/insert:
post: post:
parameters: parameters:
@ -3374,25 +3259,6 @@ paths:
description: OK description: OK
tags: tags:
- Memory - Memory
/memory/list:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/jsonl:
schema:
$ref: '#/components/schemas/MemoryBank'
description: OK
tags:
- Memory
/memory/query: /memory/query:
post: post:
parameters: parameters:
@ -3443,10 +3309,10 @@ paths:
get: get:
parameters: parameters:
- in: query - in: query
name: bank_type name: identifier
required: true required: true
schema: schema:
$ref: '#/components/schemas/MemoryBankType' type: string
- description: JSON-encoded provider data which will be made available to the - description: JSON-encoded provider data which will be made available to the
adapter servicing the API adapter servicing the API
in: header in: header
@ -3460,7 +3326,11 @@ paths:
application/json: application/json:
schema: schema:
oneOf: oneOf:
- $ref: '#/components/schemas/MemoryBankSpec' - oneOf:
- $ref: '#/components/schemas/VectorMemoryBankDef'
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
- $ref: '#/components/schemas/KeywordMemoryBankDef'
- $ref: '#/components/schemas/GraphMemoryBankDef'
- type: 'null' - type: 'null'
description: OK description: OK
tags: tags:
@ -3480,15 +3350,40 @@ paths:
content: content:
application/jsonl: application/jsonl:
schema: schema:
$ref: '#/components/schemas/MemoryBankSpec' oneOf:
- $ref: '#/components/schemas/VectorMemoryBankDef'
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
- $ref: '#/components/schemas/KeywordMemoryBankDef'
- $ref: '#/components/schemas/GraphMemoryBankDef'
description: OK description: OK
tags: tags:
- MemoryBanks - MemoryBanks
/memory_banks/register:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterMemoryBankRequest'
required: true
responses:
'200':
description: OK
tags:
- Memory
/models/get: /models/get:
get: get:
parameters: parameters:
- in: query - in: query
name: core_model_id name: identifier
required: true required: true
schema: schema:
type: string type: string
@ -3505,7 +3400,7 @@ paths:
application/json: application/json:
schema: schema:
oneOf: oneOf:
- $ref: '#/components/schemas/ModelServingSpec' - $ref: '#/components/schemas/ModelDef'
- type: 'null' - type: 'null'
description: OK description: OK
tags: tags:
@ -3525,7 +3420,7 @@ paths:
content: content:
application/jsonl: application/jsonl:
schema: schema:
$ref: '#/components/schemas/ModelServingSpec' $ref: '#/components/schemas/ModelDef'
description: OK description: OK
tags: tags:
- Models - Models
@ -3760,6 +3655,27 @@ paths:
description: OK description: OK
tags: tags:
- Inspect - Inspect
/safety/register_shield:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterShieldRequest'
required: true
responses:
'200':
description: OK
tags:
- Shields
/safety/run_shield: /safety/run_shield:
post: post:
parameters: parameters:
@ -3806,7 +3722,29 @@ paths:
application/json: application/json:
schema: schema:
oneOf: oneOf:
- $ref: '#/components/schemas/ShieldSpec' - additionalProperties: false
properties:
identifier:
type: string
params:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_id:
type: string
type:
type: string
required:
- identifier
- type
- params
type: object
- type: 'null' - type: 'null'
description: OK description: OK
tags: tags:
@ -3826,7 +3764,7 @@ paths:
content: content:
application/jsonl: application/jsonl:
schema: schema:
$ref: '#/components/schemas/ShieldSpec' $ref: '#/components/schemas/ShieldDef'
description: OK description: OK
tags: tags:
- Shields - Shields
@ -3905,21 +3843,21 @@ security:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
tags: tags:
- name: Datasets
- name: Inspect
- name: Memory
- name: BatchInference - name: BatchInference
- name: Agents - name: Datasets
- name: Inference - name: Inference
- name: Shields
- name: SyntheticDataGeneration
- name: Models
- name: RewardScoring
- name: MemoryBanks
- name: Safety
- name: Evaluations - name: Evaluations
- name: Telemetry - name: Memory
- name: Safety
- name: PostTraining - name: PostTraining
- name: MemoryBanks
- name: Models
- name: Shields
- name: Inspect
- name: SyntheticDataGeneration
- name: Telemetry
- name: Agents
- name: RewardScoring
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage" - description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
@ -4123,11 +4061,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateDatasetRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/CreateDatasetRequest"
/> />
name: CreateDatasetRequest name: CreateDatasetRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateMemoryBankRequest"
/>
name: CreateMemoryBankRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBank" />
name: MemoryBank
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsRequest"
/> />
name: DeleteAgentsRequest name: DeleteAgentsRequest
@ -4140,9 +4073,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteDocumentsRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/DeleteDocumentsRequest"
/> />
name: DeleteDocumentsRequest name: DeleteDocumentsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/DropMemoryBankRequest"
/>
name: DropMemoryBankRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/EmbeddingsRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/EmbeddingsRequest"
/> />
name: EmbeddingsRequest name: EmbeddingsRequest
@ -4163,11 +4093,23 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/GetAgentsSessionRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/GetAgentsSessionRequest"
/> />
name: GetAgentsSessionRequest name: GetAgentsSessionRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankDef"
/>
name: GraphMemoryBankDef
- description: <SchemaDefinition schemaRef="#/components/schemas/KeyValueMemoryBankDef"
/>
name: KeyValueMemoryBankDef
- description: <SchemaDefinition schemaRef="#/components/schemas/KeywordMemoryBankDef"
/>
name: KeywordMemoryBankDef
- description: 'A single session of an interaction with an Agentic System. - description: 'A single session of an interaction with an Agentic System.
<SchemaDefinition schemaRef="#/components/schemas/Session" />' <SchemaDefinition schemaRef="#/components/schemas/Session" />'
name: Session name: Session
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBankDef"
/>
name: VectorMemoryBankDef
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse" - description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse"
/> />
name: AgentStepResponse name: AgentStepResponse
@ -4189,21 +4131,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJobStatusResponse" - description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJobStatusResponse"
/> />
name: EvaluationJobStatusResponse name: EvaluationJobStatusResponse
- description: 'The model family and SKU of the model along with other parameters - description: <SchemaDefinition schemaRef="#/components/schemas/ModelDef" />
corresponding to the model. name: ModelDef
<SchemaDefinition schemaRef="#/components/schemas/Model" />'
name: Model
- description: <SchemaDefinition schemaRef="#/components/schemas/ModelServingSpec"
/>
name: ModelServingSpec
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankType" />
name: MemoryBankType
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankSpec" />
name: MemoryBankSpec
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldSpec" />
name: ShieldSpec
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" /> - description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
name: Trace name: Trace
- description: 'Checkpoint created during training runs - description: 'Checkpoint created during training runs
@ -4243,6 +4172,8 @@ tags:
name: ProviderInfo name: ProviderInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" /> - description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
name: RouteInfo name: RouteInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldDef" />
name: ShieldDef
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" /> - description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
name: LogSeverity name: LogSeverity
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" /> - description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
@ -4282,6 +4213,15 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryDocumentsResponse" - description: <SchemaDefinition schemaRef="#/components/schemas/QueryDocumentsResponse"
/> />
name: QueryDocumentsResponse name: QueryDocumentsResponse
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterMemoryBankRequest"
/>
name: RegisterMemoryBankRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterModelRequest"
/>
name: RegisterModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterShieldRequest"
/>
name: RegisterShieldRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/DialogGenerations" - description: <SchemaDefinition schemaRef="#/components/schemas/DialogGenerations"
/> />
name: DialogGenerations name: DialogGenerations
@ -4387,7 +4327,6 @@ x-tagGroups:
- CreateAgentSessionRequest - CreateAgentSessionRequest
- CreateAgentTurnRequest - CreateAgentTurnRequest
- CreateDatasetRequest - CreateDatasetRequest
- CreateMemoryBankRequest
- DPOAlignmentConfig - DPOAlignmentConfig
- DeleteAgentsRequest - DeleteAgentsRequest
- DeleteAgentsSessionRequest - DeleteAgentsSessionRequest
@ -4395,7 +4334,6 @@ x-tagGroups:
- DeleteDocumentsRequest - DeleteDocumentsRequest
- DialogGenerations - DialogGenerations
- DoraFinetuningConfig - DoraFinetuningConfig
- DropMemoryBankRequest
- EmbeddingsRequest - EmbeddingsRequest
- EmbeddingsResponse - EmbeddingsResponse
- EvaluateQuestionAnsweringRequest - EvaluateQuestionAnsweringRequest
@ -4409,22 +4347,21 @@ x-tagGroups:
- FunctionCallToolDefinition - FunctionCallToolDefinition
- GetAgentsSessionRequest - GetAgentsSessionRequest
- GetDocumentsRequest - GetDocumentsRequest
- GraphMemoryBankDef
- HealthInfo - HealthInfo
- ImageMedia - ImageMedia
- InferenceStep - InferenceStep
- InsertDocumentsRequest - InsertDocumentsRequest
- KeyValueMemoryBankDef
- KeywordMemoryBankDef
- LogEventRequest - LogEventRequest
- LogSeverity - LogSeverity
- LoraFinetuningConfig - LoraFinetuningConfig
- MemoryBank
- MemoryBankDocument - MemoryBankDocument
- MemoryBankSpec
- MemoryBankType
- MemoryRetrievalStep - MemoryRetrievalStep
- MemoryToolDefinition - MemoryToolDefinition
- MetricEvent - MetricEvent
- Model - ModelDef
- ModelServingSpec
- OptimizerConfig - OptimizerConfig
- PhotogenToolDefinition - PhotogenToolDefinition
- PostTrainingJob - PostTrainingJob
@ -4438,6 +4375,9 @@ x-tagGroups:
- QueryDocumentsRequest - QueryDocumentsRequest
- QueryDocumentsResponse - QueryDocumentsResponse
- RLHFAlgorithm - RLHFAlgorithm
- RegisterMemoryBankRequest
- RegisterModelRequest
- RegisterShieldRequest
- RestAPIExecutionConfig - RestAPIExecutionConfig
- RestAPIMethod - RestAPIMethod
- RewardScoreRequest - RewardScoreRequest
@ -4453,7 +4393,7 @@ x-tagGroups:
- SearchToolDefinition - SearchToolDefinition
- Session - Session
- ShieldCallStep - ShieldCallStep
- ShieldSpec - ShieldDef
- SpanEndPayload - SpanEndPayload
- SpanStartPayload - SpanStartPayload
- SpanStatus - SpanStatus
@ -4483,5 +4423,6 @@ x-tagGroups:
- UnstructuredLogEvent - UnstructuredLogEvent
- UpdateDocumentsRequest - UpdateDocumentsRequest
- UserMessage - UserMessage
- VectorMemoryBankDef
- ViolationLevel - ViolationLevel
- WolframAlphaToolDefinition - WolframAlphaToolDefinition

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

View file

@ -261,7 +261,7 @@ class Session(BaseModel):
turns: List[Turn] turns: List[Turn]
started_at: datetime started_at: datetime
memory_bank: Optional[MemoryBank] = None memory_bank: Optional[MemoryBankDef] = None
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
@ -411,8 +411,10 @@ class Agents(Protocol):
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgentCreateResponse: ... ) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create") @webmethod(route="/agents/turn/create")
async def create_agent_turn( def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,

View file

@ -7,7 +7,7 @@
import asyncio import asyncio
import json import json
import os import os
from typing import AsyncGenerator from typing import AsyncGenerator, Optional
import fire import fire
import httpx import httpx
@ -67,9 +67,17 @@ class AgentsClient(Agents):
response.raise_for_status() response.raise_for_status()
return AgentSessionCreateResponse(**response.json()) return AgentSessionCreateResponse(**response.json())
async def create_agent_turn( def create_agent_turn(
self, self,
request: AgentTurnCreateRequest, request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.stream( async with client.stream(
@ -93,6 +101,9 @@ class AgentsClient(Agents):
print(data) print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
raise NotImplementedError("Non-streaming not implemented yet")
async def _run_agent( async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
@ -132,8 +143,7 @@ async def _run_agent(
log.print() log.print()
async def run_llama_3_1(host: str, port: int): async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
model = "Llama3.1-8B-Instruct"
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts) await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_llama_3_2_rag(host: str, port: int): async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
urls = [ urls = [
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
) )
async def run_llama_3_2(host: str, port: int): async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models # zero shot tools for llama3.2 text models
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
) )
def main(host: str, port: int, run_type: str): def main(host: str, port: int, run_type: str, model: Optional[str] = None):
assert run_type in [ assert run_type in [
"tools_llama_3_1", "tools_llama_3_1",
"tools_llama_3_2", "tools_llama_3_2",
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
"tools_llama_3_2": run_llama_3_2, "tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag, "rag_llama_3_2": run_llama_3_2_rag,
} }
asyncio.run(fn[run_type](host, port)) args = [host, port]
if model is not None:
args.append(model)
asyncio.run(fn[run_type](*args))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator: def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
async def chat_completion( def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -66,6 +66,29 @@ class InferenceClient(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
response.raise_for_status()
j = response.json()
return ChatCompletionResponse(**j)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.stream( async with client.stream(
"POST", "POST",
@ -77,7 +100,8 @@ class InferenceClient(Inference):
if response.status_code != 200: if response.status_code != 200:
content = await response.aread() content = await response.aread()
cprint( cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red" f"Error: HTTP {response.status_code} {content.decode()}",
"red",
) )
return return
@ -85,40 +109,59 @@ class InferenceClient(Inference):
if line.startswith("data:"): if line.startswith("data:"):
data = line[len("data: ") :] data = line[len("data: ") :]
try: try:
if request.stream:
if "error" in data: if "error" in data:
cprint(data, "red") cprint(data, "red")
continue continue
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(**json.loads(data))
**json.loads(data)
)
else:
yield ChatCompletionResponse(**json.loads(data))
except Exception as e: except Exception as e:
print(data) print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int, stream: bool): async def run_main(
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
):
client = InferenceClient(f"http://{host}:{port}") client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.1-8B-Instruct"
message = UserMessage( message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon" content="hello world, write me a 2 sentence poem about the moon"
) )
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
if logprobs:
logprobs_config = LogProbConfig(
top_k=1,
)
else:
logprobs_config = None
iterator = client.chat_completion( iterator = client.chat_completion(
model="Llama3.1-8B-Instruct", model=model,
messages=[message], messages=[message],
stream=stream, stream=stream,
logprobs=logprobs_config,
) )
if logprobs:
async for chunk in iterator:
cprint(f"Response: {chunk}", "red")
else:
async for log in EventLogger().log(iterator): async for log in EventLogger().log(iterator):
log.print() log.print()
async def run_mm_main(host: str, port: int, stream: bool, path: str): async def run_mm_main(
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
):
client = InferenceClient(f"http://{host}:{port}") client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.2-11B-Vision-Instruct"
message = UserMessage( message = UserMessage(
content=[ content=[
ImageMedia(image=URL(uri=f"file://{path}")), ImageMedia(image=URL(uri=f"file://{path}")),
@ -127,7 +170,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
) )
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
iterator = client.chat_completion( iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct", model=model,
messages=[message], messages=[message],
stream=stream, stream=stream,
) )
@ -135,11 +178,19 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
log.print() log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None): def main(
host: str,
port: int,
stream: bool = True,
mm: bool = False,
logprobs: bool = False,
file: Optional[str] = None,
model: Optional[str] = None,
):
if mm: if mm:
asyncio.run(run_mm_main(host, port, stream, file)) asyncio.run(run_mm_main(host, port, stream, file, model))
else: else:
asyncio.run(run_main(host, port, stream)) asyncio.run(run_main(host, port, stream, model, logprobs))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):
@ -172,9 +173,17 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]] embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol): class Inference(Protocol):
model_store: ModelStore
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -183,8 +192,10 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat_completion")
async def chat_completion( def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -203,3 +214,6 @@ class Inference(Protocol):
model: str, model: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ... ) -> EmbeddingsResponse: ...
@webmethod(route="/inference/register_model")
async def register_model(self, model: ModelDef) -> None: ...

View file

@ -12,15 +12,15 @@ from pydantic import BaseModel
@json_schema_type @json_schema_type
class ProviderInfo(BaseModel): class ProviderInfo(BaseModel):
provider_id: str
provider_type: str provider_type: str
description: str
@json_schema_type @json_schema_type
class RouteInfo(BaseModel): class RouteInfo(BaseModel):
route: str route: str
method: str method: str
providers: List[str] provider_types: List[str]
@json_schema_type @json_schema_type

View file

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks.client import MemoryBanksClient
from llama_stack.providers.utils.memory.file_utils import data_url_from_file from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -35,44 +35,16 @@ class MemoryClient(Memory):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get( response = await client.post(
f"{self.base_url}/memory/get", f"{self.base_url}/memory/register_memory_bank",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory/create",
json={ json={
"name": name, "memory_bank": json.loads(memory_bank.json()),
"config": config.dict(),
"url": url,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20,
) )
r.raise_for_status() response.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def insert_documents( async def insert_documents(
self, self,
@ -114,22 +86,20 @@ class MemoryClient(Memory):
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}") client = MemoryClient(f"http://{host}:{port}")
banks_client = MemoryBanksClient(f"http://{host}:{port}")
# create a memory bank bank = VectorMemoryBankDef(
bank = await client.create_memory_bank( identifier="test_bank",
name="test_bank", provider_id="",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
),
) )
cprint(json.dumps(bank.dict(), indent=4), "green") await client.register_memory_bank(bank)
retrieved_bank = await client.get_memory_bank(bank.bank_id) retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
assert retrieved_bank is not None assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2" assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",
@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool):
# insert some documents # insert some documents
await client.insert_documents( await client.insert_documents(
bank_id=bank.bank_id, bank_id=bank.identifier,
documents=documents, documents=documents,
) )
# query the documents # query the documents
response = await client.query_documents( response = await client.query_documents(
bank_id=bank.bank_id, bank_id=bank.identifier,
query=[ query=[
"How do I use Lora?", "How do I use Lora?",
], ],
@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool):
print(f"Chunk:\n========\n{chunk}\n========\n") print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents( response = await client.query_documents(
bank_id=bank.bank_id, bank_id=bank.identifier,
query=[ query=[
"Tell me more about llama3 and torchtune", "Tell me more about llama3 and torchtune",
], ],

View file

@ -13,9 +13,9 @@ from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
@json_schema_type @json_schema_type
@ -26,44 +26,6 @@ class MemoryBankDocument(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
class KeyValueMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
class KeywordMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class Chunk(BaseModel): class Chunk(BaseModel):
content: InterleavedTextMedia content: InterleavedTextMedia
token_count: int token_count: int
@ -76,45 +38,12 @@ class QueryDocumentsResponse(BaseModel):
scores: List[float] scores: List[float]
@json_schema_type class MemoryBankStore(Protocol):
class QueryAPI(Protocol): def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
@webmethod(route="/query_documents")
def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@json_schema_type
class MemoryBank(BaseModel):
bank_id: str
name: str
config: MemoryBankConfig
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
url: Optional[URL] = None
class Memory(Protocol): class Memory(Protocol):
@webmethod(route="/memory/create") memory_bank_store: MemoryBankStore
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank: ...
@webmethod(route="/memory/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory/drop", method="DELETE")
async def drop_memory_bank(
self,
bank_id: str,
) -> str: ...
# this will just block now until documents are inserted, but it should # this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion
@ -154,3 +83,6 @@ class Memory(Protocol):
bank_id: str, bank_id: str,
document_ids: List[str], document_ids: List[str],
) -> None: ... ) -> None: ...
@webmethod(route="/memory/register_memory_bank")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
from typing import List, Optional from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
@ -15,6 +15,25 @@ from termcolor import cprint
from .memory_banks import * # noqa: F403 from .memory_banks import * # noqa: F403
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
if j is None:
return None
if "type" not in j:
raise ValueError("Memory bank type not specified")
type = j["type"]
if type == MemoryBankType.vector.value:
return VectorMemoryBankDef(**j)
elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBankDef(**j)
elif type == MemoryBankType.keyword.value:
return KeywordMemoryBankDef(**j)
elif type == MemoryBankType.graph.value:
return GraphMemoryBankDef(**j)
else:
raise ValueError(f"Unknown memory bank type: {type}")
class MemoryBanksClient(MemoryBanks): class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str): def __init__(self, base_url: str):
self.base_url = base_url self.base_url = base_url
@ -25,37 +44,36 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: async def list_memory_banks(self) -> List[MemoryBankDef]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/memory_banks/list", f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()] return [deserialize_memory_bank_def(x) for x in response.json()]
async def get_serving_memory_bank( async def get_memory_bank(
self, bank_type: MemoryBankType self,
) -> Optional[MemoryBankSpec]: identifier: str,
) -> Optional[MemoryBankDef]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/memory_banks/get", f"{self.base_url}/memory_banks/get",
params={ params={
"bank_type": bank_type.value, "identifier": identifier,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
j = response.json() j = response.json()
if j is None: return deserialize_memory_bank_def(j)
return None
return MemoryBankSpec(**j)
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}") client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_available_memory_banks() response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green") cprint(f"list_memory_banks response={response}", "green")

View file

@ -4,29 +4,67 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Optional, Protocol from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type @json_schema_type
class MemoryBankSpec(BaseModel): class MemoryBankType(Enum):
bank_type: MemoryBankType vector = "vector"
provider_config: GenericProviderConfig = Field( keyvalue = "keyvalue"
description="Provider config for the model, including provider_type, and corresponding config. ", keyword = "keyword"
) graph = "graph"
class CommonDef(BaseModel):
identifier: str
provider_id: Optional[str] = None
@json_schema_type
class VectorMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
@json_schema_type
class KeywordMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
@json_schema_type
class GraphMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankDef = Annotated[
Union[
VectorMemoryBankDef,
KeyValueMemoryBankDef,
KeywordMemoryBankDef,
GraphMemoryBankDef,
],
Field(discriminator="type"),
]
class MemoryBanks(Protocol): class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET") @webmethod(route="/memory_banks/list", method="GET")
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: ... async def list_memory_banks(self) -> List[MemoryBankDef]: ...
@webmethod(route="/memory_banks/get", method="GET") @webmethod(route="/memory_banks/get", method="GET")
async def get_serving_memory_bank( async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ... @webmethod(route="/memory_banks/register", method="POST")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...

View file

@ -56,7 +56,7 @@ async def run_main(host: str, port: int, stream: bool):
response = await client.list_models() response = await client.list_models()
cprint(f"list_models response={response}", "green") cprint(f"list_models response={response}", "green")
response = await client.get_model("Meta-Llama3.1-8B-Instruct") response = await client.get_model("Llama3.1-8B-Instruct")
cprint(f"get_model response={response}", "blue") cprint(f"get_model response={response}", "blue")
response = await client.get_model("Llama-Guard-3-1B") response = await client.get_model("Llama-Guard-3-1B")

View file

@ -6,27 +6,32 @@
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
from llama_models.llama3.api.datatypes import Model
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type @json_schema_type
class ModelServingSpec(BaseModel): class ModelDef(BaseModel):
llama_model: Model = Field( identifier: str = Field(
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).", description="A unique identifier for the model type",
) )
provider_config: GenericProviderConfig = Field( llama_model: str = Field(
description="Provider config for the model, including provider_type, and corresponding config. ", description="Pointer to the core Llama family model",
) )
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this model"
)
# For now, we are only supporting core llama models but as soon as finetuned
# and other custom models (for example various quantizations) are allowed, there
# will be more metadata fields here
class Models(Protocol): class Models(Protocol):
@webmethod(route="/models/list", method="GET") @webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[ModelServingSpec]: ... async def list_models(self) -> List[ModelDef]: ...
@webmethod(route="/models/get", method="GET") @webmethod(route="/models/get", method="GET")
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ... async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
@webmethod(route="/models/register", method="POST")
async def register_model(self, model: ModelDef) -> None: ...

View file

@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
) )
print(response) print(response)
response = await client.run_shield(
shield_type="injection_shield",
messages=[message],
)
print(response)
def main(host: str, port: int, image: str = None): def main(host: str, port: int, image: str = None):
asyncio.run(run_main(host, port, image)) asyncio.run(run_main(host, port, image))

View file

@ -11,6 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
@json_schema_type @json_schema_type
@ -37,8 +38,17 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None violation: Optional[SafetyViolation] = None
class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
class Safety(Protocol): class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run_shield") @webmethod(route="/safety/run_shield")
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ... ) -> RunShieldResponse: ...
@webmethod(route="/safety/register_shield")
async def register_shield(self, shield: ShieldDef) -> None: ...

View file

@ -4,25 +4,43 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Optional, Protocol from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type @json_schema_type
class ShieldSpec(BaseModel): class ShieldType(Enum):
shield_type: str generic_content_shield = "generic_content_shield"
provider_config: GenericProviderConfig = Field( llama_guard = "llama_guard"
description="Provider config for the model, including provider_type, and corresponding config. ", code_scanner = "code_scanner"
prompt_guard = "prompt_guard"
class ShieldDef(BaseModel):
identifier: str = Field(
description="A unique identifier for the shield type",
)
type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum"
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this shield"
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional parameters needed for this shield",
) )
class Shields(Protocol): class Shields(Protocol):
@webmethod(route="/shields/list", method="GET") @webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldSpec]: ... async def list_shields(self) -> List[ShieldDef]: ...
@webmethod(route="/shields/get", method="GET") @webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ... async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
@webmethod(route="/shields/register", method="POST")
async def register_shield(self, shield: ShieldDef) -> None: ...

View file

@ -158,11 +158,10 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
info = prompt_guard_download_info() info = prompt_guard_download_info()
else: else:
model = resolve_model(args.model_id) model = resolve_model(args.model_id)
info = llama_meta_net_info(model)
if model is None: if model is None:
parser.error(f"Model {args.model_id} not found") parser.error(f"Model {args.model_id} not found")
return return
info = llama_meta_net_info(model)
if args.source == "huggingface": if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser) _hf_download(model, args.hf_token, args.ignore_patterns, parser)
@ -170,7 +169,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
meta_url = args.meta_url meta_url = args.meta_url
if not meta_url: if not meta_url:
meta_url = input( meta_url = input(
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): " "Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
) )
assert meta_url is not None and "llamameta.net" in meta_url assert meta_url is not None and "llamameta.net" in meta_url
_meta_download(model, meta_url, info) _meta_download(model, meta_url, info)

View file

@ -22,7 +22,7 @@ def available_templates_specs() -> List[BuildConfig]:
import yaml import yaml
template_specs = [] template_specs = []
for p in TEMPLATES_PATH.rglob("*.yaml"): for p in TEMPLATES_PATH.rglob("*build.yaml"):
with open(p, "r") as f: with open(p, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f)) build_config = BuildConfig(**yaml.safe_load(f))
template_specs.append(build_config) template_specs.append(build_config)
@ -105,8 +105,7 @@ class StackBuild(Subcommand):
import yaml import yaml
from llama_stack.distribution.build import ApiInput, build_image, ImageType from llama_stack.distribution.build import build_image, ImageType
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.serialize import EnumEncoder from llama_stack.distribution.utils.serialize import EnumEncoder
from termcolor import cprint from termcolor import cprint
@ -137,16 +136,19 @@ class StackBuild(Subcommand):
if build_config.image_type == "conda" if build_config.image_type == "conda"
else (f"llamastack-{build_config.name}") else (f"llamastack-{build_config.name}")
) )
if build_config.image_type == "conda":
cprint( cprint(
f"You can now run `llama stack configure {configure_name}`", f"You can now run `llama stack configure {configure_name}`",
color="green", color="green",
) )
else:
cprint(
f"You can now run `llama stack run {build_config.name}`",
color="green",
)
def _run_template_list_cmd(self, args: argparse.Namespace) -> None: def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json import json
import yaml
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
@ -172,9 +174,11 @@ class StackBuild(Subcommand):
) )
def _run_stack_build_command(self, args: argparse.Namespace) -> None: def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import textwrap
import yaml import yaml
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from termcolor import cprint from termcolor import cprint
@ -238,26 +242,29 @@ class StackBuild(Subcommand):
) )
cprint( cprint(
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.", textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs.
""",
),
color="green", color="green",
) )
print("Tip: use <TAB> to see options for the providers.\n")
providers = dict() providers = dict()
for api, providers_for_api in get_provider_registry().items(): for api, providers_for_api in get_provider_registry().items():
available_providers = [
x for x in providers_for_api.keys() if x != "remote"
]
api_provider = prompt( api_provider = prompt(
"> Enter provider for the {} API: (default=meta-reference): ".format( "> Enter provider for API {}: ".format(api.value),
api.value completer=WordCompleter(available_providers),
), complete_while_typing=True,
validator=Validator.from_callable( validator=Validator.from_callable(
lambda x: x in providers_for_api, lambda x: x in available_providers,
error_message="Invalid provider, please enter one of the following: {}".format( error_message="Invalid provider, use <TAB> to see options",
list(providers_for_api.keys())
),
),
default=(
"meta-reference"
if "meta-reference" in providers_for_api
else list(providers_for_api.keys())[0]
), ),
) )

View file

@ -71,10 +71,8 @@ class StackConfigure(Subcommand):
conda_dir = ( conda_dir = (
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}" Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
) )
output = subprocess.check_output( output = subprocess.check_output(["bash", "-c", "conda info --json"])
["bash", "-c", "conda info --json -a | jq '.envs'"] conda_envs = json.loads(output.decode("utf-8"))["envs"]
)
conda_envs = json.loads(output.decode("utf-8"))
for x in conda_envs: for x in conda_envs:
if x.endswith(f"/llamastack-{args.config}"): if x.endswith(f"/llamastack-{args.config}"):
@ -129,7 +127,10 @@ class StackConfigure(Subcommand):
import yaml import yaml
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.configure import configure_api_providers from llama_stack.distribution.configure import (
configure_api_providers,
parse_and_maybe_upgrade_config,
)
from llama_stack.distribution.utils.serialize import EnumEncoder from llama_stack.distribution.utils.serialize import EnumEncoder
builds_dir = BUILDS_BASE_DIR / build_config.image_type builds_dir = BUILDS_BASE_DIR / build_config.image_type
@ -145,13 +146,17 @@ class StackConfigure(Subcommand):
"yellow", "yellow",
attrs=["bold"], attrs=["bold"],
) )
config = StackRunConfig(**yaml.safe_load(run_config_file.read_text())) config_dict = yaml.safe_load(run_config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
else: else:
config = StackRunConfig( config = StackRunConfig(
built_at=datetime.now(), built_at=datetime.now(),
image_name=image_name, image_name=image_name,
apis_to_serve=[], apis=list(build_config.distribution_spec.providers.keys()),
api_providers={}, providers={},
models=[],
shields=[],
memory_banks=[],
) )
config = configure_api_providers(config, build_config.distribution_spec) config = configure_api_providers(config, build_config.distribution_spec)

View file

@ -7,7 +7,6 @@
import argparse import argparse
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
class StackRun(Subcommand): class StackRun(Subcommand):
@ -46,10 +45,11 @@ class StackRun(Subcommand):
import pkg_resources import pkg_resources
import yaml import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.utils.exec import run_with_pty
if not args.config: if not args.config:
@ -75,8 +75,10 @@ class StackRun(Subcommand):
) )
return return
cprint(f"Using config `{config_file}`", "green")
with open(config_file, "r") as f: with open(config_file, "r") as f:
config = StackRunConfig(**yaml.safe_load(f)) config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if config.docker_image: if config.docker_image:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(

View file

@ -1,105 +0,0 @@
from argparse import Namespace
from unittest.mock import MagicMock, patch
import pytest
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.cli.stack.build import StackBuild
# temporary while we make the tests work
pytest.skip(allow_module_level=True)
@pytest.fixture
def stack_build():
parser = MagicMock()
subparsers = MagicMock()
return StackBuild(subparsers)
def test_stack_build_initialization(stack_build):
assert stack_build.parser is not None
assert stack_build.parser.set_defaults.called_once_with(
func=stack_build._run_stack_build_command
)
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_with_config(
mock_build_image, mock_build_config, stack_build
):
args = Namespace(
config="test_config.yaml",
template=None,
list_templates=False,
name=None,
image_type="conda",
)
with patch("builtins.open", MagicMock()):
with patch("yaml.safe_load") as mock_yaml_load:
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()
@patch("llama_stack.cli.table.print_table")
def test_run_stack_build_command_list_templates(mock_print_table, stack_build):
args = Namespace(list_templates=True)
stack_build._run_stack_build_command(args)
mock_print_table.assert_called_once()
@patch("prompt_toolkit.prompt")
@patch("llama_stack.distribution.datatypes.BuildConfig")
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_interactive(
mock_build_image, mock_build_config, mock_prompt, stack_build
):
args = Namespace(
config=None, template=None, list_templates=False, name=None, image_type=None
)
mock_prompt.side_effect = [
"test_name",
"conda",
"meta-reference",
"test description",
]
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
assert mock_prompt.call_count == 4
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()
@patch("llama_stack.distribution.datatypes.BuildConfig")
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_with_template(
mock_build_image, mock_build_config, stack_build
):
args = Namespace(
config=None,
template="test_template",
list_templates=False,
name="test_name",
image_type="docker",
)
with patch("builtins.open", MagicMock()):
with patch("yaml.safe_load") as mock_yaml_load:
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()

View file

@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
import pytest
import yaml
from llama_stack.distribution.configure import (
LLAMA_STACK_RUN_CONFIG_VERSION,
parse_and_maybe_upgrade_config,
)
@pytest.fixture
def up_to_date_config():
return yaml.safe_load(
"""
version: {version}
image_name: foo
apis_to_serve: []
built_at: {built_at}
models:
- identifier: model1
provider_id: provider1
llama_model: Llama3.1-8B-Instruct
shields:
- identifier: shield1
type: llama_guard
provider_id: provider1
memory_banks:
- identifier: memory1
type: vector
provider_id: provider1
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
providers:
inference:
- provider_id: provider1
provider_type: meta-reference
config: {{}}
safety:
- provider_id: provider1
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- provider_id: provider1
provider_type: meta-reference
config: {{}}
""".format(
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
)
)
@pytest.fixture
def old_config():
return yaml.safe_load(
"""
image_name: foo
built_at: {built_at}
apis_to_serve: []
routing_table:
inference:
- provider_type: remote::ollama
config:
host: localhost
port: 11434
routing_key: Llama3.2-1B-Instruct
- provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
routing_key: Llama3.1-8B-Instruct
safety:
- routing_key: ["shield1", "shield2"]
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- routing_key: vector
provider_type: meta-reference
config: {{}}
api_providers:
telemetry:
provider_type: noop
config: {{}}
""".format(
built_at=datetime.now().isoformat()
)
)
@pytest.fixture
def invalid_config():
return yaml.safe_load(
"""
routing_table: {}
api_providers: {}
"""
)
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
result = parse_and_maybe_upgrade_config(up_to_date_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert len(result.models) == 1
assert len(result.shields) == 1
assert len(result.memory_banks) == 1
assert "inference" in result.providers
def test_parse_and_maybe_upgrade_config_old_format(old_config):
result = parse_and_maybe_upgrade_config(old_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert len(result.models) == 2
assert len(result.shields) == 2
assert len(result.memory_banks) == 1
assert all(
api in result.providers
for api in ["inference", "safety", "memory", "telemetry"]
)
safety_provider = result.providers["safety"][0]
assert safety_provider.provider_type == "meta-reference"
assert "llama_guard_shield" in safety_provider.config
inference_providers = result.providers["inference"]
assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == {
"remote::ollama-00",
"meta-reference-01",
}
ollama = inference_providers[0]
assert ollama.provider_type == "remote::ollama"
assert ollama.config["port"] == 11434
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
with pytest.raises(ValueError):
parse_and_maybe_upgrade_config(invalid_config)

View file

@ -8,15 +8,16 @@ from enum import Enum
from typing import List, Optional from typing import List, Optional
import pkg_resources import pkg_resources
from llama_stack.distribution.utils.exec import run_with_pty
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path from pathlib import Path
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
@ -95,6 +96,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
build_config.name, build_config.name,
package_deps.docker_image, package_deps.docker_image,
str(build_file_path), str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.docker.value),
" ".join(deps), " ".join(deps),
] ]
else: else:

View file

@ -23,7 +23,7 @@ if [ "$#" -lt 3 ]; then
exit 1 exit 1
fi fi
special_pip_deps="$3" special_pip_deps="$4"
set -euo pipefail set -euo pipefail

View file

@ -10,7 +10,7 @@ if [ "$#" -lt 4 ]; then
exit 1 exit 1
fi fi
special_pip_deps="$5" special_pip_deps="$6"
set -euo pipefail set -euo pipefail
@ -18,7 +18,8 @@ build_name="$1"
image_name="llamastack-$build_name" image_name="llamastack-$build_name"
docker_base=$2 docker_base=$2
build_file_path=$3 build_file_path=$3
pip_dependencies=$4 host_build_dir=$4
pip_dependencies=$5
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'
@ -33,7 +34,8 @@ REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
TEMP_DIR=$(mktemp -d) TEMP_DIR=$(mktemp -d)
llama stack configure $build_file_path --output-dir $REPO_CONFIGS_DIR llama stack configure $build_file_path
cp $host_build_dir/$build_name-run.yaml $REPO_CONFIGS_DIR
add_to_docker() { add_to_docker() {
local input local input
@ -132,6 +134,9 @@ fi
set -x set -x
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts $DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
# clean up tmp/configs
rm -rf $REPO_CONFIGS_DIR
set +x set +x
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name" echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"

View file

@ -3,171 +3,369 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import textwrap
from typing import Any from typing import Any
from pydantic import BaseModel from llama_models.sku_list import (
llama3_1_family,
llama3_2_family,
llama3_family,
resolve_model,
safety_models,
)
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from prompt_toolkit import prompt from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
stack_apis,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.impls.meta_reference.safety.config import (
MetaReferenceShieldType,
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
ALLOWED_MODELS = (
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
) )
def make_routing_entry_type(config_class: Any): def configure_single_provider(
class BaseModelWithConfig(BaseModel): registry: Dict[str, ProviderSpec], provider: Provider
routing_key: str ) -> Provider:
config: config_class provider_spec = registry[provider.provider_type]
return BaseModelWithConfig
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
"""Get corresponding builtin APIs given provider backed APIs"""
res = []
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in provider_backed_apis:
res.append(inf.routing_table_api.value)
return res
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig:
apis = config.apis_to_serve or list(spec.providers.keys())
# append the bulitin routing APIs
apis += get_builtin_apis(apis)
router_api2builtin_api = {
inf.router_api.value: inf.routing_table_api.value
for inf in builtin_automatically_routed_apis()
}
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
apis = [v.value for v in stack_apis()]
all_providers = get_provider_registry()
# configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys():
if api_str not in apis:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
api = Api(api_str)
p = spec.providers[api_str]
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
if isinstance(p, list):
cprint(
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
"yellow",
)
p = p[0]
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
try: try:
provider_config = config.api_providers.get(api_str) if provider.config:
if provider_config: existing = config_type(**provider.config)
existing = config_type(**provider_config.config)
else: else:
existing = None existing = None
except Exception: except Exception:
existing = None existing = None
cfg = prompt_for_config(config_type, existing) cfg = prompt_for_config(config_type, existing)
return Provider(
if api_str in router_api2builtin_api: provider_id=provider.provider_id,
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table provider_type=provider.provider_type,
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
routing_entries = []
if api_str == "inference":
if hasattr(cfg, "model"):
routing_key = cfg.model
else:
routing_key = prompt(
"> Please enter the supported model your provider has for inference: ",
default="Meta-Llama3.1-8B-Instruct",
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(), config=cfg.dict(),
) )
def configure_api_providers(
config: StackRunConfig, build_spec: DistributionSpec
) -> StackRunConfig:
is_nux = len(config.providers) == 0
if is_nux:
print(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
"""
)
) )
if api_str == "safety": provider_registry = get_provider_registry()
# TODO: add support for other safety providers, and simplify safety provider config builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
if p == "meta-reference":
routing_entries.append( if config.apis:
RoutableProviderConfig( apis_to_serve = config.apis
routing_key=[s.value for s in MetaReferenceShieldType],
provider_type=p,
config=cfg.dict(),
)
)
else: else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
for api_str in apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue
if api not in provider_registry:
raise ValueError(f"Unknown API `{api_str}`")
existing_providers = config.providers.get(api_str, [])
if existing_providers:
cprint( cprint(
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.", f"Re-configuring existing providers for API `{api_str}`...",
"yellow", "green",
attrs=["bold"], attrs=["bold"],
) )
routing_entries.append( updated_providers = []
RoutableProviderConfig( for p in existing_providers:
routing_key=routing_key, print(f"> Configuring provider `({p.provider_type})`")
provider_type=p, updated_providers.append(
config=cfg.dict(), configure_single_provider(provider_registry[api], p)
)
)
if api_str == "memory":
bank_types = list([x.value for x in MemoryBankType])
routing_key = prompt(
"> Please enter the supported memory bank type your provider has for memory: ",
default="vector",
validator=Validator.from_callable(
lambda x: x in bank_types,
error_message="Invalid provider, please enter one of the following: {}".format(
bank_types
),
),
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
)
config.routing_table[api_str] = routing_entries
config.api_providers[api_str] = PlaceholderProviderConfig(
providers=p if isinstance(p, list) else [p]
) )
print("")
else: else:
config.api_providers[api_str] = GenericProviderConfig( # we are newly configuring this API
provider_type=p, plist = build_spec.providers.get(api_str, [])
config=cfg.dict(), plist = plist if isinstance(plist, list) else [plist]
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
updated_providers = []
for i, provider_type in enumerate(plist):
print(f"> Configuring provider `({provider_type})`")
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(
f"{provider_type}-{i:02d}"
if len(plist) > 1
else provider_type
),
provider_type=provider_type,
config={},
),
)
)
print("")
config.providers[api_str] = updated_providers
if is_nux:
print(
textwrap.dedent(
"""
=========================================================================================
Now let's configure the `objects` you will be serving via the stack. These are:
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
This wizard will guide you through setting up one of each of these objects. You can
always add more later by editing the run.yaml file.
"""
)
) )
object_types = {
"models": (ModelDef, configure_models, "inference"),
"shields": (ShieldDef, configure_shields, "safety"),
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
}
safety_providers = config.providers.get("safety", [])
for otype, (odef, config_method, api_str) in object_types.items():
existing_objects = getattr(config, otype)
if existing_objects:
cprint(
f"{len(existing_objects)} {otype} exist. Skipping...",
"blue",
attrs=["bold"],
)
updated_objects = existing_objects
else:
providers = config.providers.get(api_str, [])
if not providers:
updated_objects = []
else:
# we are newly configuring this API
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
updated_objects = config_method(
config.providers[api_str], safety_providers
)
setattr(config, otype, updated_objects)
print("") print("")
return config return config
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
if not safety_providers:
return None
provider = safety_providers[0]
assert provider.provider_type == "meta-reference"
cfg = provider.config["llama_guard_shield"]
if not cfg:
return None
return cfg["model"]
def configure_models(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ModelDef]:
model = prompt(
"> Please enter the model you want to serve: ",
default="Llama3.2-1B-Instruct",
validator=Validator.from_callable(
lambda x: resolve_model(x) is not None,
error_message="Model must be: {}".format(
[x.descriptor() for x in ALLOWED_MODELS]
),
),
)
model = ModelDef(
identifier=model,
llama_model=model,
provider_id=providers[0].provider_id,
)
ret = [model]
if llama_guard := get_llama_guard_model(safety_providers):
ret.append(
ModelDef(
identifier=llama_guard,
llama_model=llama_guard,
provider_id=providers[0].provider_id,
)
)
return ret
def configure_shields(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ShieldDef]:
if get_llama_guard_model(safety_providers):
return [
ShieldDef(
identifier="llama_guard",
type="llama_guard",
provider_id=providers[0].provider_id,
params={},
)
]
return []
def configure_memory_banks(
providers: List[Provider], safety_providers: List[Provider]
) -> List[MemoryBankDef]:
bank_name = prompt(
"> Please enter a name for your memory bank: ",
default="my-memory-bank",
)
return [
VectorMemoryBankDef(
identifier=bank_name,
provider_id=providers[0].provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
]
def upgrade_from_routing_table_to_registry(
config_dict: Dict[str, Any],
) -> Dict[str, Any]:
def get_providers(entries):
return [
Provider(
provider_id=(
f"{entry['provider_type']}-{i:02d}"
if len(entries) > 1
else entry["provider_type"]
),
provider_type=entry["provider_type"],
config=entry["config"],
)
for i, entry in enumerate(entries)
]
providers_by_api = {}
models = []
shields = []
memory_banks = []
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
if api_str == "inference":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
models.append(
ModelDef(
identifier=key,
provider_id=provider.provider_id,
llama_model=key,
)
)
elif api_str == "safety":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
shields.append(
ShieldDef(
identifier=key,
type=ShieldType.llama_guard.value,
provider_id=provider.provider_id,
)
)
elif api_str == "memory":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
# we currently only support Vector memory banks so this is OK
memory_banks.append(
VectorMemoryBankDef(
identifier=key,
provider_id=provider.provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
)
config_dict["models"] = models
config_dict["shields"] = shields
config_dict["memory_banks"] = memory_banks
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
if isinstance(provider, dict) and "provider_type" in provider:
providers_by_api[api_str] = [
Provider(
provider_id=f"{provider['provider_type']}",
provider_type=provider["provider_type"],
config=provider["config"],
)
]
config_dict["providers"] = providers_by_api
config_dict.pop("routing_table", None)
config_dict.pop("api_providers", None)
config_dict.pop("provider_map", None)
config_dict["apis"] = config_dict["apis_to_serve"]
config_dict.pop("apis_to_serve", None)
return config_dict
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict)
if "models" not in config_dict:
print("Upgrading config...")
config_dict = upgrade_from_routing_table_to_registry(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
config_dict["built_at"] = datetime.now().isoformat()
return StackRunConfig(**config_dict)

View file

@ -11,28 +11,32 @@ from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1" LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1" LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]] RoutingKey = Union[str, List[str]]
class GenericProviderConfig(BaseModel): RoutableObject = Union[
provider_type: str ModelDef,
config: Dict[str, Any] ShieldDef,
MemoryBankDef,
]
RoutedProtocol = Union[
class RoutableProviderConfig(GenericProviderConfig): Inference,
routing_key: RoutingKey Safety,
Memory,
]
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
# Example: /inference, /safety # Example: /inference, /safety
@ -53,18 +57,17 @@ class AutoRoutedProviderSpec(ProviderSpec):
# Example: /models, /shields # Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec): class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table" provider_type: str = "routing_table"
config_class: str = "" config_class: str = ""
docker_image: Optional[str] = None docker_image: Optional[str] = None
inner_specs: List[ProviderSpec] router_api: Api
registry: List[RoutableObject]
module: str module: str
pip_packages: List[str] = Field(default_factory=list) pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class DistributionSpec(BaseModel): class DistributionSpec(BaseModel):
description: Optional[str] = Field( description: Optional[str] = Field(
default="", default="",
@ -80,7 +83,12 @@ in the runtime configuration to help route to the correct provider.""",
) )
@json_schema_type class Provider(BaseModel):
provider_id: str
provider_type: str
config: Dict[str, Any]
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):
version: str = LLAMA_STACK_RUN_CONFIG_VERSION version: str = LLAMA_STACK_RUN_CONFIG_VERSION
built_at: datetime built_at: datetime
@ -100,36 +108,39 @@ this could be just a hash
default=None, default=None,
description="Reference to the conda environment if this package refers to a conda environment", description="Reference to the conda environment if this package refers to a conda environment",
) )
apis_to_serve: List[str] = Field( apis: List[str] = Field(
default_factory=list,
description=""" description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
) )
api_providers: Dict[ providers: Dict[str, List[Provider]] = Field(
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
] = Field(
description=""" description="""
Provider configurations for each of the APIs provided by this package. One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""", """,
) )
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
default_factory=dict,
description="""
E.g. The following is a ProviderRoutingEntry for models: models: List[ModelDef] = Field(
- routing_key: Meta-Llama3.1-8B-Instruct description="""
provider_type: meta-reference List of model definitions to serve. This list may get extended by
config: /models/register API calls at runtime.
model: Meta-Llama3.1-8B-Instruct """,
quantization: null )
torch_seed: null shields: List[ShieldDef] = Field(
max_seq_len: 4096 description="""
max_batch_size: 1 List of shield definitions to serve. This list may get extended by
/shields/register API calls at runtime.
""",
)
memory_banks: List[MemoryBankDef] = Field(
description="""
List of memory bank definitions to serve. This list may get extended by
/memory_banks/register API calls at runtime.
""", """,
) )
@json_schema_type
class BuildConfig(BaseModel): class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
name: str name: str

View file

@ -6,45 +6,58 @@
from typing import Dict, List from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403 from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
def is_passthrough(spec: ProviderSpec) -> bool: class DistributionInspectConfig(BaseModel):
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = DistributionInspectImpl(config, deps)
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect): class DistributionInspectImpl(Inspect):
def __init__(self): def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]: async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
run_config = self.config.run_config
ret = {} ret = {}
all_providers = get_provider_registry() for api, providers in run_config.providers.items():
for api, providers in all_providers.items(): ret[api] = [
ret[api.value] = [
ProviderInfo( ProviderInfo(
provider_id=p.provider_id,
provider_type=p.provider_type, provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
) )
for p in providers.values() for p in providers
] ]
return ret return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]: async def list_routes(self) -> Dict[str, List[RouteInfo]]:
run_config = self.config.run_config
ret = {} ret = {}
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items(): for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret[api.value] = [ ret[api.value] = [
RouteInfo( RouteInfo(
route=e.route, route=e.route,
method=e.method, method=e.method,
providers=[], provider_types=[p.provider_type for p in providers],
) )
for e in endpoints for e in endpoints
] ]

View file

@ -13,138 +13,207 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
) )
from llama_stack.distribution.inspect import DistributionInspectImpl
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
""" """
Does two things: Does two things:
- flatmaps, sorts and resolves the providers in dependency order - flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation - for each API, produces either a (local, passthrough or router) implementation
""" """
all_providers = get_provider_registry() all_api_providers = get_provider_registry()
specs = {}
configs = {}
for api_str, config in run_config.api_providers.items(): routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
providers_with_specs = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str) api = Api(api_str)
if api in routing_table_apis:
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# skip checks for API whose provider config is specified in routing_table
if isinstance(config, PlaceholderProviderConfig):
continue
if config.provider_type not in providers:
raise ValueError( raise ValueError(
f"Provider `{config.provider_type}` is not available for API `{api}`" f"Provider for `{api_str}` is automatically provided and cannot be overridden"
) )
specs[api] = providers[config.provider_type]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set( specs = {}
list(specs.keys()) + list(run_config.routing_table.keys()) for provider in providers:
if provider.provider_type not in all_api_providers[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
) )
p = all_api_providers[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.dict()),
)
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
)
for info in builtin_automatically_routed_apis(): for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve: if info.router_api.value not in apis_to_serve:
continue continue
if info.router_api.value not in run_config.routing_table: available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_table[info.router_api.value]
providers = all_providers[info.router_api]
inner_specs = []
inner_deps = [] inner_deps = []
for rt_entry in routing_table: registry = getattr(run_config, info.routing_table_api.value)
if rt_entry.provider_type not in providers: for entry in registry:
if entry.provider_id not in available_providers:
raise ValueError( raise ValueError(
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`" f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
) )
inner_specs.append(providers[rt_entry.provider_type])
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
specs[source_api] = RoutingTableProviderSpec( provider = available_providers[entry.provider_id]
api=source_api, inner_deps.extend(provider.spec.api_dependencies)
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__builtin__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
registry=registry,
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
api_dependencies=inner_deps, api_dependencies=inner_deps,
inner_specs=inner_specs, deps__=(
[x.value for x in inner_deps]
+ [f"inner-{info.router_api.value}"]
),
),
) )
configs[source_api] = routing_table }
specs[info.router_api] = AutoRoutedProviderSpec( providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__builtin__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api, api=info.router_api,
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
routing_table_api=source_api, routing_table_api=info.routing_table_api,
api_dependencies=[source_api], api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
) )
configs[info.router_api] = {} }
sorted_specs = topological_sort(specs.values()) sorted_providers = topological_sort(
print(f"Resolved {len(sorted_specs)} providers in topological order") {k: v.values() for k, v in providers_with_specs.items()}
for spec in sorted_specs: )
print(f" {spec.api}: {spec.provider_type}") apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
"inspect",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={
"run_config": run_config.dict(),
},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=([x.value for x in apis]),
),
),
)
)
print(f"Resolved {len(sorted_providers)} providers in topological order")
for api_str, provider in sorted_providers:
print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}")
print("") print("")
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, configs[api])
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider(
provider,
deps,
inner_impls,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl impls[api] = impl
impls[Api.inspect] = DistributionInspectImpl() return impls
specs[Api.inspect] = InlineProviderSpec(
api=Api.inspect,
provider_type="__distribution_builtin__",
config_class="",
module="",
)
return impls, specs
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: def topological_sort(
by_id = {x.api: x for x in providers} providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[ProviderWithSpec]:
def dfs(kv, visited: Set[str], stack: List[str]):
api_str, providers = kv
visited.add(api_str)
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): deps = []
visited.add(a.api) for provider in providers:
for dep in provider.spec.deps__:
deps.append(dep)
for api in a.api_dependencies: for dep in deps:
if api not in visited: if dep not in visited:
dfs(by_id[api], visited, stack) dfs((dep, providers_with_specs[dep]), visited, stack)
stack.append(a.api) stack.append(api_str)
visited = set() visited = set()
stack = [] stack = []
for a in providers: for api_str, providers in providers_with_specs.items():
if a.api not in visited: if api_str not in visited:
dfs(a, visited, stack) dfs((api_str, providers), visited, stack)
return [by_id[x] for x in stack] flattened = []
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api # returns a class implementing the protocol corresponding to the Api
async def instantiate_provider( async def instantiate_provider(
provider_spec: ProviderSpec, provider: ProviderWithSpec,
deps: Dict[str, Any], deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable], inner_impls: Dict[str, Any],
): ):
provider_spec = provider.spec
module = importlib.import_module(provider_spec.module) module = importlib.import_module(provider_spec.module)
args = [] args = []
@ -154,9 +223,8 @@ async def instantiate_provider(
else: else:
method = "get_client_impl" method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config) config = config_type(**provider.config)
args = [config, deps] args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec): elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl" method = "get_auto_router_impl"
@ -166,31 +234,18 @@ async def instantiate_provider(
elif isinstance(provider_spec, RoutingTableProviderSpec): elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl" method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None config = None
args = [provider_spec.api, inner_impls, routing_table, deps] args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
else: else:
method = "get_provider_impl" method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config) config = config_type(**provider.config)
args = [config, deps] args = [config, deps]
fn = getattr(module, method) fn = getattr(module, method)
impl = await fn(*args) impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config impl.__provider_config__ = config
return impl return impl

View file

@ -4,23 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, List, Tuple from typing import Any, List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
_deps,
) -> Any:
from .routing_tables import ( from .routing_tables import (
MemoryBanksRoutingTable, MemoryBanksRoutingTable,
ModelsRoutingTable, ModelsRoutingTable,
ShieldsRoutingTable, ShieldsRoutingTable,
) )
async def get_routing_table_impl(
api: Api,
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
) -> Any:
api_to_tables = { api_to_tables = {
"memory_banks": MemoryBanksRoutingTable, "memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable, "models": ModelsRoutingTable,
@ -29,7 +28,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables: if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](inner_impls, routing_table_config) impl = api_to_tables[api.value](registry, impls_by_provider_id)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory): class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type""" """Routes to an provider based on the memory bank identifier"""
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
self.routing_table = routing_table self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -29,32 +28,8 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def get_provider_from_bank_id(self, bank_id: str) -> Any: async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
bank_type = self.bank_id_to_type.get(bank_id) await self.routing_table.register_memory_bank(memory_bank)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
name, config, url
)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def insert_documents( async def insert_documents(
self, self,
@ -62,7 +37,7 @@ class MemoryRouter(Memory):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
return await self.get_provider_from_bank_id(bank_id).insert_documents( return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds bank_id, documents, ttl_seconds
) )
@ -72,7 +47,7 @@ class MemoryRouter(Memory):
query: InterleavedTextMedia, query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
return await self.get_provider_from_bank_id(bank_id).query_documents( return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params bank_id, query, params
) )
@ -92,7 +67,10 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def chat_completion( async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -113,27 +91,32 @@ class InferenceRouter(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
# TODO: we need to fix streaming response to align provider implementations with Protocol. provider = self.routing_table.get_provider_impl(model)
async for chunk in self.routing_table.get_provider_impl(model).chat_completion( if stream:
**params return (chunk async for chunk in provider.chat_completion(**params))
): else:
yield chunk return provider.chat_completion(**params)
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> AsyncGenerator:
return await self.routing_table.get_provider_impl(model).completion( provider = self.routing_table.get_provider_impl(model)
params = dict(
model=model, model=model,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
if stream:
return (chunk async for chunk in provider.completion(**params))
else:
return provider.completion(**params)
async def embeddings( async def embeddings(
self, self,
@ -159,6 +142,9 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield)
async def run_shield( async def run_shield(
self, self,
shield_type: str, shield_type: str,

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403
@ -16,129 +15,129 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.inference:
await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable): class CommonRoutingTableImpl(RoutingTable):
def __init__( def __init__(
self, self,
inner_impls: List[Tuple[RoutingKey, Any]], registry: List[RoutableObject],
routing_table_config: Dict[str, List[RoutableProviderConfig]], impls_by_provider_id: Dict[str, RoutedProtocol],
) -> None: ) -> None:
self.unique_providers = [] for obj in registry:
self.providers = {} if obj.provider_id not in impls_by_provider_id:
self.routing_keys = [] print(f"{impls_by_provider_id=}")
raise ValueError(
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
)
for key, impl in inner_impls: self.impls_by_provider_id = impls_by_provider_id
keys = key if isinstance(key, list) else [key] self.registry = registry
self.unique_providers.append((keys, impl))
for k in keys: for p in self.impls_by_provider_id.values():
if k in self.providers: api = get_impl_api(p)
raise ValueError(f"Duplicate routing key {k}") if api == Api.inference:
self.providers[k] = impl p.model_store = self
self.routing_keys.append(k) elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
self.routing_table_config = routing_table_config self.routing_key_to_object = {}
for obj in self.registry:
self.routing_key_to_object[obj.identifier] = obj
async def initialize(self) -> None: async def initialize(self) -> None:
for keys, p in self.unique_providers: for obj in self.registry:
spec = p.__provider_spec__ p = self.impls_by_provider_id[obj.provider_id]
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None: await register_object_with_provider(obj, p)
continue
await p.validate_routing_keys(keys)
async def shutdown(self) -> None: async def shutdown(self) -> None:
for _, p in self.unique_providers: for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any: def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.providers: if routing_key not in self.routing_key_to_object:
raise ValueError(f"Could not find provider for {routing_key}") raise ValueError(f"`{routing_key}` not registered")
return self.providers[routing_key]
def get_routing_keys(self) -> List[str]: obj = self.routing_key_to_object[routing_key]
return self.routing_keys if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]: return self.impls_by_provider_id[obj.provider_id]
for entry in self.routing_table_config:
if entry.routing_key == routing_key: def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
return entry for obj in self.registry:
if obj.identifier == identifier:
return obj
return None return None
async def register_object(self, obj: RoutableObject):
if obj.identifier in self.routing_key_to_object:
print(f"`{obj.identifier}` is already registered")
return
if not obj.provider_id:
provider_ids = list(self.impls_by_provider_id.keys())
if not provider_ids:
raise ValueError("No providers found")
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
obj.provider_id = provider_ids[0]
else:
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj)
# TODO: persist this to a store
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDef]:
return self.registry
async def list_models(self) -> List[ModelServingSpec]: async def get_model(self, identifier: str) -> Optional[ModelDef]:
specs = [] return self.get_object_by_identifier(identifier)
for entry in self.routing_table_config:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=entry,
)
)
return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: async def register_model(self, model: ModelDef) -> None:
for entry in self.routing_table_config: await self.register_object(model)
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
provider_config=entry,
)
return None
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]:
return self.registry
async def list_shields(self) -> List[ShieldSpec]: async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
specs = [] return self.get_object_by_identifier(shield_type)
for entry in self.routing_table_config:
if isinstance(entry.routing_key, list):
for k in entry.routing_key:
specs.append(
ShieldSpec(
shield_type=k,
provider_config=entry,
)
)
else:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: async def register_shield(self, shield: ShieldDef) -> None:
for entry in self.routing_table_config: await self.register_object(shield)
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
return None
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.registry
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
specs = [] return self.get_object_by_identifier(identifier)
for entry in self.routing_table_config:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]: async def register_memory_bank(self, bank: MemoryBankDef) -> None:
for entry in self.routing_table_config: await self.register_object(bank)
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
return None

View file

@ -5,18 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import functools
import inspect import inspect
import json import json
import signal import signal
import traceback import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from ssl import SSLError from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional from typing import Any, Dict, Optional
import fire import fire
import httpx import httpx
@ -43,20 +40,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str: def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel): if isinstance(data, BaseModel):
data = data.json() data = data.json()
@ -169,11 +152,20 @@ async def passthrough(
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR) await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(*args, **kwargs): def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...") print("SIGINT or CTRL-C detected. Exiting gracefully...")
async def run_shutdown():
for impl in app.__llama_stack_impls__.values():
print(f"Shutting down {impl}")
await impl.shutdown()
asyncio.run(run_shutdown())
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop): for task in asyncio.all_tasks(loop):
task.cancel() task.cancel()
loop.stop() loop.stop()
@ -181,7 +173,10 @@ def handle_sigint(*args, **kwargs):
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
print("Starting up") print("Starting up")
yield yield
print("Shutting down") print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
def create_dynamic_passthrough( def create_dynamic_passthrough(
@ -193,22 +188,16 @@ def create_dynamic_passthrough(
return endpoint return endpoint
def create_dynamic_typed_route(func: Any, method: str): def is_streaming_request(func_name: str, request: Request, **kwargs):
hints = get_type_hints(func) # TODO: pass the api method and punt it to the Protocol definition directly
response_model = hints.get("return") return kwargs.get("stream", False)
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming: async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
async def sse_generator(event_gen): async def sse_generator(event_gen):
try: try:
@ -230,23 +219,23 @@ def create_dynamic_typed_route(func: Any, method: str):
finally: finally:
await end_trace() await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else: def create_dynamic_typed_route(func: Any, method: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__) await start_trace(func.__name__)
set_request_provider_data(request.headers) set_request_provider_data(request.headers)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:
return ( if is_streaming:
await func(**kwargs) return StreamingResponse(
if asyncio.iscoroutinefunction(func) sse_generator(func(**kwargs)), media_type="text/event-stream"
else func(**kwargs)
) )
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e: except Exception as e:
traceback.print_exception(e) traceback.print_exception(e)
raise translate_exception(e) from e raise translate_exception(e) from e
@ -285,29 +274,25 @@ def main(
app = FastAPI() app = FastAPI()
impls, specs = asyncio.run(resolve_impls_with_routing(config)) impls = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()
if config.apis_to_serve: if config.apis:
apis_to_serve = set(config.apis_to_serve) apis_to_serve = set(config.apis)
else: else:
apis_to_serve = set(impls.keys()) apis_to_serve = set(impls.keys())
apis_to_serve.add(Api.inspect) apis_to_serve.add("inspect")
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
impl = impls[api] impl = impls[api]
provider_spec = specs[api] if is_passthrough(impl.__provider_spec__):
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints: for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)( getattr(app, endpoint.method)(endpoint.route)(
@ -337,7 +322,9 @@ def main(
print("") print("")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
app.__llama_stack_impls__ = impls
import uvicorn import uvicorn

View file

@ -1,8 +1,9 @@
built_at: '2024-09-30T09:04:30.533391' version: '2'
built_at: '2024-10-08T17:42:07.505267'
image_name: local-cpu image_name: local-cpu
docker_image: local-cpu docker_image: local-cpu
conda_env: null conda_env: null
apis_to_serve: apis:
- agents - agents
- inference - inference
- models - models
@ -10,40 +11,48 @@ apis_to_serve:
- safety - safety
- shields - shields
- memory_banks - memory_banks
api_providers: providers:
inference: inference:
providers: - provider_id: remote::ollama
- remote::ollama provider_type: remote::ollama
config:
host: localhost
port: 6000
safety: safety:
providers: - provider_id: meta-reference
- meta-reference provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
memory:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
agents: agents:
- provider_id: meta-reference
provider_type: meta-reference provider_type: meta-reference
config: config:
persistence_store: persistence_store:
namespace: null namespace: null
type: sqlite type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db db_path: ~/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
telemetry: telemetry:
- provider_id: meta-reference
provider_type: meta-reference provider_type: meta-reference
config: {} config: {}
routing_table: models:
inference: - identifier: Llama3.1-8B-Instruct
- provider_type: remote::ollama llama_model: Llama3.1-8B-Instruct
config: provider_id: remote::ollama
host: localhost shields:
port: 6000 - identifier: llama_guard
routing_key: Meta-Llama3.1-8B-Instruct type: llama_guard
safety: provider_id: meta-reference
- provider_type: meta-reference params: {}
config: memory_banks:
llama_guard_shield: null - identifier: vector
prompt_guard_shield: null provider_id: meta-reference
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"] type: vector
memory: embedding_model: all-MiniLM-L6-v2
- provider_type: meta-reference chunk_size_in_tokens: 512
config: {} overlap_size_in_tokens: null
routing_key: vector

View file

@ -1,8 +1,9 @@
built_at: '2024-09-30T09:00:56.693751' version: '2'
built_at: '2024-10-08T17:42:33.690666'
image_name: local-gpu image_name: local-gpu
docker_image: local-gpu docker_image: local-gpu
conda_env: null conda_env: null
apis_to_serve: apis:
- memory - memory
- inference - inference
- agents - agents
@ -10,43 +11,51 @@ apis_to_serve:
- safety - safety
- models - models
- memory_banks - memory_banks
api_providers: providers:
inference: inference:
providers: - provider_id: meta-reference
- meta-reference
safety:
providers:
- meta-reference
agents:
provider_type: meta-reference provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
telemetry:
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_type: meta-reference
config: config:
model: Llama3.1-8B-Instruct model: Llama3.1-8B-Instruct
quantization: null quantization: null
torch_seed: null torch_seed: null
max_seq_len: 4096 max_seq_len: 4096
max_batch_size: 1 max_batch_size: 1
routing_key: Llama3.1-8B-Instruct
safety: safety:
- provider_type: meta-reference - provider_id: meta-reference
provider_type: meta-reference
config: config:
llama_guard_shield: null llama_guard_shield: null
prompt_guard_shield: null prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory: memory:
- provider_type: meta-reference - provider_id: meta-reference
provider_type: meta-reference
config: {} config: {}
routing_key: vector agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
models:
- identifier: Llama3.1-8B-Instruct
llama_model: Llama3.1-8B-Instruct
provider_id: meta-reference
shields:
- identifier: llama_guard
type: llama_guard
provider_id: meta-reference
params: {}
memory_banks:
- identifier: vector
provider_id: meta-reference
type: vector
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
overlap_size_in_tokens: null

View file

@ -0,0 +1,10 @@
name: local-databricks
distribution_spec:
description: Use Databricks for running LLM inference
providers:
inference: remote::databricks
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-vllm
distribution_spec:
description: Like local, but use vLLM for running LLM inference
providers:
inference: vllm
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -13,7 +13,7 @@ from botocore.config import Config
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
@ -26,7 +26,7 @@ BEDROCK_SUPPORTED_MODELS = {
} }
class BedrockInferenceAdapter(Inference, RoutableProviderForModels): class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
@staticmethod @staticmethod
def _create_bedrock_client(config: BedrockConfig) -> BaseClient: def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
@ -69,7 +69,7 @@ class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
return boto3_session.client("bedrock-runtime", config=boto3_config) return boto3_session.client("bedrock-runtime", config=boto3_config)
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
) )
self._config = config self._config = config

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(
config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class DatabricksImplConfig(BaseModel):
url: str = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: str = Field(
default=None,
description="The Databricks API token",
)

View file

@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
}
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request),
}

View file

@ -10,14 +10,19 @@ from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
) )
from .config import FireworksImplConfig from .config import FireworksImplConfig
@ -27,21 +32,18 @@ FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct", "Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
} }
class FireworksInferenceAdapter(Inference, RoutableProviderForModels): class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
) )
self.config = config self.config = config
tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(Tokenizer.get_instance())
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> Fireworks:
return Fireworks(api_key=self.config.api_key)
async def initialize(self) -> None: async def initialize(self) -> None:
return return
@ -49,7 +51,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -59,27 +61,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list: def chat_completion(
fireworks_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
fireworks_messages.append({"role": role, "content": message.content})
return fireworks_messages
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -101,147 +83,41 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
logprobs=logprobs, logprobs=logprobs,
) )
messages = augment_messages_for_tools(request) client = Fireworks(api_key=self.config.api_key)
if stream:
# accumulate sampling params and other options to pass to fireworks return self._stream_chat_completion(request, client)
options = self.get_fireworks_chat_options(request)
fireworks_model = self.map_to_provider_model(request.model)
if not request.stream:
r = await self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else: else:
yield ChatCompletionResponseStreamChunk( return self._nonstream_chat_completion(request, client)
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = "" async def _nonstream_chat_completion(
ipython = False self, request: ChatCompletionRequest, client: Fireworks
stop_reason = None ) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_chat_completion_response(request, r, self.formatter)
async for chunk in self.client.chat.completions.acreate( async def _stream_chat_completion(
model=fireworks_model, self, request: ChatCompletionRequest, client: Fireworks
messages=self._messages_to_fireworks_messages(messages), ) -> AsyncGenerator:
stream=True, params = self._get_params(request)
stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt = chat_completion_request_to_prompt(request, self.formatter)
# Fireworks always prepends with BOS
if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
options = get_sampling_options(request)
options.setdefault("max_tokens", 512)
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
"stream": request.stream,
**options, **options,
): }
if chunk.choices[0].finish_reason:
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -7,6 +7,10 @@
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
class OllamaImplConfig(RemoteProviderConfig):
port: int = 11434
async def get_adapter_impl(config: RemoteProviderConfig, _deps): async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter from .ollama import OllamaInferenceAdapter

View file

@ -9,34 +9,36 @@ from typing import AsyncGenerator
import httpx import httpx
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.openai_compat import (
augment_messages_for_tools, get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
) )
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
# TODO: Eventually this will move to the llama cli model list command OLLAMA_SUPPORTED_MODELS = {
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
# "Llama3.1-8B-Instruct": "llama3.1",
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
} }
class OllamaInferenceAdapter(Inference, RoutableProviderForModels): class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
)
self.url = url self.url = url
tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(Tokenizer.get_instance())
self.formatter = ChatFormat(tokenizer)
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
@ -54,7 +56,29 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def completion( async def register_model(self, model: ModelDef) -> None:
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}"
)
ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier]
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}")
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -64,32 +88,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def _messages_to_ollama_messages(self, messages: list[Message]) -> list: def chat_completion(
ollama_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
ollama_messages.append({"role": role, "content": message.content})
return ollama_messages
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -110,156 +109,54 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
if stream:
messages = augment_messages_for_tools(request) return self._stream_chat_completion(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
if not request.stream:
r = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=False,
options=options,
)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r["message"]["content"], stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else: else:
yield ChatCompletionResponseStreamChunk( return self._nonstream_chat_completion(request)
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, def _get_params(self, request: ChatCompletionRequest) -> dict:
delta="", return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request),
"raw": True,
"stream": request.stream,
}
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
) )
response = OpenAICompatCompletionResponse(
choices=[choice],
) )
stream = await self.client.chat( return process_chat_completion_response(request, response, self.formatter)
model=ollama_model,
messages=self._messages_to_ollama_messages(messages), async def _stream_chat_completion(
stream=True, self, request: ChatCompletionRequest
options=options, ) -> AsyncGenerator:
params = self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
) )
buffer = "" stream = _generate_and_convert_to_openai_compat()
ipython = False async for chunk in process_chat_completion_stream_response(
stop_reason = None request, stream, self.formatter
):
async for chunk in stream: yield chunk
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk["message"]["content"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleInferenceImpl(Inference):
class SampleInferenceImpl(Inference, RoutableProvider):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None: async def register_model(self, model: ModelDef) -> None:
# these are the model names the Llama Stack will use to route requests to this provider # these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
pass pass

View file

@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel):
@json_schema_type @json_schema_type
class InferenceAPIImplConfig(BaseModel): class InferenceAPIImplConfig(BaseModel):
model_id: str = Field( huggingface_repo: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
) )
api_token: Optional[str] = Field( api_token: Optional[str] = Field(

View file

@ -10,14 +10,19 @@ from typing import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.openai_compat import (
augment_messages_for_tools, get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
) )
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
@ -25,24 +30,31 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _HfAdapter(Inference, RoutableProvider): class _HfAdapter(Inference):
client: AsyncInferenceClient client: AsyncInferenceClient
max_tokens: int max_tokens: int
model_id: str model_id: str
def __init__(self) -> None: def __init__(self) -> None:
self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(Tokenizer.get_instance())
self.formatter = ChatFormat(self.tokenizer)
async def validate_routing_keys(self, routing_keys: list[str]) -> None: async def register_model(self, model: ModelDef) -> None:
# these are the model names the Llama Stack will use to route requests to this provider resolved_model = resolve_model(model.identifier)
# perform validation here if necessary if resolved_model is None:
pass raise ValueError(f"Unknown model: {model.identifier}")
if not resolved_model.huggingface_repo:
raise ValueError(
f"Model {model.identifier} does not have a HuggingFace repo"
)
if self.model_id != resolved_model.huggingface_repo:
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -52,16 +64,7 @@ class _HfAdapter(Inference, RoutableProvider):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def get_chat_options(self, request: ChatCompletionRequest) -> dict: def chat_completion(
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -83,146 +86,64 @@ class _HfAdapter(Inference, RoutableProvider):
logprobs=logprobs, logprobs=logprobs,
) )
messages = augment_messages_for_tools(request) if stream:
model_input = self.formatter.encode_dialog_prompt(messages) return self._stream_chat_completion(request)
prompt = self.tokenizer.decode(model_input.tokens) else:
return self._nonstream_chat_completion(request)
input_tokens = len(model_input.tokens) async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
choice = OpenAICompatCompletionChoice(text=token_result.text)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
)
max_new_tokens = min( max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens), request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1, self.max_tokens - input_tokens - 1,
) )
options = get_sampling_options(request)
print(f"Calculated max_new_tokens: {max_new_tokens}") return dict(
options = self.get_chat_options(request)
if not request.stream:
response = await self.client.text_generation(
prompt=prompt, prompt=prompt,
stream=False, stream=request.stream,
details=True, details=True,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"], stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options, **options,
) )
stop_reason = None
if response.details.finish_reason:
if response.details.finish_reason in ["stop", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif response.details.finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
response.generated_text,
stop_reason,
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
tokens = []
async for response in await self.client.text_generation(
prompt=prompt,
stream=True,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
):
token_result = response.token
buffer += token_result.text
tokens.append(token_result.id)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# parse tool calls and report errors
message = self.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
class TGIAdapter(_HfAdapter): class TGIAdapter(_HfAdapter):
@ -236,7 +157,7 @@ class TGIAdapter(_HfAdapter):
class InferenceAPIAdapter(_HfAdapter): class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None: async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(
model=config.model_id, token=config.api_token model=config.huggingface_repo, token=config.api_token
) )
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]

View file

@ -8,17 +8,22 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together from together import Together
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
augment_messages_for_tools, from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
) )
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from .config import TogetherImplConfig from .config import TogetherImplConfig
@ -34,19 +39,14 @@ TOGETHER_SUPPORTED_MODELS = {
class TogetherInferenceAdapter( class TogetherInferenceAdapter(
Inference, NeedsRequestProviderData, RoutableProviderForModels ModelRegistryHelper, Inference, NeedsRequestProviderData
): ):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
) )
self.config = config self.config = config
tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(Tokenizer.get_instance())
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> Together:
return Together(api_key=self.config.api_key)
async def initialize(self) -> None: async def initialize(self) -> None:
return return
@ -64,27 +64,7 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def _messages_to_together_messages(self, messages: list[Message]) -> list: def chat_completion(
together_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
together_messages.append({"role": role, "content": message.content})
return together_messages
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -95,7 +75,6 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
together_api_key = None together_api_key = None
if self.config.api_key is not None: if self.config.api_key is not None:
together_api_key = self.config.api_key together_api_key = self.config.api_key
@ -108,7 +87,6 @@ class TogetherInferenceAdapter(
together_api_key = provider_data.together_api_key together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key) client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,
@ -120,146 +98,39 @@ class TogetherInferenceAdapter(
logprobs=logprobs, logprobs=logprobs,
) )
# accumulate sampling params and other options to pass to together if stream:
options = self.get_together_chat_options(request) return self._stream_chat_completion(request, client)
together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request)
if not request.stream:
# TODO: might need to add back an async here
r = client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if (
r.choices[0].finish_reason == "stop"
or r.choices[0].finish_reason == "eos"
):
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else: else:
yield ChatCompletionResponseStreamChunk( return self._nonstream_chat_completion(request, client)
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = "" async def _nonstream_chat_completion(
ipython = False self, request: ChatCompletionRequest, client: Together
stop_reason = None ) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
for chunk in client.chat.completions.create( async def _stream_chat_completion(
model=together_model, self, request: ChatCompletionRequest, client: Together
messages=self._messages_to_together_messages(messages), ) -> AsyncGenerator:
stream=True, params = self._get_params(request)
**options,
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
): ):
if finish_reason := chunk.choices[0].finish_reason: yield chunk
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content def _get_params(self, request: ChatCompletionRequest) -> dict:
if text is None: return {
continue "model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
# check if its a tool call ( aka starts with <|python_tag|> ) "stream": request.stream,
if not ipython and text.startswith("<|python_tag|>"): **get_sampling_options(request),
ipython = True }
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import uuid
from typing import List from typing import List
from urllib.parse import urlparse from urllib.parse import urlparse
@ -13,7 +12,6 @@ import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory, RoutableProvider): class ChromaMemoryAdapter(Memory):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}") print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/") url = url.rstrip("/")
@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_memory_bank(
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self, self,
name: str, memory_bank: MemoryBankDef,
config: MemoryBankConfig, ) -> None:
url: Optional[URL] = None, assert (
) -> MemoryBank: memory_bank.type == MemoryBankType.vector.value
bank_id = str(uuid.uuid4()) ), f"Only vector banks are supported {memory_bank.type}"
bank = MemoryBank(
bank_id=bank_id, collection = await self.client.get_or_create_collection(
name=name, name=memory_bank.identifier,
config=config,
url=url,
)
collection = await self.client.create_collection(
name=bank_id,
metadata={"bank": bank.json()},
) )
bank_index = BankWithIndex( bank_index = BankWithIndex(
bank=bank, index=ChromaIndex(self.client, collection) bank=memory_bank, index=ChromaIndex(self.client, collection)
) )
self.cache[bank_id] = bank_index self.cache[memory_bank.identifier] = bank_index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if bank is None:
raise ValueError(f"Bank {bank_id} not found")
collections = await self.client.list_collections() collections = await self.client.list_collections()
for collection in collections: for collection in collections:
if collection.name == bank_id: if collection.name == bank_id:
print(collection.metadata)
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=ChromaIndex(self.client, collection), index=ChromaIndex(self.client, collection),

View file

@ -4,18 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import uuid from typing import List
from typing import List, Tuple
import psycopg2 import psycopg2
from numpy.typing import NDArray from numpy.typing import NDArray
from psycopg2 import sql from psycopg2 import sql
from psycopg2.extras import execute_values, Json from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION, ALL_MINILM_L6_V2_DIMENSION,
@ -32,33 +28,6 @@ def check_extension_version(cur):
return result[0] if result else None return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, keys: List[str], cls):
query = "SELECT key, data FROM metadata_store"
if keys:
placeholders = ",".join(["%s"] * len(keys))
query += f" WHERE key IN ({placeholders})"
cur.execute(query, keys)
else:
cur.execute(query)
rows = cur.fetchall()
return [cls(**row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex): class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor): def __init__(self, bank: MemoryBank, dimension: int, cursor):
self.cursor = cursor self.cursor = cursor
@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory, RoutableProvider): class PGVectorMemoryAdapter(Memory):
def __init__(self, config: PGVectorConfig) -> None: def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config self.config = config
@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
else: else:
raise RuntimeError("Vector extension is not installed.") raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e: except Exception as e:
import traceback import traceback
@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_memory_bank(
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self, self,
name: str, memory_bank: MemoryBankDef,
config: MemoryBankConfig, ) -> None:
url: Optional[URL] = None, assert (
) -> MemoryBank: memory_bank.type == MemoryBankType.vector.value
bank_id = str(uuid.uuid4()) ), f"Only vector banks are supported {memory_bank.type}"
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
upsert_models(
self.cursor,
[
(bank.bank_id, bank),
],
)
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=memory_bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
banks = load_models(self.cursor, [bank_id], MemoryBank) bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not banks: if not bank:
return None raise ValueError(f"Bank {bank_id} not found")
bank = banks[0]
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleMemoryImpl(Memory):
class SampleMemoryImpl(Memory, RoutableProvider):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None: async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider # these are the memory banks the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
pass pass

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class WeaviateRequestProviderData(BaseModel):
weaviate_api_key: str
weaviate_cluster_url: str
class WeaviateConfig(BaseModel):
pass

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import weaviate
import weaviate.classes as wvc
from numpy.typing import NDArray
from weaviate.classes.init import Auth
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str):
self.client = client
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
data_objects = []
for i, chunk in enumerate(chunks):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_content": chunk.json(),
},
vector=embeddings[i].tolist(),
)
)
# Inserting chunks into a prespecified Weaviate collection
collection = self.client.collections.get(self.collection_name)
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_metadata=wvc.query.MetadataQuery(distance=True),
)
chunks = []
scores = []
for doc in results.objects:
chunk_json = doc.properties["chunk_content"]
try:
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {chunk_json}")
continue
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
def __init__(self, config: WeaviateConfig) -> None:
self.config = config
self.client_cache = {}
self.cache = {}
def _get_client(self) -> weaviate.Client:
provider_data = self.get_request_provider_data()
assert provider_data is not None, "Request provider data must be set"
assert isinstance(provider_data, WeaviateRequestProviderData)
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
if key in self.client_cache:
return self.client_cache[key]
client = weaviate.connect_to_weaviate_cloud(
cluster_url=provider_data.weaviate_cluster_url,
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
)
self.client_cache[key] = client
return client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for client in self.client_cache.values():
client.close()
async def register_memory_bank(
self,
memory_bank: MemoryBankDef,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
client = self._get_client()
# Create collection if it doesn't exist
if not client.collections.exists(memory_bank.identifier):
client.collections.create(
name=memory_bank.identifier,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
],
)
index = BankWithIndex(
bank=memory_bank,
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
)
self.cache[memory_bank.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
client = self._get_client()
if not client.collections.exists(bank_id):
raise ValueError(f"Collection with name `{bank_id}` not found")
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id),
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -7,14 +7,12 @@
import json import json
import logging import logging
import traceback
from typing import Any, Dict, List from typing import Any, Dict, List
import boto3 import boto3
from llama_stack.apis.safety import * # noqa from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from .config import BedrockSafetyConfig from .config import BedrockSafetyConfig
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SUPPORTED_SHIELD_TYPES = [ BEDROCK_SUPPORTED_SHIELDS = [
"bedrock_guardrail", ShieldType.generic_content_shield.value,
] ]
class BedrockSafetyAdapter(Safety, RoutableProvider): class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None: def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile: if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}") raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config self.config = config
self.registered_shields = []
async def initialize(self) -> None: async def initialize(self) -> None:
try: try:
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_shield(self, shield: ShieldDef) -> None:
for key in routing_keys: if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
if key not in SUPPORTED_SHIELD_TYPES: raise ValueError(f"Unsupported safety shield type: {shield.type}")
raise ValueError(f"Unknown safety shield type: {key}")
shield_params = shield.params
if "guardrailIdentifier" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES: shield_def = await self.shield_store.get_shield(shield_type)
raise ValueError(f"Unknown safety shield type: {shield_type}") if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [ ```content = [
@ -69,17 +79,9 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
""" """
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params: shield_params = shield_def.params
raise RuntimeError( logger.debug(f"run_shield::{shield_params}::messages={messages}")
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
# - convert the messages into format Bedrock expects # - convert the messages into format Bedrock expects
content_messages = [] content_messages = []
@ -90,12 +92,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
) )
response = self.boto_client.apply_guardrail( response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get("guardrailIdentifier"), guardrailIdentifier=shield_params["guardrailIdentifier"],
guardrailVersion=params.get("guardrailVersion"), guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages, content=content_messages,
) )
logger.debug(f"run_shield:: response: {response}::")
if response["action"] == "GUARDRAIL_INTERVENED": if response["action"] == "GUARDRAIL_INTERVENED":
user_message = "" user_message = ""
metadata = {} metadata = {}
@ -105,16 +106,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
for assessment in response["assessments"]: for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values # guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment) metadata = dict(assessment)
return SafetyViolation( return SafetyViolation(
user_message=user_message, user_message=user_message,
violation_level=ViolationLevel.ERROR, violation_level=ViolationLevel.ERROR,
metadata=metadata, metadata=metadata,
) )
except Exception:
error_str = traceback.format_exc()
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
)
return None return None

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleSafetyImpl(Safety):
class SampleSafetyImpl(Safety, RoutableProvider):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None: async def register_shield(self, shield: ShieldDef) -> None:
# these are the safety shields the Llama Stack will use to route requests to this provider # these are the safety shields the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
pass pass

View file

@ -6,26 +6,20 @@
from together import Together from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import ( from llama_stack.apis.safety import * # noqa: F403
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig from .config import TogetherSafetyConfig
SAFETY_SHIELD_TYPES = { TOGETHER_SHIELD_MODEL_MAP = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B", "llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
} }
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider): class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None: def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config self.config = config
@ -35,16 +29,20 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_shield(self, shield: ShieldDef) -> None:
for key in routing_keys: if shield.type != ShieldType.llama_guard.value:
if key not in SAFETY_SHIELD_TYPES: raise ValueError(f"Unsupported safety shield type: {shield.type}")
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:
if shield_type not in SAFETY_SHIELD_TYPES: shield_def = await self.shield_store.get_shield(shield_type)
raise ValueError(f"Unknown safety shield type: {shield_type}") if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
model = shield_def.params.get("model", "llama_guard")
if model not in TOGETHER_SHIELD_MODEL_MAP:
raise ValueError(f"Unsupported safety model: {model}")
together_api_key = None together_api_key = None
if self.config.api_key is not None: if self.config.api_key is not None:
@ -57,8 +55,6 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
) )
together_api_key = provider_data.together_api_key together_api_key = provider_data.together_api_key
model_name = SAFETY_SHIELD_TYPES[shield_type]
# messages can have role assistant or user # messages can have role assistant or user
api_messages = [] api_messages = []
for message in messages: for message in messages:
@ -66,7 +62,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
api_messages.append({"role": message.role, "content": message.content}) api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response( violation = await get_safety_response(
together_api_key, model_name, api_messages together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
) )
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)

View file

@ -43,23 +43,14 @@ class ProviderSpec(BaseModel):
description="Higher-level API surfaces may depend on other providers to provide their functionality", description="Higher-level API surfaces may depend on other providers to provide their functionality",
) )
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)
class RoutingTable(Protocol): class RoutingTable(Protocol):
def get_routing_keys(self) -> List[str]: ...
def get_provider_impl(self, routing_key: str) -> Any: ... def get_provider_impl(self, routing_key: str) -> Any: ...
class RoutableProvider(Protocol):
"""
A provider which sits behind the RoutingTable and can get routed to.
All Inference / Safety / Memory providers fall into this bucket.
"""
async def validate_routing_keys(self, keys: List[str]) -> None: ...
@json_schema_type @json_schema_type
class AdapterSpec(BaseModel): class AdapterSpec(BaseModel):
adapter_type: str = Field( adapter_type: str = Field(
@ -156,6 +147,10 @@ as being "Llama Stack compatible"
return None return None
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
# Can avoid this by using Pydantic computed_field # Can avoid this by using Pydantic computed_field
def remote_provider_spec( def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None api: Api, adapter: Optional[AdapterSpec] = None

View file

@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id) session_info = await self.storage.get_session_info(request.session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {request.session_id} not found") raise ValueError(f"Session {request.session_id} not found")
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None: if session_info.memory_bank_id is None:
memory_bank = await self.memory_api.create_memory_bank( bank_id = f"memory_bank_{session_id}"
name=f"memory_bank_{session_id}", memory_bank = VectorMemoryBankDef(
config=VectorMemoryBankConfig( identifier=bank_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
),
) )
bank_id = memory_bank.bank_id await self.memory_api.register_memory_bank(memory_bank)
await self.storage.add_memory_bank_to_session(session_id, bank_id) await self.storage.add_memory_bank_to_session(session_id, bank_id)
else: else:
bank_id = session_info.memory_bank_id bank_id = session_info.memory_bank_id
@ -673,7 +674,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _retrieve_context( async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment] self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids) ) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
bank_ids = [] bank_ids = []
memory = self._memory_tool_definition() memory = self._memory_tool_definition()
@ -722,12 +723,13 @@ class ChatAgent(ShieldRunnerMixin):
chunks = [c for r in results for c in r.chunks] chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores] scores = [s for r in results for s in r.scores]
if not chunks:
return None, bank_ids
# sort by score # sort by score
chunks, scores = zip( chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) *sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
) )
if not chunks:
return None, bank_ids
tokens = 0 tokens = 0
picked = [] picked = []

View file

@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id, session_id=session_id,
) )
async def create_agent_turn( def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
attachments: Optional[List[Attachment]] = None, attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
agent = await self.get_agent(agent_id)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest( request = AgentTurnCreateRequest(
agent_id=agent_id, agent_id=agent_id,
session_id=session_id, session_id=session_id,
messages=messages, messages=messages,
attachments=attachments, attachments=attachments,
stream=stream, stream=True,
) )
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.create_and_execute_turn(request): async for event in agent.create_and_execute_turn(request):
yield event yield event

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import CodeShieldConfig
async def get_provider_impl(config: CodeShieldConfig, deps):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from termcolor import cprint
from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403
class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type != ShieldType.code_scanner.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
from codeshield.cs import CodeShield
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
violation = None
if result.is_insecure:
violation = SafetyViolation(
violation_level=(ViolationLevel.ERROR),
user_message="Sorry, I found security concerns in the code.",
metadata={
"violation_type": ",".join(
[issue.pattern_id for issue in result.issues_found]
)
},
)
return RunShieldResponse(violation=violation)

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class CodeShieldConfig(BaseModel):
pass

View file

@ -43,13 +43,12 @@ class MetaReferenceEvalsImpl(Evals):
print("generation start") print("generation start")
for msg in x1[:5]: for msg in x1[:5]:
print("generation for msg: ", msg) print("generation for msg: ", msg)
response = self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model=model, model=model,
messages=[msg], messages=[msg],
stream=False, stream=False,
) )
async for x in response: generation_outputs.append(response.completion_message.content)
generation_outputs.append(x.completion_message.content)
x2 = task_impl.postprocess(generation_outputs) x2 = task_impl.postprocess(generation_outputs)
eval_results = task_impl.score(x2) eval_results = task_impl.score(x2)

View file

@ -297,7 +297,7 @@ class Llama:
token=next_token[0].item(), token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()), text=self.tokenizer.decode(next_token.tolist()),
logprobs=( logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist() token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs if logprobs
else None else None
), ),

View file

@ -6,15 +6,14 @@
import asyncio import asyncio
from typing import AsyncIterator, List, Union from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.inference.prompt_adapter import (
from llama_stack.providers.utils.inference.augment_messages import ( chat_completion_request_to_messages,
augment_messages_for_tools,
) )
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
@ -25,7 +24,7 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, RoutableProvider): class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None: def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)
@ -35,21 +34,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol
async def initialize(self) -> None: async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_model(self, model: ModelDef) -> None:
assert ( if model.identifier != self.model.descriptor():
len(routing_keys) == 1 raise RuntimeError(
), f"Only one routing key is supported {routing_keys}" f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
assert routing_keys[0] == self.config.model )
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() self.generator.stop()
# hm, when stream=False, we should not be doing SSE :/ which is what the def chat_completion(
# top-level server is going to do. make the typing more specific here
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -59,9 +57,10 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[ ) -> AsyncGenerator:
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] if logprobs:
]: assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API) # wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
@ -74,7 +73,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
logprobs=logprobs, logprobs=logprobs,
) )
messages = augment_messages_for_tools(request)
model = resolve_model(request.model) model = resolve_model(request.model)
if model is None: if model is None:
raise RuntimeError( raise RuntimeError(
@ -88,8 +86,64 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
if SEMAPHORE.locked(): if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported") raise RuntimeError("Only one concurrent request is supported")
async with SEMAPHORE:
if request.stream: if request.stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
@ -99,10 +153,7 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
tokens = [] tokens = []
logprobs = [] logprobs = []
stop_reason = None stop_reason = None
buffer = ""
ipython = False ipython = False
for token_result in self.generator.chat_completion( for token_result in self.generator.chat_completion(
@ -113,10 +164,9 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
logprobs=request.logprobs, logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format, tool_prompt_format=request.tool_prompt_format,
): ):
buffer += token_result.text
tokens.append(token_result.token) tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"): if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True ipython = True
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -127,13 +177,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
), ),
) )
) )
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
continue continue
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":
@ -154,23 +197,32 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
delta = text delta = text
if stop_reason is None: if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=delta, delta=delta,
stop_reason=stop_reason, stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
) )
) )
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message( message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason tokens, stop_reason
) )
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
@ -203,10 +255,3 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)

View file

@ -13,15 +13,15 @@ from typing import Optional
import torch import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from termcolor import cprint from termcolor import cprint
from torch import Tensor from torch import Tensor
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.config import ( from llama_stack.providers.impls.meta_reference.inference.config import (
CheckpointQuantizationFormat,
MetaReferenceImplConfig, MetaReferenceImplConfig,
) )

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -14,7 +13,6 @@ import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -63,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, RoutableProvider): class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig) -> None:
self.config = config self.config = config
self.cache = {} self.cache = {}
@ -72,37 +70,18 @@ class FaissMemoryImpl(Memory, RoutableProvider):
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_memory_bank(
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self, self,
name: str, memory_bank: MemoryBankDef,
config: MemoryBankConfig, ) -> None:
url: Optional[URL] = None,
) -> MemoryBank:
assert url is None, "URL is not supported for this implementation"
assert ( assert (
config.type == MemoryBankType.vector.value memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}" ), f"Only vector banks are supported {memory_bank.type}"
bank_id = str(uuid.uuid4()) index = BankWithIndex(
bank = MemoryBank( bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
bank_id=bank_id,
name=name,
config=config,
url=url,
) )
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)) self.cache[memory_bank.identifier] = index
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
index = self.cache.get(bank_id)
if index is None:
return None
return index.bank
async def insert_documents( async def insert_documents(
self, self,

View file

@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str:
return interleaved_text_media_as_str(message.content) return interleaved_text_media_as_str(message.content)
# For shields that operate on simple strings
class TextShield(ShieldBase): class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str: def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages]) return "\n".join([message_content_as_str(m) for m in messages])
@ -56,9 +55,3 @@ class TextShield(ShieldBase):
@abstractmethod @abstractmethod
async def run_impl(self, text: str) -> ShieldResponse: async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError() raise NotImplementedError()
class DummyShield(TextShield):
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(is_violation=False)

View file

@ -9,23 +9,19 @@ from typing import List, Optional
from llama_models.sku_list import CoreModelId, safety_models from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, validator from pydantic import BaseModel, field_validator
class MetaReferenceShieldType(Enum): class PromptGuardType(Enum):
llama_guard = "llama_guard" injection = "injection"
code_scanner_guard = "code_scanner_guard" jailbreak = "jailbreak"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
class LlamaGuardShieldConfig(BaseModel): class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-1B" model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = [] excluded_categories: List[str] = []
disable_input_check: bool = False
disable_output_check: bool = False
@validator("model") @field_validator("model")
@classmethod @classmethod
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
permitted_models = [ permitted_models = [
@ -47,10 +43,6 @@ class LlamaGuardShieldConfig(BaseModel):
return model return model
class PromptGuardShieldConfig(BaseModel):
model: str = "Prompt-Guard-86M"
class SafetyConfig(BaseModel): class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None enable_prompt_guard: Optional[bool] = False

View file

@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
model: str, model: str,
inference_api: Inference, inference_api: Inference,
excluded_categories: List[str] = None, excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
on_violation_action: OnViolationAction = OnViolationAction.RAISE, on_violation_action: OnViolationAction = OnViolationAction.RAISE,
): ):
super().__init__(on_violation_action) super().__init__(on_violation_action)
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
self.model = model self.model = model
self.inference_api = inference_api self.inference_api = inference_api
self.excluded_categories = excluded_categories self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
def check_unsafe_response(self, response: str) -> Optional[str]: def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
async def run(self, messages: List[Message]) -> ShieldResponse: async def run(self, messages: List[Message]) -> ShieldResponse:
messages = self.validate_messages(messages) messages = self.validate_messages(messages)
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
is_violation=False,
)
if self.model == CoreModelId.llama_guard_3_11b_vision.value: if self.model == CoreModelId.llama_guard_3_11b_vision.value:
shield_input_message = self.build_vision_shield_input(messages) shield_input_message = self.build_vision_shield_input(messages)

View file

@ -6,56 +6,43 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api, RoutableProvider from llama_stack.distribution.datatypes import Api
from llama_stack.providers.impls.meta_reference.safety.shields.base import ( from .base import OnViolationAction, ShieldBase
OnViolationAction, from .config import SafetyConfig
) from .llama_guard import LlamaGuardShield
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
from .config import MetaReferenceShieldType, SafetyConfig
from .shields import (
CodeScannerShield,
InjectionShield,
JailbreakShield,
LlamaGuardShield,
PromptGuardShield,
ShieldBase,
)
def resolve_and_get_path(model_name: str) -> str: PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model.descriptor())
return model_dir
class MetaReferenceSafetyImpl(Safety, RoutableProvider): class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig, deps) -> None: def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config self.config = config
self.inference_api = deps[Api.inference] self.inference_api = deps[Api.inference]
self.available_shields = []
if config.llama_guard_shield:
self.available_shields.append(ShieldType.llama_guard.value)
if config.enable_prompt_guard:
self.available_shields.append(ShieldType.prompt_guard.value)
async def initialize(self) -> None: async def initialize(self) -> None:
shield_cfg = self.config.prompt_guard_shield if self.config.enable_prompt_guard:
if shield_cfg is not None: model_dir = model_local_dir(PROMPT_GUARD_MODEL)
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir) _ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_shield(self, shield: ShieldDef) -> None:
available_shields = [v.value for v in MetaReferenceShieldType] if shield.type not in self.available_shields:
for key in routing_keys: raise ValueError(f"Unsupported safety shield type: {shield.type}")
if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield( async def run_shield(
self, self,
@ -63,10 +50,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
available_shields = [v.value for v in MetaReferenceShieldType] shield_def = await self.shield_store.get_shield(shield_type)
assert shield_type in available_shields, f"Unknown shield {shield_type}" if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type)) shield = self.get_shield_impl(shield_def)
messages = messages.copy() messages = messages.copy()
# some shields like llama-guard require the first message to be a user message # some shields like llama-guard require the first message to be a user message
@ -92,34 +80,22 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
cfg = self.config if shield.type == ShieldType.llama_guard.value:
if typ == MetaReferenceShieldType.llama_guard: cfg = self.config.llama_guard_shield
cfg = cfg.llama_guard_shield
assert (
cfg is not None
), "Cannot use LlamaGuardShield since not present in config"
return LlamaGuardShield( return LlamaGuardShield(
model=cfg.model, model=cfg.model,
inference_api=self.inference_api, inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories, excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
) )
elif typ == MetaReferenceShieldType.jailbreak_shield: elif shield.type == ShieldType.prompt_guard.value:
assert ( model_dir = model_local_dir(PROMPT_GUARD_MODEL)
cfg.prompt_guard_shield is not None subtype = shield.params.get("prompt_guard_type", "injection")
), "Cannot use Jailbreak Shield since Prompt Guard not present in config" if subtype == "injection":
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif typ == MetaReferenceShieldType.injection_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return InjectionShield.instance(model_dir) return InjectionShield.instance(model_dir)
elif typ == MetaReferenceShieldType.code_scanner_guard: elif subtype == "jailbreak":
return CodeScannerShield.instance() return JailbreakShield.instance(model_dir)
else: else:
raise ValueError(f"Unknown shield type: {typ}") raise ValueError(f"Unknown prompt guard type: {subtype}")
else:
raise ValueError(f"Unknown shield type: {shield.type}")

View file

@ -1,33 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# supress warnings and spew of logs from hugging face
import transformers
from .base import ( # noqa: F401
DummyShield,
OnViolationAction,
ShieldBase,
ShieldResponse,
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
transformers.logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from .base import ShieldResponse, TextShield
class CodeScannerShield(TextShield):
async def run_impl(self, text: str) -> ShieldResponse:
from codeshield.cs import CodeShield
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
),
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(is_violation=False)

View file

@ -0,0 +1,11 @@
from typing import Any
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
@json_schema_type
class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""
model: str = Field(
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model

View file

@ -0,0 +1,241 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import uuid
from typing import Any
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMConfig
log = logging.getLogger(__name__)
def _random_uuid() -> str:
return str(uuid.uuid4().hex)
def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
"""Convert sampling params to vLLM sampling params."""
if sampling_params is None:
return SamplingParams()
# TODO convert what I saw in my first test ... but surely there's more to do here
kwargs = {
"temperature": sampling_params.temperature,
}
if sampling_params.top_k >= 1:
kwargs["top_k"] = sampling_params.top_k
if sampling_params.top_p:
kwargs["top_p"] = sampling_params.top_p
if sampling_params.max_tokens >= 1:
kwargs["max_tokens"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return SamplingParams(**kwargs)
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
"""Inference implementation for vLLM."""
HF_MODEL_MAPPINGS = {
# TODO: seems like we should be able to build this table dynamically ...
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
}
def __init__(self, config: VLLMConfig):
Inference.__init__(self)
ModelRegistryHelper.__init__(
self,
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
)
self.config = config
self.engine = None
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
async def initialize(self):
"""Initialize the vLLM inference adapter."""
log.info("Initializing vLLM inference adapter")
# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
if "VLLM_NO_USAGE_STATS" not in os.environ:
os.environ["VLLM_NO_USAGE_STATS"] = "1"
hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model)
# TODO -- there are a ton of options supported here ...
engine_args = AsyncEngineArgs()
engine_args.model = hf_model
# We will need a new config item for this in the future if model support is more broad
# than it is today (llama only)
engine_args.tokenizer = hf_model
engine_args.tensor_parallel_size = self.config.tensor_parallel_size
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
async def shutdown(self):
"""Shutdown the vLLM inference adapter."""
log.info("Shutting down vLLM inference adapter")
if self.engine:
self.engine.shutdown_background_loop()
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Any | None = ...,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
log.info("vLLM completion")
messages = [UserMessage(content=content)]
return self.chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
def chat_completion(
self,
model: str,
messages: list[Message],
sampling_params: Any | None = ...,
tools: list[ToolDefinition] | None = ...,
tool_choice: ToolChoice | None = ...,
tool_prompt_format: ToolPromptFormat | None = ...,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
log.info("vLLM chat completion")
assert self.engine is not None
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = chat_completion_request_to_prompt(request, self.formatter)
vllm_sampling_params = _vllm_sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
if stream:
return self._stream_chat_completion(request, results_generator)
else:
return self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> ChatCompletionResponse:
outputs = [o async for o in results_generator]
final_output = outputs[-1]
assert final_output is not None
outputs = final_output.outputs
finish_reason = outputs[-1].stop_reason
choice = OpenAICompatCompletionChoice(
finish_reason=finish_reason,
text="".join([output.text for output in outputs]),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
async def _generate_and_convert_to_openai_compat():
async for chunk in results_generator:
if not chunk.outputs:
log.warning("Empty chunk received")
continue
text = "".join([output.text for output in chunk.outputs])
choice = OpenAICompatCompletionChoice(
finish_reason=chunk.outputs[-1].stop_reason,
text=text,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
async def embeddings(
self, model: str, contents: list[InterleavedTextMedia]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
raise NotImplementedError()

View file

@ -41,6 +41,7 @@ def available_providers() -> List[ProviderSpec]:
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="ollama", adapter_type="ollama",
pip_packages=["ollama"], pip_packages=["ollama"],
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.adapters.inference.ollama", module="llama_stack.providers.adapters.inference.ollama",
), ),
), ),
@ -103,4 +104,24 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig", config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
), ),
), ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="databricks",
pip_packages=[
"openai",
],
module="llama_stack.providers.adapters.inference.databricks",
config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig",
),
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.impls.vllm",
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
),
] ]

View file

@ -56,6 +56,16 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig", config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
), ),
), ),
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_type="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
module="llama_stack.providers.adapters.memory.weaviate",
config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
),
),
remote_provider_spec( remote_provider_spec(
api=Api.memory, api=Api.memory,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -21,7 +21,6 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety, api=Api.safety,
provider_type="meta-reference", provider_type="meta-reference",
pip_packages=[ pip_packages=[
"codeshield",
"transformers", "transformers",
"torch --index-url https://download.pytorch.org/whl/cpu", "torch --index-url https://download.pytorch.org/whl/cpu",
], ],
@ -61,4 +60,14 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator", provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
), ),
), ),
InlineProviderSpec(
api=Api.safety,
provider_type="meta-reference/codeshield",
pip_packages=[
"codeshield",
],
module="llama_stack.providers.impls.meta_reference.codeshield",
config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig",
api_dependencies=[],
),
] ]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,25 @@
providers:
- provider_id: test-ollama
provider_type: remote::ollama
config:
host: localhost
port: 11434
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://localhost:7001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002
- provider_id: test-together
provider_type: remote::together
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-together":
together_api_key:
0xdeadbeefputrealapikeyhere

View file

@ -0,0 +1,243 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import itertools
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
# --tb=short --disable-warnings
# ```
def group_chunks(response):
return {
event_type: list(group)
for event_type, group in itertools.groupby(
response, key=lambda chunk: chunk.event.event_type
)
}
Llama_8B = "Llama3.1-8B-Instruct"
Llama_3B = "Llama3.2-3B-Instruct"
def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
],
ids=lambda d: d["model"],
)
async def inference_settings(request):
model = request.param["model"]
impls = await resolve_impls_for_test(
Api.inference,
models=[
ModelDef(
identifier=model,
llama_model=model,
)
],
)
return {
"impl": impls[Api.inference],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if "Llama3.1" in model
else ToolPromptFormat.python_list
),
},
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = await inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
messages=sample_messages,
stream=True,
**inference_settings["common_params"],
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = await inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=True,
**inference_settings["common_params"],
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
model = inference_settings["common_params"]["model"]
if "Llama3.1" in model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]

View file

@ -8,7 +8,7 @@ import unittest
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403 from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.augment_messages import augment_messages_for_tools from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
MODEL = "Llama3.1-8B-Instruct" MODEL = "Llama3.1-8B-Instruct"
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content), UserMessage(content=content),
], ],
) )
messages = augment_messages_for_tools(request) messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2) self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content) self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search), ToolDefinition(tool_name=BuiltinTool.brave_search),
], ],
) )
messages = augment_messages_for_tools(request) messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2) self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content) self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
tool_prompt_format=ToolPromptFormat.json, tool_prompt_format=ToolPromptFormat.json,
) )
messages = augment_messages_for_tools(request) messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 3) self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content) self.assertTrue("Environment: ipython" in messages[0].content)
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
), ),
], ],
) )
messages = augment_messages_for_tools(request) messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 3) self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content) self.assertTrue("Environment: ipython" in messages[0].content)
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter), ToolDefinition(tool_name=BuiltinTool.code_interpreter),
], ],
) )
messages = augment_messages_for_tools(request) messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2, messages) self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt)) self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,24 @@
providers:
- provider_id: test-faiss
provider_type: meta-reference
config: {}
- provider_id: test-chroma
provider_type: remote::chroma
config:
host: localhost
port: 6001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002
- provider_id: test-weaviate
provider_type: remote::weaviate
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-weaviate":
weaviate_api_key: 0xdeadbeefputrealapikeyhere
weaviate_cluster_url: http://foobarbaz

View file

@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/memory/test_memory.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session")
async def memory_impl():
impls = await resolve_impls_for_test(
Api.memory,
memory_banks=[],
)
return impls[Api.memory]
@pytest.fixture
def sample_documents():
return [
MemoryBankDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
MemoryBankDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
MemoryBankDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
MemoryBankDocument(
document_id="doc4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
),
]
async def register_memory_bank(memory_impl: Memory):
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await memory_impl.register_memory_bank(bank)
@pytest.mark.asyncio
async def test_query_documents(memory_impl, sample_documents):
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await register_memory_bank(memory_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
query1 = "programming language"
response1 = await memory_impl.query_documents("test_bank", query1)
assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Test case 3: Query with semantic similarity
query3 = "AI and brain-inspired computing"
response3 = await memory_impl.query_documents("test_bank", query3)
assert_valid_response(response3)
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
# Test case 4: Query with limit on number of results
query4 = "computer"
params4 = {"max_chunks": 2}
response4 = await memory_impl.query_documents("test_bank", query4, params4)
assert_valid_response(response4)
assert len(response4.chunks) <= 2
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.5}
response5 = await memory_impl.query_documents("test_bank", query5, params5)
assert_valid_response(response5)
assert all(score >= 0.5 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse):
assert isinstance(response, QueryDocumentsResponse)
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
for chunk in response.chunks:
assert isinstance(chunk.content, str)
assert chunk.document_id is not None

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
async def resolve_impls_for_test(
api: Api,
models: List[ModelDef] = None,
memory_banks: List[MemoryBankDef] = None,
shields: List[ShieldDef] = None,
):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
)
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
if len(providers_by_id) == 0:
raise ValueError("No providers found in config file")
if "PROVIDER_ID" in os.environ:
provider_id = os.environ["PROVIDER_ID"]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
models = models or []
shields = shields or []
memory_banks = memory_banks or []
models = [
ModelDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in models
]
shields = [
ShieldDef(
**{
**s.dict(),
"provider_id": provider_id,
}
)
for s in shields
]
memory_banks = [
MemoryBankDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in memory_banks
]
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api],
providers={api.value: [Provider(**provider)]},
models=models,
memory_banks=memory_banks,
shields=shields,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
class ModelRegistryHelper:
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
def map_to_provider_model(self, identifier: str) -> str:
model = resolve_model(identifier)
if not model:
raise ValueError(f"Unknown model: `{identifier}`")
if identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
)
return self.stack_to_provider_models_map[identifier]
async def register_model(self, model: ModelDef) -> None:
if model.identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)

View file

@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_stack.apis.inference import * # noqa: F403
from pydantic import BaseModel
class OpenAICompatCompletionChoiceDelta(BaseModel):
content: str
class OpenAICompatCompletionChoice(BaseModel):
finish_reason: Optional[str] = None
text: Optional[str] = None
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice]
def get_sampling_options(request: ChatCompletionRequest) -> dict:
options = {}
if params := request.sampling_params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr):
options[attr] = getattr(params, attr)
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
return options
def text_from_choice(choice) -> str:
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content
return choice.text
def process_chat_completion_response(
request: ChatCompletionRequest,
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
) -> ChatCompletionResponse:
choice = response.choices[0]
stop_reason = None
if reason := choice.finish_reason:
if reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif reason == "eom":
stop_reason = StopReason.end_of_message
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
completion_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), stop_reason
)
return ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
async def process_chat_completion_stream_response(
request: ChatCompletionRequest,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
if ipython:
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

Some files were not shown because too many files have changed in this diff Show more