mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
rebase on top of registry
This commit is contained in:
commit
6abef716dd
107 changed files with 4813 additions and 3587 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ xcuserdata/
|
|||
Package.resolved
|
||||
*.pte
|
||||
*.ipynb_checkpoints*
|
||||
.idea
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Contributing to Llama-Models
|
||||
# Contributing to Llama-Stack
|
||||
We want to make contributing to this project as easy and transparent as
|
||||
possible.
|
||||
|
||||
|
@ -32,7 +32,7 @@ outlined on that page and do not file a public issue.
|
|||
* ...
|
||||
|
||||
## Tips
|
||||
* If you are developing with a llama-models repository checked out and need your distribution to reflect changes from there, set `LLAMA_MODELS_DIR` to that dir when running any of the `llama` CLI commands.
|
||||
* If you are developing with a llama-stack repository checked out and need your distribution to reflect changes from there, set `LLAMA_STACK_DIR` to that dir when running any of the `llama` CLI commands.
|
||||
|
||||
## License
|
||||
By contributing to Llama, you agree that your contributions will be licensed
|
||||
|
|
17
README.md
17
README.md
|
@ -81,11 +81,24 @@ cd llama-stack
|
|||
$CONDA_PREFIX/bin/pip install -e .
|
||||
```
|
||||
|
||||
## The Llama CLI
|
||||
## Documentations
|
||||
|
||||
The `llama` CLI makes it easy to work with the Llama Stack set of tools, including installing and running Distributions, downloading models, studying model prompt formats, etc. Please see the [CLI reference](docs/cli_reference.md) for details. Please see the [Getting Started](docs/getting_started.md) guide for running a Llama Stack server.
|
||||
The `llama` CLI makes it easy to work with the Llama Stack set of tools. Please find the following docs for details.
|
||||
|
||||
* [CLI reference](docs/cli_reference.md)
|
||||
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
|
||||
* [Getting Started](docs/getting_started.md)
|
||||
* Guide to build and run a Llama Stack server.
|
||||
* [Contributing](CONTRIBUTING.md)
|
||||
|
||||
|
||||
## Llama Stack Client SDK
|
||||
|
||||
| **Language** | **Client SDK** | **Package** |
|
||||
| :----: | :----: | :----: |
|
||||
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [](https://pypi.org/project/llama_stack_client/)
|
||||
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) |
|
||||
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [](https://npmjs.org/package/llama-stack-client)
|
||||
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) |
|
||||
|
||||
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
|
||||
|
|
5
SECURITY.md
Normal file
5
SECURITY.md
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Please report vulnerabilities to our bug bounty program at https://bugbounty.meta.com/
|
|
@ -1,6 +1,6 @@
|
|||
# Llama CLI Reference
|
||||
|
||||
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
|
||||
The `llama` CLI tool helps you setup and use the Llama Stack & agentic systems. It should be available on your path after installing the `llama-stack` package.
|
||||
|
||||
### Subcommands
|
||||
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
|
||||
|
@ -117,9 +117,9 @@ llama download --source meta --model-id Llama-Guard-3-1B --meta-url META_URL
|
|||
Essentially, the same commands above work, just replace `--source meta` with `--source huggingface`.
|
||||
|
||||
```bash
|
||||
llama download --source huggingface --model-id Meta-Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
|
||||
llama download --source huggingface --model-id Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
|
||||
|
||||
llama download --source huggingface --model-id Meta-Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
|
||||
llama download --source huggingface --model-id Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
|
||||
|
||||
llama download --source huggingface --model-id Llama-Guard-3-1B --ignore-patterns *original*
|
||||
llama download --source huggingface --model-id Prompt-Guard-86M --ignore-patterns *original*
|
||||
|
@ -215,9 +215,8 @@ You can even run `llama model prompt-format` see all of the templates and their
|
|||
```
|
||||
llama model prompt-format -m Llama3.2-3B-Instruct
|
||||
```
|
||||
<p align="center">
|
||||
<img width="719" alt="image" src="https://github.com/user-attachments/assets/c5332026-8c0b-4edc-b438-ec60cd7ca554">
|
||||
</p>
|
||||

|
||||
|
||||
|
||||
|
||||
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
|
||||
|
@ -230,7 +229,7 @@ You will be shown a Markdown formatted description of the model interface and ho
|
|||
- Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
|
||||
|
||||
### Step 3.1 Build
|
||||
In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
|
||||
In the following steps, imagine we'll be working with a `Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
|
||||
- `name`: the name for our distribution (e.g. `8b-instruct`)
|
||||
- `image_type`: our build image type (`conda | docker`)
|
||||
- `distribution_spec`: our distribution specs for specifying API providers
|
||||
|
@ -365,7 +364,7 @@ llama stack configure [ <name> | <docker-image-name> | <path/to/name.build.yaml>
|
|||
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml
|
||||
|
||||
Configuring API: inference (meta-reference)
|
||||
Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required):
|
||||
Enter value for model (existing: Llama3.1-8B-Instruct) (required):
|
||||
Enter value for quantization (optional):
|
||||
Enter value for torch_seed (optional):
|
||||
Enter value for max_seq_len (existing: 4096) (required):
|
||||
|
@ -397,7 +396,7 @@ YAML configuration has been written to ~/.llama/builds/conda/8b-instruct-run.yam
|
|||
After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/8b-instruct-run.yaml` with the following contents. You may edit this file to change the settings.
|
||||
|
||||
As you can see, we did basic configuration above and configured:
|
||||
- inference to run on model `Meta-Llama3.1-8B-Instruct` (obtained from `llama model list`)
|
||||
- inference to run on model `Llama3.1-8B-Instruct` (obtained from `llama model list`)
|
||||
- Llama Guard safety shield with model `Llama-Guard-3-1B`
|
||||
- Prompt Guard safety shield with model `Prompt-Guard-86M`
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# llama-stack
|
||||
|
||||
[](https://pypi.org/project/llama-stack/)
|
||||
[](https://discord.gg/TZAAYNVtrU)
|
||||
[](https://discord.gg/llama-stack)
|
||||
|
||||
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
|
||||
|
||||
|
@ -66,8 +66,17 @@ This guides allows you to quickly get started with building and running a Llama
|
|||
You may also checkout this [notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) for trying out out demo scripts.
|
||||
|
||||
## Quick Cheatsheet
|
||||
- Quick 3 line command to build and start a LlamaStack server using our Meta Reference implementation for all API endpoints with `conda` as build type.
|
||||
|
||||
#### Via docker
|
||||
```
|
||||
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack-local-gpu
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> `~/.llama` should be the path containing downloaded weights of Llama models.
|
||||
|
||||
|
||||
#### Via conda
|
||||
**`llama stack build`**
|
||||
- You'll be prompted to enter build information interactively.
|
||||
```
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -580,63 +580,6 @@ components:
|
|||
- uuid
|
||||
- dataset
|
||||
type: object
|
||||
CreateMemoryBankRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
config:
|
||||
oneOf:
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
chunk_size_in_tokens:
|
||||
type: integer
|
||||
embedding_model:
|
||||
type: string
|
||||
overlap_size_in_tokens:
|
||||
type: integer
|
||||
type:
|
||||
const: vector
|
||||
default: vector
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- embedding_model
|
||||
- chunk_size_in_tokens
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: keyvalue
|
||||
default: keyvalue
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: keyword
|
||||
default: keyword
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: graph
|
||||
default: graph
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
name:
|
||||
type: string
|
||||
url:
|
||||
$ref: '#/components/schemas/URL'
|
||||
required:
|
||||
- name
|
||||
- config
|
||||
type: object
|
||||
DPOAlignmentConfig:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -739,14 +682,6 @@ components:
|
|||
- rank
|
||||
- alpha
|
||||
type: object
|
||||
DropMemoryBankRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
bank_id:
|
||||
type: string
|
||||
required:
|
||||
- bank_id
|
||||
type: object
|
||||
EmbeddingsRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -908,6 +843,21 @@ components:
|
|||
required:
|
||||
- document_ids
|
||||
type: object
|
||||
GraphMemoryBankDef:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
identifier:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
type:
|
||||
const: graph
|
||||
default: graph
|
||||
type: string
|
||||
required:
|
||||
- identifier
|
||||
- type
|
||||
type: object
|
||||
HealthInfo:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -973,6 +923,36 @@ components:
|
|||
- bank_id
|
||||
- documents
|
||||
type: object
|
||||
KeyValueMemoryBankDef:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
identifier:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
type:
|
||||
const: keyvalue
|
||||
default: keyvalue
|
||||
type: string
|
||||
required:
|
||||
- identifier
|
||||
- type
|
||||
type: object
|
||||
KeywordMemoryBankDef:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
identifier:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
type:
|
||||
const: keyword
|
||||
default: keyword
|
||||
type: string
|
||||
required:
|
||||
- identifier
|
||||
- type
|
||||
type: object
|
||||
LogEventRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -1015,66 +995,6 @@ components:
|
|||
- rank
|
||||
- alpha
|
||||
type: object
|
||||
MemoryBank:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
bank_id:
|
||||
type: string
|
||||
config:
|
||||
oneOf:
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
chunk_size_in_tokens:
|
||||
type: integer
|
||||
embedding_model:
|
||||
type: string
|
||||
overlap_size_in_tokens:
|
||||
type: integer
|
||||
type:
|
||||
const: vector
|
||||
default: vector
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- embedding_model
|
||||
- chunk_size_in_tokens
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: keyvalue
|
||||
default: keyvalue
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: keyword
|
||||
default: keyword
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: graph
|
||||
default: graph
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
name:
|
||||
type: string
|
||||
url:
|
||||
$ref: '#/components/schemas/URL'
|
||||
required:
|
||||
- bank_id
|
||||
- name
|
||||
- config
|
||||
type: object
|
||||
MemoryBankDocument:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -1107,41 +1027,6 @@ components:
|
|||
- content
|
||||
- metadata
|
||||
type: object
|
||||
MemoryBankSpec:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
bank_type:
|
||||
$ref: '#/components/schemas/MemoryBankType'
|
||||
provider_config:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
config:
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
type: object
|
||||
provider_type:
|
||||
type: string
|
||||
required:
|
||||
- provider_type
|
||||
- config
|
||||
type: object
|
||||
required:
|
||||
- bank_type
|
||||
- provider_config
|
||||
type: object
|
||||
MemoryBankType:
|
||||
enum:
|
||||
- vector
|
||||
- keyvalue
|
||||
- keyword
|
||||
- graph
|
||||
type: string
|
||||
MemoryRetrievalStep:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -1349,36 +1234,18 @@ components:
|
|||
- value
|
||||
- unit
|
||||
type: object
|
||||
Model:
|
||||
description: The model family and SKU of the model along with other parameters
|
||||
corresponding to the model.
|
||||
ModelServingSpec:
|
||||
ModelDef:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
identifier:
|
||||
type: string
|
||||
llama_model:
|
||||
$ref: '#/components/schemas/Model'
|
||||
provider_config:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
config:
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
type: object
|
||||
provider_type:
|
||||
type: string
|
||||
required:
|
||||
- provider_type
|
||||
- config
|
||||
type: object
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
required:
|
||||
- identifier
|
||||
- llama_model
|
||||
- provider_config
|
||||
type: object
|
||||
OptimizerConfig:
|
||||
additionalProperties: false
|
||||
|
@ -1554,13 +1421,13 @@ components:
|
|||
ProviderInfo:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
description:
|
||||
provider_id:
|
||||
type: string
|
||||
provider_type:
|
||||
type: string
|
||||
required:
|
||||
- provider_id
|
||||
- provider_type
|
||||
- description
|
||||
type: object
|
||||
QLoraFinetuningConfig:
|
||||
additionalProperties: false
|
||||
|
@ -1650,6 +1517,56 @@ components:
|
|||
enum:
|
||||
- dpo
|
||||
type: string
|
||||
RegisterMemoryBankRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
memory_bank:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/VectorMemoryBankDef'
|
||||
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
|
||||
- $ref: '#/components/schemas/KeywordMemoryBankDef'
|
||||
- $ref: '#/components/schemas/GraphMemoryBankDef'
|
||||
required:
|
||||
- memory_bank
|
||||
type: object
|
||||
RegisterModelRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
model:
|
||||
$ref: '#/components/schemas/ModelDef'
|
||||
required:
|
||||
- model
|
||||
type: object
|
||||
RegisterShieldRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
shield:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
identifier:
|
||||
type: string
|
||||
params:
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
type: object
|
||||
provider_id:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
required:
|
||||
- identifier
|
||||
- type
|
||||
- params
|
||||
type: object
|
||||
required:
|
||||
- shield
|
||||
type: object
|
||||
RestAPIExecutionConfig:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -1728,7 +1645,7 @@ components:
|
|||
properties:
|
||||
method:
|
||||
type: string
|
||||
providers:
|
||||
provider_types:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
|
@ -1737,7 +1654,7 @@ components:
|
|||
required:
|
||||
- route
|
||||
- method
|
||||
- providers
|
||||
- provider_types
|
||||
type: object
|
||||
RunShieldRequest:
|
||||
additionalProperties: false
|
||||
|
@ -1892,7 +1809,11 @@ components:
|
|||
additionalProperties: false
|
||||
properties:
|
||||
memory_bank:
|
||||
$ref: '#/components/schemas/MemoryBank'
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/VectorMemoryBankDef'
|
||||
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
|
||||
- $ref: '#/components/schemas/KeywordMemoryBankDef'
|
||||
- $ref: '#/components/schemas/GraphMemoryBankDef'
|
||||
session_id:
|
||||
type: string
|
||||
session_name:
|
||||
|
@ -1935,33 +1856,29 @@ components:
|
|||
- step_id
|
||||
- step_type
|
||||
type: object
|
||||
ShieldSpec:
|
||||
ShieldDef:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
provider_config:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
config:
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
type: object
|
||||
provider_type:
|
||||
type: string
|
||||
required:
|
||||
- provider_type
|
||||
- config
|
||||
identifier:
|
||||
type: string
|
||||
params:
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
type: object
|
||||
shield_type:
|
||||
provider_id:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
required:
|
||||
- shield_type
|
||||
- provider_config
|
||||
- identifier
|
||||
- type
|
||||
- params
|
||||
type: object
|
||||
SpanEndPayload:
|
||||
additionalProperties: false
|
||||
|
@ -2571,6 +2488,29 @@ components:
|
|||
- role
|
||||
- content
|
||||
type: object
|
||||
VectorMemoryBankDef:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
chunk_size_in_tokens:
|
||||
type: integer
|
||||
embedding_model:
|
||||
type: string
|
||||
identifier:
|
||||
type: string
|
||||
overlap_size_in_tokens:
|
||||
type: integer
|
||||
provider_id:
|
||||
type: string
|
||||
type:
|
||||
const: vector
|
||||
default: vector
|
||||
type: string
|
||||
required:
|
||||
- identifier
|
||||
- type
|
||||
- embedding_model
|
||||
- chunk_size_in_tokens
|
||||
type: object
|
||||
ViolationLevel:
|
||||
enum:
|
||||
- info
|
||||
|
@ -2604,7 +2544,7 @@ info:
|
|||
description: "This is the specification of the llama stack that provides\n \
|
||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||
\ to\n best leverage Llama Models. The specification is still in\
|
||||
\ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
|
||||
\ draft and subject to change.\n Generated at 2024-10-08 15:18:57.600111"
|
||||
title: '[DRAFT] Llama Stack Specification'
|
||||
version: 0.0.1
|
||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||
|
@ -3226,7 +3166,7 @@ paths:
|
|||
description: OK
|
||||
tags:
|
||||
- Inference
|
||||
/memory/create:
|
||||
/inference/register_model:
|
||||
post:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
|
@ -3240,17 +3180,13 @@ paths:
|
|||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreateMemoryBankRequest'
|
||||
$ref: '#/components/schemas/RegisterModelRequest'
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/MemoryBank'
|
||||
description: OK
|
||||
tags:
|
||||
- Memory
|
||||
- Models
|
||||
/memory/documents/delete:
|
||||
post:
|
||||
parameters:
|
||||
|
@ -3302,57 +3238,6 @@ paths:
|
|||
description: OK
|
||||
tags:
|
||||
- Memory
|
||||
/memory/drop:
|
||||
post:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DropMemoryBankRequest'
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: string
|
||||
description: OK
|
||||
tags:
|
||||
- Memory
|
||||
/memory/get:
|
||||
get:
|
||||
parameters:
|
||||
- in: query
|
||||
name: bank_id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/MemoryBank'
|
||||
- type: 'null'
|
||||
description: OK
|
||||
tags:
|
||||
- Memory
|
||||
/memory/insert:
|
||||
post:
|
||||
parameters:
|
||||
|
@ -3374,25 +3259,6 @@ paths:
|
|||
description: OK
|
||||
tags:
|
||||
- Memory
|
||||
/memory/list:
|
||||
get:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/jsonl:
|
||||
schema:
|
||||
$ref: '#/components/schemas/MemoryBank'
|
||||
description: OK
|
||||
tags:
|
||||
- Memory
|
||||
/memory/query:
|
||||
post:
|
||||
parameters:
|
||||
|
@ -3443,10 +3309,10 @@ paths:
|
|||
get:
|
||||
parameters:
|
||||
- in: query
|
||||
name: bank_type
|
||||
name: identifier
|
||||
required: true
|
||||
schema:
|
||||
$ref: '#/components/schemas/MemoryBankType'
|
||||
type: string
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
|
@ -3460,7 +3326,11 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/MemoryBankSpec'
|
||||
- oneOf:
|
||||
- $ref: '#/components/schemas/VectorMemoryBankDef'
|
||||
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
|
||||
- $ref: '#/components/schemas/KeywordMemoryBankDef'
|
||||
- $ref: '#/components/schemas/GraphMemoryBankDef'
|
||||
- type: 'null'
|
||||
description: OK
|
||||
tags:
|
||||
|
@ -3480,15 +3350,40 @@ paths:
|
|||
content:
|
||||
application/jsonl:
|
||||
schema:
|
||||
$ref: '#/components/schemas/MemoryBankSpec'
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/VectorMemoryBankDef'
|
||||
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
|
||||
- $ref: '#/components/schemas/KeywordMemoryBankDef'
|
||||
- $ref: '#/components/schemas/GraphMemoryBankDef'
|
||||
description: OK
|
||||
tags:
|
||||
- MemoryBanks
|
||||
/memory_banks/register:
|
||||
post:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterMemoryBankRequest'
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
tags:
|
||||
- Memory
|
||||
/models/get:
|
||||
get:
|
||||
parameters:
|
||||
- in: query
|
||||
name: core_model_id
|
||||
name: identifier
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -3505,7 +3400,7 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ModelServingSpec'
|
||||
- $ref: '#/components/schemas/ModelDef'
|
||||
- type: 'null'
|
||||
description: OK
|
||||
tags:
|
||||
|
@ -3525,7 +3420,7 @@ paths:
|
|||
content:
|
||||
application/jsonl:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ModelServingSpec'
|
||||
$ref: '#/components/schemas/ModelDef'
|
||||
description: OK
|
||||
tags:
|
||||
- Models
|
||||
|
@ -3760,6 +3655,27 @@ paths:
|
|||
description: OK
|
||||
tags:
|
||||
- Inspect
|
||||
/safety/register_shield:
|
||||
post:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterShieldRequest'
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
tags:
|
||||
- Shields
|
||||
/safety/run_shield:
|
||||
post:
|
||||
parameters:
|
||||
|
@ -3806,7 +3722,29 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ShieldSpec'
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
identifier:
|
||||
type: string
|
||||
params:
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
type: object
|
||||
provider_id:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
required:
|
||||
- identifier
|
||||
- type
|
||||
- params
|
||||
type: object
|
||||
- type: 'null'
|
||||
description: OK
|
||||
tags:
|
||||
|
@ -3826,7 +3764,7 @@ paths:
|
|||
content:
|
||||
application/jsonl:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ShieldSpec'
|
||||
$ref: '#/components/schemas/ShieldDef'
|
||||
description: OK
|
||||
tags:
|
||||
- Shields
|
||||
|
@ -3905,21 +3843,21 @@ security:
|
|||
servers:
|
||||
- url: http://any-hosted-llama-stack.com
|
||||
tags:
|
||||
- name: Datasets
|
||||
- name: Inspect
|
||||
- name: Memory
|
||||
- name: BatchInference
|
||||
- name: Agents
|
||||
- name: Datasets
|
||||
- name: Inference
|
||||
- name: Shields
|
||||
- name: SyntheticDataGeneration
|
||||
- name: Models
|
||||
- name: RewardScoring
|
||||
- name: MemoryBanks
|
||||
- name: Safety
|
||||
- name: Evaluations
|
||||
- name: Telemetry
|
||||
- name: Memory
|
||||
- name: Safety
|
||||
- name: PostTraining
|
||||
- name: MemoryBanks
|
||||
- name: Models
|
||||
- name: Shields
|
||||
- name: Inspect
|
||||
- name: SyntheticDataGeneration
|
||||
- name: Telemetry
|
||||
- name: Agents
|
||||
- name: RewardScoring
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||
name: BuiltinTool
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||
|
@ -4123,11 +4061,6 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateDatasetRequest"
|
||||
/>
|
||||
name: CreateDatasetRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateMemoryBankRequest"
|
||||
/>
|
||||
name: CreateMemoryBankRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBank" />
|
||||
name: MemoryBank
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsRequest"
|
||||
/>
|
||||
name: DeleteAgentsRequest
|
||||
|
@ -4140,9 +4073,6 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteDocumentsRequest"
|
||||
/>
|
||||
name: DeleteDocumentsRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DropMemoryBankRequest"
|
||||
/>
|
||||
name: DropMemoryBankRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/EmbeddingsRequest"
|
||||
/>
|
||||
name: EmbeddingsRequest
|
||||
|
@ -4163,11 +4093,23 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/GetAgentsSessionRequest"
|
||||
/>
|
||||
name: GetAgentsSessionRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankDef"
|
||||
/>
|
||||
name: GraphMemoryBankDef
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/KeyValueMemoryBankDef"
|
||||
/>
|
||||
name: KeyValueMemoryBankDef
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/KeywordMemoryBankDef"
|
||||
/>
|
||||
name: KeywordMemoryBankDef
|
||||
- description: 'A single session of an interaction with an Agentic System.
|
||||
|
||||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/Session" />'
|
||||
name: Session
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBankDef"
|
||||
/>
|
||||
name: VectorMemoryBankDef
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse"
|
||||
/>
|
||||
name: AgentStepResponse
|
||||
|
@ -4189,21 +4131,8 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJobStatusResponse"
|
||||
/>
|
||||
name: EvaluationJobStatusResponse
|
||||
- description: 'The model family and SKU of the model along with other parameters
|
||||
corresponding to the model.
|
||||
|
||||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/Model" />'
|
||||
name: Model
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ModelServingSpec"
|
||||
/>
|
||||
name: ModelServingSpec
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankType" />
|
||||
name: MemoryBankType
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankSpec" />
|
||||
name: MemoryBankSpec
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldSpec" />
|
||||
name: ShieldSpec
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ModelDef" />
|
||||
name: ModelDef
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
|
||||
name: Trace
|
||||
- description: 'Checkpoint created during training runs
|
||||
|
@ -4243,6 +4172,8 @@ tags:
|
|||
name: ProviderInfo
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
|
||||
name: RouteInfo
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldDef" />
|
||||
name: ShieldDef
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
|
||||
name: LogSeverity
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
|
||||
|
@ -4282,6 +4213,15 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryDocumentsResponse"
|
||||
/>
|
||||
name: QueryDocumentsResponse
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterMemoryBankRequest"
|
||||
/>
|
||||
name: RegisterMemoryBankRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterModelRequest"
|
||||
/>
|
||||
name: RegisterModelRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterShieldRequest"
|
||||
/>
|
||||
name: RegisterShieldRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DialogGenerations"
|
||||
/>
|
||||
name: DialogGenerations
|
||||
|
@ -4387,7 +4327,6 @@ x-tagGroups:
|
|||
- CreateAgentSessionRequest
|
||||
- CreateAgentTurnRequest
|
||||
- CreateDatasetRequest
|
||||
- CreateMemoryBankRequest
|
||||
- DPOAlignmentConfig
|
||||
- DeleteAgentsRequest
|
||||
- DeleteAgentsSessionRequest
|
||||
|
@ -4395,7 +4334,6 @@ x-tagGroups:
|
|||
- DeleteDocumentsRequest
|
||||
- DialogGenerations
|
||||
- DoraFinetuningConfig
|
||||
- DropMemoryBankRequest
|
||||
- EmbeddingsRequest
|
||||
- EmbeddingsResponse
|
||||
- EvaluateQuestionAnsweringRequest
|
||||
|
@ -4409,22 +4347,21 @@ x-tagGroups:
|
|||
- FunctionCallToolDefinition
|
||||
- GetAgentsSessionRequest
|
||||
- GetDocumentsRequest
|
||||
- GraphMemoryBankDef
|
||||
- HealthInfo
|
||||
- ImageMedia
|
||||
- InferenceStep
|
||||
- InsertDocumentsRequest
|
||||
- KeyValueMemoryBankDef
|
||||
- KeywordMemoryBankDef
|
||||
- LogEventRequest
|
||||
- LogSeverity
|
||||
- LoraFinetuningConfig
|
||||
- MemoryBank
|
||||
- MemoryBankDocument
|
||||
- MemoryBankSpec
|
||||
- MemoryBankType
|
||||
- MemoryRetrievalStep
|
||||
- MemoryToolDefinition
|
||||
- MetricEvent
|
||||
- Model
|
||||
- ModelServingSpec
|
||||
- ModelDef
|
||||
- OptimizerConfig
|
||||
- PhotogenToolDefinition
|
||||
- PostTrainingJob
|
||||
|
@ -4438,6 +4375,9 @@ x-tagGroups:
|
|||
- QueryDocumentsRequest
|
||||
- QueryDocumentsResponse
|
||||
- RLHFAlgorithm
|
||||
- RegisterMemoryBankRequest
|
||||
- RegisterModelRequest
|
||||
- RegisterShieldRequest
|
||||
- RestAPIExecutionConfig
|
||||
- RestAPIMethod
|
||||
- RewardScoreRequest
|
||||
|
@ -4453,7 +4393,7 @@ x-tagGroups:
|
|||
- SearchToolDefinition
|
||||
- Session
|
||||
- ShieldCallStep
|
||||
- ShieldSpec
|
||||
- ShieldDef
|
||||
- SpanEndPayload
|
||||
- SpanStartPayload
|
||||
- SpanStatus
|
||||
|
@ -4483,5 +4423,6 @@ x-tagGroups:
|
|||
- UnstructuredLogEvent
|
||||
- UpdateDocumentsRequest
|
||||
- UserMessage
|
||||
- VectorMemoryBankDef
|
||||
- ViolationLevel
|
||||
- WolframAlphaToolDefinition
|
||||
|
|
BIN
docs/resources/prompt-format.png
Normal file
BIN
docs/resources/prompt-format.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 170 KiB |
|
@ -261,7 +261,7 @@ class Session(BaseModel):
|
|||
turns: List[Turn]
|
||||
started_at: datetime
|
||||
|
||||
memory_bank: Optional[MemoryBank] = None
|
||||
memory_bank: Optional[MemoryBankDef] = None
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
|
@ -411,8 +411,10 @@ class Agents(Protocol):
|
|||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse: ...
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/agents/turn/create")
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -67,9 +67,17 @@ class AgentsClient(Agents):
|
|||
response.raise_for_status()
|
||||
return AgentSessionCreateResponse(**response.json())
|
||||
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
if request.stream:
|
||||
return self._stream_agent_turn(request)
|
||||
else:
|
||||
return self._nonstream_agent_turn(request)
|
||||
|
||||
async def _stream_agent_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
|
@ -93,6 +101,9 @@ class AgentsClient(Agents):
|
|||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
||||
raise NotImplementedError("Non-streaming not implemented yet")
|
||||
|
||||
|
||||
async def _run_agent(
|
||||
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
||||
|
@ -132,8 +143,7 @@ async def _run_agent(
|
|||
log.print()
|
||||
|
||||
|
||||
async def run_llama_3_1(host: str, port: int):
|
||||
model = "Llama3.1-8B-Instruct"
|
||||
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
|
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
|
|||
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
||||
|
||||
|
||||
async def run_llama_3_2_rag(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
|
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
async def run_llama_3_2(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
# zero shot tools for llama3.2 text models
|
||||
|
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
def main(host: str, port: int, run_type: str):
|
||||
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||
assert run_type in [
|
||||
"tools_llama_3_1",
|
||||
"tools_llama_3_2",
|
||||
|
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
|
|||
"tools_llama_3_2": run_llama_3_2,
|
||||
"rag_llama_3_2": run_llama_3_2_rag,
|
||||
}
|
||||
asyncio.run(fn[run_type](host, port))
|
||||
args = [host, port]
|
||||
if model is not None:
|
||||
args.append(model)
|
||||
asyncio.run(fn[run_type](*args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -42,10 +42,10 @@ class InferenceClient(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -66,6 +66,29 @@ class InferenceClient(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
return ChatCompletionResponse(**j)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
|
@ -77,7 +100,8 @@ class InferenceClient(Inference):
|
|||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
cprint(
|
||||
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
||||
f"Error: HTTP {response.status_code} {content.decode()}",
|
||||
"red",
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -85,40 +109,59 @@ class InferenceClient(Inference):
|
|||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if request.stream:
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponse(**json.loads(data))
|
||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
async def run_main(
|
||||
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
|
||||
):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
if not model:
|
||||
model = "Llama3.1-8B-Instruct"
|
||||
|
||||
message = UserMessage(
|
||||
content="hello world, write me a 2 sentence poem about the moon"
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
|
||||
if logprobs:
|
||||
logprobs_config = LogProbConfig(
|
||||
top_k=1,
|
||||
)
|
||||
else:
|
||||
logprobs_config = None
|
||||
|
||||
iterator = client.chat_completion(
|
||||
model="Llama3.1-8B-Instruct",
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
logprobs=logprobs_config,
|
||||
)
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
if logprobs:
|
||||
async for chunk in iterator:
|
||||
cprint(f"Response: {chunk}", "red")
|
||||
else:
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
|
||||
async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
||||
async def run_mm_main(
|
||||
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
|
||||
):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
if not model:
|
||||
model = "Llama3.2-11B-Vision-Instruct"
|
||||
|
||||
message = UserMessage(
|
||||
content=[
|
||||
ImageMedia(image=URL(uri=f"file://{path}")),
|
||||
|
@ -127,7 +170,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
|||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
model="Llama3.2-11B-Vision-Instruct",
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
|
@ -135,11 +178,19 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
|||
log.print()
|
||||
|
||||
|
||||
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
|
||||
def main(
|
||||
host: str,
|
||||
port: int,
|
||||
stream: bool = True,
|
||||
mm: bool = False,
|
||||
logprobs: bool = False,
|
||||
file: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
if mm:
|
||||
asyncio.run(run_mm_main(host, port, stream, file))
|
||||
asyncio.run(run_mm_main(host, port, stream, file, model))
|
||||
else:
|
||||
asyncio.run(run_main(host, port, stream))
|
||||
asyncio.run(run_main(host, port, stream, model, logprobs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
|
@ -172,9 +173,17 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
def get_model(self, identifier: str) -> ModelDef: ...
|
||||
|
||||
|
||||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/inference/completion")
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -183,8 +192,10 @@ class Inference(Protocol):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -203,3 +214,6 @@ class Inference(Protocol):
|
|||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse: ...
|
||||
|
||||
@webmethod(route="/inference/register_model")
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
|
|
|
@ -12,15 +12,15 @@ from pydantic import BaseModel
|
|||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
description: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
providers: List[str]
|
||||
provider_types: List[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks.client import MemoryBanksClient
|
||||
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
||||
|
||||
|
||||
|
@ -35,44 +35,16 @@ class MemoryClient(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(
|
||||
f"{self.base_url}/memory/get",
|
||||
params={
|
||||
"bank_id": bank_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory/create",
|
||||
response = await client.post(
|
||||
f"{self.base_url}/memory/register_memory_bank",
|
||||
json={
|
||||
"name": name,
|
||||
"config": config.dict(),
|
||||
"url": url,
|
||||
"memory_bank": json.loads(memory_bank.json()),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
response.raise_for_status()
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -114,22 +86,20 @@ class MemoryClient(Memory):
|
|||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = MemoryClient(f"http://{host}:{port}")
|
||||
banks_client = MemoryBanksClient(f"http://{host}:{port}")
|
||||
|
||||
# create a memory bank
|
||||
bank = await client.create_memory_bank(
|
||||
name="test_bank",
|
||||
config=VectorMemoryBankConfig(
|
||||
bank_id="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
provider_id="",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
cprint(json.dumps(bank.dict(), indent=4), "green")
|
||||
await client.register_memory_bank(bank)
|
||||
|
||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
|
||||
assert retrieved_bank is not None
|
||||
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
|
||||
assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
|
@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
|
||||
# insert some documents
|
||||
await client.insert_documents(
|
||||
bank_id=bank.bank_id,
|
||||
bank_id=bank.identifier,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
# query the documents
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
bank_id=bank.identifier,
|
||||
query=[
|
||||
"How do I use Lora?",
|
||||
],
|
||||
|
@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
bank_id=bank.identifier,
|
||||
query=[
|
||||
"Tell me more about llama3 and torchtune",
|
||||
],
|
||||
|
|
|
@ -13,9 +13,9 @@ from typing import List, Optional, Protocol
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -26,44 +26,6 @@ class MemoryBankDocument(BaseModel):
|
|||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankType(Enum):
|
||||
vector = "vector"
|
||||
keyvalue = "keyvalue"
|
||||
keyword = "keyword"
|
||||
graph = "graph"
|
||||
|
||||
|
||||
class VectorMemoryBankConfig(BaseModel):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class KeyValueMemoryBankConfig(BaseModel):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
|
||||
|
||||
class KeywordMemoryBankConfig(BaseModel):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
|
||||
|
||||
class GraphMemoryBankConfig(BaseModel):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankConfig,
|
||||
KeyValueMemoryBankConfig,
|
||||
KeywordMemoryBankConfig,
|
||||
GraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
content: InterleavedTextMedia
|
||||
token_count: int
|
||||
|
@ -76,45 +38,12 @@ class QueryDocumentsResponse(BaseModel):
|
|||
scores: List[float]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryAPI(Protocol):
|
||||
@webmethod(route="/query_documents")
|
||||
def query_documents(
|
||||
self,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBank(BaseModel):
|
||||
bank_id: str
|
||||
name: str
|
||||
config: MemoryBankConfig
|
||||
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
|
||||
url: Optional[URL] = None
|
||||
class MemoryBankStore(Protocol):
|
||||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
||||
|
||||
|
||||
class Memory(Protocol):
|
||||
@webmethod(route="/memory/create")
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory/get", method="GET")
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory/drop", method="DELETE")
|
||||
async def drop_memory_bank(
|
||||
self,
|
||||
bank_id: str,
|
||||
) -> str: ...
|
||||
memory_bank_store: MemoryBankStore
|
||||
|
||||
# this will just block now until documents are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
|
@ -154,3 +83,6 @@ class Memory(Protocol):
|
|||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/register_memory_bank")
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -15,6 +15,25 @@ from termcolor import cprint
|
|||
from .memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
|
||||
if j is None:
|
||||
return None
|
||||
|
||||
if "type" not in j:
|
||||
raise ValueError("Memory bank type not specified")
|
||||
type = j["type"]
|
||||
if type == MemoryBankType.vector.value:
|
||||
return VectorMemoryBankDef(**j)
|
||||
elif type == MemoryBankType.keyvalue.value:
|
||||
return KeyValueMemoryBankDef(**j)
|
||||
elif type == MemoryBankType.keyword.value:
|
||||
return KeywordMemoryBankDef(**j)
|
||||
elif type == MemoryBankType.graph.value:
|
||||
return GraphMemoryBankDef(**j)
|
||||
else:
|
||||
raise ValueError(f"Unknown memory bank type: {type}")
|
||||
|
||||
|
||||
class MemoryBanksClient(MemoryBanks):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
@ -25,37 +44,36 @@ class MemoryBanksClient(MemoryBanks):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [MemoryBankSpec(**x) for x in response.json()]
|
||||
return [deserialize_memory_bank_def(x) for x in response.json()]
|
||||
|
||||
async def get_serving_memory_bank(
|
||||
self, bank_type: MemoryBankType
|
||||
) -> Optional[MemoryBankSpec]:
|
||||
async def get_memory_bank(
|
||||
self,
|
||||
identifier: str,
|
||||
) -> Optional[MemoryBankDef]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/get",
|
||||
params={
|
||||
"bank_type": bank_type.value,
|
||||
"identifier": identifier,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return MemoryBankSpec(**j)
|
||||
return deserialize_memory_bank_def(j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = MemoryBanksClient(f"http://{host}:{port}")
|
||||
|
||||
response = await client.list_available_memory_banks()
|
||||
response = await client.list_memory_banks()
|
||||
cprint(f"list_memory_banks response={response}", "green")
|
||||
|
||||
|
||||
|
|
|
@ -4,29 +4,67 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.memory import MemoryBankType
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankSpec(BaseModel):
|
||||
bank_type: MemoryBankType
|
||||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||
)
|
||||
class MemoryBankType(Enum):
|
||||
vector = "vector"
|
||||
keyvalue = "keyvalue"
|
||||
keyword = "keyword"
|
||||
graph = "graph"
|
||||
|
||||
|
||||
class CommonDef(BaseModel):
|
||||
identifier: str
|
||||
provider_id: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorMemoryBankDef(CommonDef):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeyValueMemoryBankDef(CommonDef):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeywordMemoryBankDef(CommonDef):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GraphMemoryBankDef(CommonDef):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
MemoryBankDef = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankDef,
|
||||
KeyValueMemoryBankDef,
|
||||
KeywordMemoryBankDef,
|
||||
GraphMemoryBankDef,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory_banks/list", method="GET")
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: ...
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/get", method="GET")
|
||||
async def get_serving_memory_bank(
|
||||
self, bank_type: MemoryBankType
|
||||
) -> Optional[MemoryBankSpec]: ...
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/register", method="POST")
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
|
|
|
@ -56,7 +56,7 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
response = await client.list_models()
|
||||
cprint(f"list_models response={response}", "green")
|
||||
|
||||
response = await client.get_model("Meta-Llama3.1-8B-Instruct")
|
||||
response = await client.get_model("Llama3.1-8B-Instruct")
|
||||
cprint(f"get_model response={response}", "blue")
|
||||
|
||||
response = await client.get_model("Llama-Guard-3-1B")
|
||||
|
|
|
@ -6,27 +6,32 @@
|
|||
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelServingSpec(BaseModel):
|
||||
llama_model: Model = Field(
|
||||
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
|
||||
class ModelDef(BaseModel):
|
||||
identifier: str = Field(
|
||||
description="A unique identifier for the model type",
|
||||
)
|
||||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||
llama_model: str = Field(
|
||||
description="Pointer to the core Llama family model",
|
||||
)
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this model"
|
||||
)
|
||||
# For now, we are only supporting core llama models but as soon as finetuned
|
||||
# and other custom models (for example various quantizations) are allowed, there
|
||||
# will be more metadata fields here
|
||||
|
||||
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models/list", method="GET")
|
||||
async def list_models(self) -> List[ModelServingSpec]: ...
|
||||
async def list_models(self) -> List[ModelDef]: ...
|
||||
|
||||
@webmethod(route="/models/get", method="GET")
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
||||
|
||||
@webmethod(route="/models/register", method="POST")
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
|
|
|
@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
|
|||
)
|
||||
print(response)
|
||||
|
||||
response = await client.run_shield(
|
||||
shield_type="injection_shield",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
def main(host: str, port: int, image: str = None):
|
||||
asyncio.run(run_main(host, port, image))
|
||||
|
|
|
@ -11,6 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -37,8 +38,17 @@ class RunShieldResponse(BaseModel):
|
|||
violation: Optional[SafetyViolation] = None
|
||||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
|
||||
|
||||
class Safety(Protocol):
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run_shield")
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse: ...
|
||||
|
||||
@webmethod(route="/safety/register_shield")
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
|
|
|
@ -4,25 +4,43 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldSpec(BaseModel):
|
||||
shield_type: str
|
||||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||
class ShieldType(Enum):
|
||||
generic_content_shield = "generic_content_shield"
|
||||
llama_guard = "llama_guard"
|
||||
code_scanner = "code_scanner"
|
||||
prompt_guard = "prompt_guard"
|
||||
|
||||
|
||||
class ShieldDef(BaseModel):
|
||||
identifier: str = Field(
|
||||
description="A unique identifier for the shield type",
|
||||
)
|
||||
type: str = Field(
|
||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||
)
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this shield"
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional parameters needed for this shield",
|
||||
)
|
||||
|
||||
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields/list", method="GET")
|
||||
async def list_shields(self) -> List[ShieldSpec]: ...
|
||||
async def list_shields(self) -> List[ShieldDef]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
|
||||
|
||||
@webmethod(route="/shields/register", method="POST")
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
|
|
|
@ -158,19 +158,18 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
info = prompt_guard_download_info()
|
||||
else:
|
||||
model = resolve_model(args.model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {args.model_id} not found")
|
||||
return
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if model is None:
|
||||
parser.error(f"Model {args.model_id} not found")
|
||||
return
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
else:
|
||||
meta_url = args.meta_url
|
||||
if not meta_url:
|
||||
meta_url = input(
|
||||
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
"Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
assert meta_url is not None and "llamameta.net" in meta_url
|
||||
_meta_download(model, meta_url, info)
|
||||
|
|
|
@ -22,7 +22,7 @@ def available_templates_specs() -> List[BuildConfig]:
|
|||
import yaml
|
||||
|
||||
template_specs = []
|
||||
for p in TEMPLATES_PATH.rglob("*.yaml"):
|
||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||
with open(p, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
template_specs.append(build_config)
|
||||
|
@ -105,8 +105,7 @@ class StackBuild(Subcommand):
|
|||
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.build import ApiInput, build_image, ImageType
|
||||
|
||||
from llama_stack.distribution.build import build_image, ImageType
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
from termcolor import cprint
|
||||
|
@ -137,16 +136,19 @@ class StackBuild(Subcommand):
|
|||
if build_config.image_type == "conda"
|
||||
else (f"llamastack-{build_config.name}")
|
||||
)
|
||||
cprint(
|
||||
f"You can now run `llama stack configure {configure_name}`",
|
||||
color="green",
|
||||
)
|
||||
if build_config.image_type == "conda":
|
||||
cprint(
|
||||
f"You can now run `llama stack configure {configure_name}`",
|
||||
color="green",
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"You can now run `llama stack run {build_config.name}`",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
|
@ -172,9 +174,11 @@ class StackBuild(Subcommand):
|
|||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
import textwrap
|
||||
import yaml
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.completion import WordCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from termcolor import cprint
|
||||
|
||||
|
@ -238,26 +242,29 @@ class StackBuild(Subcommand):
|
|||
)
|
||||
|
||||
cprint(
|
||||
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. Let's select
|
||||
the provider types (implementations) you want to use for these APIs.
|
||||
""",
|
||||
),
|
||||
color="green",
|
||||
)
|
||||
|
||||
print("Tip: use <TAB> to see options for the providers.\n")
|
||||
|
||||
providers = dict()
|
||||
for api, providers_for_api in get_provider_registry().items():
|
||||
available_providers = [
|
||||
x for x in providers_for_api.keys() if x != "remote"
|
||||
]
|
||||
api_provider = prompt(
|
||||
"> Enter provider for the {} API: (default=meta-reference): ".format(
|
||||
api.value
|
||||
),
|
||||
"> Enter provider for API {}: ".format(api.value),
|
||||
completer=WordCompleter(available_providers),
|
||||
complete_while_typing=True,
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in providers_for_api,
|
||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
||||
list(providers_for_api.keys())
|
||||
),
|
||||
),
|
||||
default=(
|
||||
"meta-reference"
|
||||
if "meta-reference" in providers_for_api
|
||||
else list(providers_for_api.keys())[0]
|
||||
lambda x: x in available_providers,
|
||||
error_message="Invalid provider, use <TAB> to see options",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -71,10 +71,8 @@ class StackConfigure(Subcommand):
|
|||
conda_dir = (
|
||||
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
|
||||
)
|
||||
output = subprocess.check_output(
|
||||
["bash", "-c", "conda info --json -a | jq '.envs'"]
|
||||
)
|
||||
conda_envs = json.loads(output.decode("utf-8"))
|
||||
output = subprocess.check_output(["bash", "-c", "conda info --json"])
|
||||
conda_envs = json.loads(output.decode("utf-8"))["envs"]
|
||||
|
||||
for x in conda_envs:
|
||||
if x.endswith(f"/llamastack-{args.config}"):
|
||||
|
@ -129,7 +127,10 @@ class StackConfigure(Subcommand):
|
|||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.configure import configure_api_providers
|
||||
from llama_stack.distribution.configure import (
|
||||
configure_api_providers,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
||||
|
@ -145,13 +146,17 @@ class StackConfigure(Subcommand):
|
|||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
config = StackRunConfig(**yaml.safe_load(run_config_file.read_text()))
|
||||
config_dict = yaml.safe_load(run_config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
image_name=image_name,
|
||||
apis_to_serve=[],
|
||||
api_providers={},
|
||||
apis=list(build_config.distribution_spec.providers.keys()),
|
||||
providers={},
|
||||
models=[],
|
||||
shields=[],
|
||||
memory_banks=[],
|
||||
)
|
||||
|
||||
config = configure_api_providers(config, build_config.distribution_spec)
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
@ -46,10 +45,11 @@ class StackRun(Subcommand):
|
|||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
if not args.config:
|
||||
|
@ -75,8 +75,10 @@ class StackRun(Subcommand):
|
|||
)
|
||||
return
|
||||
|
||||
cprint(f"Using config `{config_file}`", "green")
|
||||
with open(config_file, "r") as f:
|
||||
config = StackRunConfig(**yaml.safe_load(f))
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
|
||||
if config.docker_image:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
|
@ -1,105 +0,0 @@
|
|||
from argparse import Namespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.cli.stack.build import StackBuild
|
||||
|
||||
|
||||
# temporary while we make the tests work
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stack_build():
|
||||
parser = MagicMock()
|
||||
subparsers = MagicMock()
|
||||
return StackBuild(subparsers)
|
||||
|
||||
|
||||
def test_stack_build_initialization(stack_build):
|
||||
assert stack_build.parser is not None
|
||||
assert stack_build.parser.set_defaults.called_once_with(
|
||||
func=stack_build._run_stack_build_command
|
||||
)
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_with_config(
|
||||
mock_build_image, mock_build_config, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config="test_config.yaml",
|
||||
template=None,
|
||||
list_templates=False,
|
||||
name=None,
|
||||
image_type="conda",
|
||||
)
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("yaml.safe_load") as mock_yaml_load:
|
||||
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
|
||||
mock_build_config.return_value = MagicMock()
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
||||
|
||||
|
||||
@patch("llama_stack.cli.table.print_table")
|
||||
def test_run_stack_build_command_list_templates(mock_print_table, stack_build):
|
||||
args = Namespace(list_templates=True)
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_print_table.assert_called_once()
|
||||
|
||||
|
||||
@patch("prompt_toolkit.prompt")
|
||||
@patch("llama_stack.distribution.datatypes.BuildConfig")
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_interactive(
|
||||
mock_build_image, mock_build_config, mock_prompt, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config=None, template=None, list_templates=False, name=None, image_type=None
|
||||
)
|
||||
|
||||
mock_prompt.side_effect = [
|
||||
"test_name",
|
||||
"conda",
|
||||
"meta-reference",
|
||||
"test description",
|
||||
]
|
||||
mock_build_config.return_value = MagicMock()
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
assert mock_prompt.call_count == 4
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.datatypes.BuildConfig")
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_with_template(
|
||||
mock_build_image, mock_build_config, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config=None,
|
||||
template="test_template",
|
||||
list_templates=False,
|
||||
name="test_name",
|
||||
image_type="docker",
|
||||
)
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("yaml.safe_load") as mock_yaml_load:
|
||||
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
|
||||
mock_build_config.return_value = MagicMock()
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
153
llama_stack/cli/tests/test_stack_config.py
Normal file
153
llama_stack/cli/tests/test_stack_config.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from llama_stack.distribution.configure import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def up_to_date_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
version: {version}
|
||||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {built_at}
|
||||
models:
|
||||
- identifier: model1
|
||||
provider_id: provider1
|
||||
llama_model: Llama3.1-8B-Instruct
|
||||
shields:
|
||||
- identifier: shield1
|
||||
type: llama_guard
|
||||
provider_id: provider1
|
||||
memory_banks:
|
||||
- identifier: memory1
|
||||
type: vector
|
||||
provider_id: provider1
|
||||
embedding_model: all-MiniLM-L6-v2
|
||||
chunk_size_in_tokens: 512
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
config: {{}}
|
||||
safety:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
config: {{}}
|
||||
""".format(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
image_name: foo
|
||||
built_at: {built_at}
|
||||
apis_to_serve: []
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
routing_key: Llama3.2-1B-Instruct
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- routing_key: ["shield1", "shield2"]
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- routing_key: vector
|
||||
provider_type: meta-reference
|
||||
config: {{}}
|
||||
api_providers:
|
||||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
""".format(
|
||||
built_at=datetime.now().isoformat()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
routing_table: {}
|
||||
api_providers: {}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert len(result.models) == 1
|
||||
assert len(result.shields) == 1
|
||||
assert len(result.memory_banks) == 1
|
||||
assert "inference" in result.providers
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||
result = parse_and_maybe_upgrade_config(old_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert len(result.models) == 2
|
||||
assert len(result.shields) == 2
|
||||
assert len(result.memory_banks) == 1
|
||||
assert all(
|
||||
api in result.providers
|
||||
for api in ["inference", "safety", "memory", "telemetry"]
|
||||
)
|
||||
safety_provider = result.providers["safety"][0]
|
||||
assert safety_provider.provider_type == "meta-reference"
|
||||
assert "llama_guard_shield" in safety_provider.config
|
||||
|
||||
inference_providers = result.providers["inference"]
|
||||
assert len(inference_providers) == 2
|
||||
assert set(x.provider_id for x in inference_providers) == {
|
||||
"remote::ollama-00",
|
||||
"meta-reference-01",
|
||||
}
|
||||
|
||||
ollama = inference_providers[0]
|
||||
assert ollama.provider_type == "remote::ollama"
|
||||
assert ollama.config["port"] == 11434
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
||||
with pytest.raises(ValueError):
|
||||
parse_and_maybe_upgrade_config(invalid_config)
|
|
@ -8,15 +8,16 @@ from enum import Enum
|
|||
from typing import List, Optional
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
from pydantic import BaseModel
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
|
||||
|
||||
|
@ -95,6 +96,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
build_config.name,
|
||||
package_deps.docker_image,
|
||||
str(build_file_path),
|
||||
str(BUILDS_BASE_DIR / ImageType.docker.value),
|
||||
" ".join(deps),
|
||||
]
|
||||
else:
|
||||
|
|
|
@ -23,7 +23,7 @@ if [ "$#" -lt 3 ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$3"
|
||||
special_pip_deps="$4"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ if [ "$#" -lt 4 ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$5"
|
||||
special_pip_deps="$6"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
@ -18,7 +18,8 @@ build_name="$1"
|
|||
image_name="llamastack-$build_name"
|
||||
docker_base=$2
|
||||
build_file_path=$3
|
||||
pip_dependencies=$4
|
||||
host_build_dir=$4
|
||||
pip_dependencies=$5
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
|
@ -33,7 +34,8 @@ REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
|
|||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
llama stack configure $build_file_path --output-dir $REPO_CONFIGS_DIR
|
||||
llama stack configure $build_file_path
|
||||
cp $host_build_dir/$build_name-run.yaml $REPO_CONFIGS_DIR
|
||||
|
||||
add_to_docker() {
|
||||
local input
|
||||
|
@ -132,6 +134,9 @@ fi
|
|||
|
||||
set -x
|
||||
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -rf $REPO_CONFIGS_DIR
|
||||
set +x
|
||||
|
||||
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
|
||||
|
|
|
@ -3,171 +3,369 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import textwrap
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_models.sku_list import (
|
||||
llama3_1_family,
|
||||
llama3_2_family,
|
||||
llama3_family,
|
||||
resolve_model,
|
||||
safety_models,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.validation import Validator
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.memory.memory import MemoryBankType
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
stack_apis,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.impls.meta_reference.safety.config import (
|
||||
MetaReferenceShieldType,
|
||||
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
ALLOWED_MODELS = (
|
||||
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
||||
)
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
routing_key: str
|
||||
config: config_class
|
||||
def configure_single_provider(
|
||||
registry: Dict[str, ProviderSpec], provider: Provider
|
||||
) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
if provider.config:
|
||||
existing = config_type(**provider.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
|
||||
return BaseModelWithConfig
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
|
||||
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
|
||||
"""Get corresponding builtin APIs given provider backed APIs"""
|
||||
res = []
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in provider_backed_apis:
|
||||
res.append(inf.routing_table_api.value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
config: StackRunConfig, build_spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||
# append the bulitin routing APIs
|
||||
apis += get_builtin_apis(apis)
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
router_api2builtin_api = {
|
||||
inf.router_api.value: inf.routing_table_api.value
|
||||
for inf in builtin_automatically_routed_apis()
|
||||
}
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||
we need to configure the providers (implementations) you want to use for these APIs.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = get_provider_registry()
|
||||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
|
||||
# configure simple case for with non-routing providers to api_providers
|
||||
for api_str in spec.providers.keys():
|
||||
if api_str not in apis:
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
if api not in provider_registry:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
|
||||
p = spec.providers[api_str]
|
||||
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
|
||||
|
||||
if isinstance(p, list):
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
|
||||
"yellow",
|
||||
f"Re-configuring existing providers for API `{api_str}`...",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p = p[0]
|
||||
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
provider_config = config.api_providers.get(api_str)
|
||||
if provider_config:
|
||||
existing = config_type(**provider_config.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
|
||||
if api_str in router_api2builtin_api:
|
||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
||||
routing_entries = []
|
||||
if api_str == "inference":
|
||||
if hasattr(cfg, "model"):
|
||||
routing_key = cfg.model
|
||||
else:
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported model your provider has for inference: ",
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
print(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(provider_registry[api], p)
|
||||
)
|
||||
|
||||
if api_str == "safety":
|
||||
# TODO: add support for other safety providers, and simplify safety provider config
|
||||
if p == "meta-reference":
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=[s.value for s in MetaReferenceShieldType],
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
if api_str == "memory":
|
||||
bank_types = list([x.value for x in MemoryBankType])
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported memory bank type your provider has for memory: ",
|
||||
default="vector",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in bank_types,
|
||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
||||
bank_types
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
config.routing_table[api_str] = routing_entries
|
||||
config.api_providers[api_str] = PlaceholderProviderConfig(
|
||||
providers=p if isinstance(p, list) else [p]
|
||||
)
|
||||
print("")
|
||||
else:
|
||||
config.api_providers[api_str] = GenericProviderConfig(
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
# we are newly configuring this API
|
||||
plist = build_spec.providers.get(api_str, [])
|
||||
plist = plist if isinstance(plist, list) else [plist]
|
||||
|
||||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
print(f"> Configuring provider `({provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{provider_type}-{i:02d}"
|
||||
if len(plist) > 1
|
||||
else provider_type
|
||||
),
|
||||
provider_type=provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
)
|
||||
print("")
|
||||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
=========================================================================================
|
||||
Now let's configure the `objects` you will be serving via the stack. These are:
|
||||
|
||||
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
|
||||
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
|
||||
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
|
||||
|
||||
This wizard will guide you through setting up one of each of these objects. You can
|
||||
always add more later by editing the run.yaml file.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
object_types = {
|
||||
"models": (ModelDef, configure_models, "inference"),
|
||||
"shields": (ShieldDef, configure_shields, "safety"),
|
||||
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||
}
|
||||
safety_providers = config.providers.get("safety", [])
|
||||
|
||||
for otype, (odef, config_method, api_str) in object_types.items():
|
||||
existing_objects = getattr(config, otype)
|
||||
|
||||
if existing_objects:
|
||||
cprint(
|
||||
f"{len(existing_objects)} {otype} exist. Skipping...",
|
||||
"blue",
|
||||
attrs=["bold"],
|
||||
)
|
||||
updated_objects = existing_objects
|
||||
else:
|
||||
providers = config.providers.get(api_str, [])
|
||||
if not providers:
|
||||
updated_objects = []
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||
updated_objects = config_method(
|
||||
config.providers[api_str], safety_providers
|
||||
)
|
||||
|
||||
setattr(config, otype, updated_objects)
|
||||
print("")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
|
||||
if not safety_providers:
|
||||
return None
|
||||
|
||||
provider = safety_providers[0]
|
||||
assert provider.provider_type == "meta-reference"
|
||||
|
||||
cfg = provider.config["llama_guard_shield"]
|
||||
if not cfg:
|
||||
return None
|
||||
return cfg["model"]
|
||||
|
||||
|
||||
def configure_models(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[ModelDef]:
|
||||
model = prompt(
|
||||
"> Please enter the model you want to serve: ",
|
||||
default="Llama3.2-1B-Instruct",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: resolve_model(x) is not None,
|
||||
error_message="Model must be: {}".format(
|
||||
[x.descriptor() for x in ALLOWED_MODELS]
|
||||
),
|
||||
),
|
||||
)
|
||||
model = ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
provider_id=providers[0].provider_id,
|
||||
)
|
||||
|
||||
ret = [model]
|
||||
if llama_guard := get_llama_guard_model(safety_providers):
|
||||
ret.append(
|
||||
ModelDef(
|
||||
identifier=llama_guard,
|
||||
llama_model=llama_guard,
|
||||
provider_id=providers[0].provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def configure_shields(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[ShieldDef]:
|
||||
if get_llama_guard_model(safety_providers):
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier="llama_guard",
|
||||
type="llama_guard",
|
||||
provider_id=providers[0].provider_id,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def configure_memory_banks(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[MemoryBankDef]:
|
||||
bank_name = prompt(
|
||||
"> Please enter a name for your memory bank: ",
|
||||
default="my-memory-bank",
|
||||
)
|
||||
|
||||
return [
|
||||
VectorMemoryBankDef(
|
||||
identifier=bank_name,
|
||||
provider_id=providers[0].provider_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def upgrade_from_routing_table_to_registry(
|
||||
config_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{entry['provider_type']}-{i:02d}"
|
||||
if len(entries) > 1
|
||||
else entry["provider_type"]
|
||||
),
|
||||
provider_type=entry["provider_type"],
|
||||
config=entry["config"],
|
||||
)
|
||||
for i, entry in enumerate(entries)
|
||||
]
|
||||
|
||||
providers_by_api = {}
|
||||
models = []
|
||||
shields = []
|
||||
memory_banks = []
|
||||
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
||||
if api_str == "inference":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
models.append(
|
||||
ModelDef(
|
||||
identifier=key,
|
||||
provider_id=provider.provider_id,
|
||||
llama_model=key,
|
||||
)
|
||||
)
|
||||
elif api_str == "safety":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
shields.append(
|
||||
ShieldDef(
|
||||
identifier=key,
|
||||
type=ShieldType.llama_guard.value,
|
||||
provider_id=provider.provider_id,
|
||||
)
|
||||
)
|
||||
elif api_str == "memory":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
# we currently only support Vector memory banks so this is OK
|
||||
memory_banks.append(
|
||||
VectorMemoryBankDef(
|
||||
identifier=key,
|
||||
provider_id=provider.provider_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
)
|
||||
config_dict["models"] = models
|
||||
config_dict["shields"] = shields
|
||||
config_dict["memory_banks"] = memory_banks
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
if isinstance(provider, dict) and "provider_type" in provider:
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=f"{provider['provider_type']}",
|
||||
provider_type=provider["provider_type"],
|
||||
config=provider["config"],
|
||||
)
|
||||
]
|
||||
|
||||
config_dict["providers"] = providers_by_api
|
||||
|
||||
config_dict.pop("routing_table", None)
|
||||
config_dict.pop("api_providers", None)
|
||||
config_dict.pop("provider_map", None)
|
||||
|
||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
return StackRunConfig(**config_dict)
|
||||
|
||||
if "models" not in config_dict:
|
||||
print("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table_to_registry(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
config_dict["built_at"] = datetime.now().isoformat()
|
||||
|
||||
return StackRunConfig(**config_dict)
|
||||
|
|
|
@ -11,28 +11,32 @@ from typing import Dict, List, Optional, Union
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
||||
|
||||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
RoutableObject = Union[
|
||||
ModelDef,
|
||||
ShieldDef,
|
||||
MemoryBankDef,
|
||||
]
|
||||
|
||||
|
||||
class RoutableProviderConfig(GenericProviderConfig):
|
||||
routing_key: RoutingKey
|
||||
|
||||
|
||||
class PlaceholderProviderConfig(BaseModel):
|
||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
||||
|
||||
providers: List[str]
|
||||
RoutedProtocol = Union[
|
||||
Inference,
|
||||
Safety,
|
||||
Memory,
|
||||
]
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
|
@ -53,18 +57,17 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
|
||||
|
||||
# Example: /models, /shields
|
||||
@json_schema_type
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
router_api: Api
|
||||
registry: List[RoutableObject]
|
||||
module: str
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
description: Optional[str] = Field(
|
||||
default="",
|
||||
|
@ -80,7 +83,12 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Provider(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
built_at: datetime
|
||||
|
@ -100,36 +108,39 @@ this could be just a hash
|
|||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
apis_to_serve: List[str] = Field(
|
||||
apis: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
|
||||
api_providers: Dict[
|
||||
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
|
||||
] = Field(
|
||||
providers: Dict[str, List[Provider]] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package.
|
||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
|
||||
E.g. The following is a ProviderRoutingEntry for models:
|
||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Meta-Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
""",
|
||||
models: List[ModelDef] = Field(
|
||||
description="""
|
||||
List of model definitions to serve. This list may get extended by
|
||||
/models/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
shields: List[ShieldDef] = Field(
|
||||
description="""
|
||||
List of shield definitions to serve. This list may get extended by
|
||||
/shields/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
memory_banks: List[MemoryBankDef] = Field(
|
||||
description="""
|
||||
List of memory bank definitions to serve. This list may get extended by
|
||||
/memory_banks/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
name: str
|
||||
|
|
|
@ -6,45 +6,58 @@
|
|||
|
||||
from typing import Dict, List
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_providers = get_provider_registry()
|
||||
for api, providers in all_providers.items():
|
||||
ret[api.value] = [
|
||||
for api, providers in run_config.providers.items():
|
||||
ret[api] = [
|
||||
ProviderInfo(
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
description="Passthrough" if is_passthrough(p) else "",
|
||||
)
|
||||
for p in providers.values()
|
||||
for p in providers
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
for api, endpoints in all_endpoints.items():
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
ret[api.value] = [
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
providers=[],
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
|
|
|
@ -13,138 +13,207 @@ from llama_stack.distribution.distribution import (
|
|||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.inspect import DistributionInspectImpl
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
||||
class ProviderWithSpec(Provider):
|
||||
spec: ProviderSpec
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = get_provider_registry()
|
||||
specs = {}
|
||||
configs = {}
|
||||
all_api_providers = get_provider_registry()
|
||||
|
||||
for api_str, config in run_config.api_providers.items():
|
||||
api = Api(api_str)
|
||||
|
||||
# TODO: check that these APIs are not in the routing table part of the config
|
||||
providers = all_providers[api]
|
||||
|
||||
# skip checks for API whose provider config is specified in routing_table
|
||||
if isinstance(config, PlaceholderProviderConfig):
|
||||
continue
|
||||
|
||||
if config.provider_type not in providers:
|
||||
raise ValueError(
|
||||
f"Provider `{config.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[config.provider_type]
|
||||
configs[api] = config
|
||||
|
||||
apis_to_serve = run_config.apis_to_serve or set(
|
||||
list(specs.keys()) + list(run_config.routing_table.keys())
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
|
||||
providers_with_specs = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(
|
||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||
)
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in all_api_providers[api]:
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = all_api_providers[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.dict()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
source_api = info.routing_table_api
|
||||
|
||||
assert (
|
||||
source_api not in specs
|
||||
), f"Routing table API {source_api} specified in wrong place?"
|
||||
assert (
|
||||
info.router_api not in specs
|
||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
if info.router_api.value not in run_config.routing_table:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
||||
|
||||
routing_table = run_config.routing_table[info.router_api.value]
|
||||
|
||||
providers = all_providers[info.router_api]
|
||||
|
||||
inner_specs = []
|
||||
inner_deps = []
|
||||
for rt_entry in routing_table:
|
||||
if rt_entry.provider_type not in providers:
|
||||
registry = getattr(run_config, info.routing_table_api.value)
|
||||
for entry in registry:
|
||||
if entry.provider_id not in available_providers:
|
||||
raise ValueError(
|
||||
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
|
||||
f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_type])
|
||||
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
|
||||
|
||||
specs[source_api] = RoutingTableProviderSpec(
|
||||
api=source_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
inner_specs=inner_specs,
|
||||
provider = available_providers[entry.provider_id]
|
||||
inner_deps.extend(provider.spec.api_dependencies)
|
||||
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__routing_table__",
|
||||
config={},
|
||||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
registry=registry,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
deps__=(
|
||||
[x.value for x in inner_deps]
|
||||
+ [f"inner-{info.router_api.value}"]
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
providers_with_specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__autorouted__",
|
||||
config={},
|
||||
spec=AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
sorted_providers = topological_sort(
|
||||
{k: v.values() for k, v in providers_with_specs.items()}
|
||||
)
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
),
|
||||
),
|
||||
)
|
||||
configs[source_api] = routing_table
|
||||
|
||||
specs[info.router_api] = AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=source_api,
|
||||
api_dependencies=[source_api],
|
||||
)
|
||||
configs[info.router_api] = {}
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||
for spec in sorted_specs:
|
||||
print(f" {spec.api}: {spec.provider_type}")
|
||||
print("")
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, configs[api])
|
||||
|
||||
impls[api] = impl
|
||||
|
||||
impls[Api.inspect] = DistributionInspectImpl()
|
||||
specs[Api.inspect] = InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__distribution_builtin__",
|
||||
config_class="",
|
||||
module="",
|
||||
)
|
||||
|
||||
return impls, specs
|
||||
print(f"Resolved {len(sorted_providers)} providers in topological order")
|
||||
for api_str, provider in sorted_providers:
|
||||
print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}")
|
||||
print("")
|
||||
|
||||
impls = {}
|
||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
|
||||
inner_impls = {}
|
||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[
|
||||
f"inner-{provider.spec.router_api.value}"
|
||||
]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||
by_id = {x.api: x for x in providers}
|
||||
def topological_sort(
|
||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||
) -> List[ProviderWithSpec]:
|
||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
deps = []
|
||||
for provider in providers:
|
||||
for dep in provider.spec.deps__:
|
||||
deps.append(dep)
|
||||
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
for dep in deps:
|
||||
if dep not in visited:
|
||||
dfs((dep, providers_with_specs[dep]), visited, stack)
|
||||
|
||||
stack.append(a.api)
|
||||
stack.append(api_str)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in providers:
|
||||
if a.api not in visited:
|
||||
dfs(a, visited, stack)
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
dfs((api_str, providers), visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
flattened = []
|
||||
for api_str in stack:
|
||||
for provider in providers_with_specs[api_str]:
|
||||
flattened.append((api_str, provider))
|
||||
return flattened
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||
inner_impls: Dict[str, Any],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
args = []
|
||||
|
@ -154,9 +223,8 @@ async def instantiate_provider(
|
|||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
@ -166,31 +234,18 @@ async def instantiate_provider(
|
|||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
assert isinstance(provider_config, List)
|
||||
routing_table = provider_config
|
||||
|
||||
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in routing_table:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_type],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||
args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
@ -4,23 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
registry: List[RoutableObject],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
) -> Any:
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
|
@ -29,7 +28,7 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
||||
impl = api_to_tables[api.value](registry, impls_by_provider_id)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
|
@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
"""Routes to an provider based on the memory bank type"""
|
||||
"""Routes to an provider based on the memory bank identifier"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -29,32 +28,8 @@ class MemoryRouter(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.routing_table.get_provider_impl(bank_type)
|
||||
if not provider:
|
||||
raise ValueError(f"Could not find provider for {bank_type}")
|
||||
return provider
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_type = config.type
|
||||
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
|
||||
name, config, url
|
||||
)
|
||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
provider = self.get_provider_from_bank_id(bank_id)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
await self.routing_table.register_memory_bank(memory_bank)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -62,7 +37,7 @@ class MemoryRouter(Memory):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
return await self.get_provider_from_bank_id(bank_id).insert_documents(
|
||||
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
|
||||
bank_id, documents, ttl_seconds
|
||||
)
|
||||
|
||||
|
@ -72,7 +47,7 @@ class MemoryRouter(Memory):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
return await self.get_provider_from_bank_id(bank_id).query_documents(
|
||||
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
||||
bank_id, query, params
|
||||
)
|
||||
|
||||
|
@ -92,7 +67,10 @@ class InferenceRouter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -113,27 +91,32 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
**params
|
||||
):
|
||||
yield chunk
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.chat_completion(**params))
|
||||
else:
|
||||
return provider.chat_completion(**params)
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
) -> AsyncGenerator:
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
params = dict(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.completion(**params))
|
||||
else:
|
||||
return provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -159,6 +142,9 @@ class SafetyRouter(Safety):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.routing_table.register_shield(shield)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
|
|
|
@ -4,9 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
@ -16,129 +15,129 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
return p.__provider_spec__.api
|
||||
|
||||
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
|
||||
|
||||
# TODO: this routing table maintains state in memory purely. We need to
|
||||
# add persistence to it when we add dynamic registration of objects.
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[RoutingKey, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
registry: List[RoutableObject],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
) -> None:
|
||||
self.unique_providers = []
|
||||
self.providers = {}
|
||||
self.routing_keys = []
|
||||
for obj in registry:
|
||||
if obj.provider_id not in impls_by_provider_id:
|
||||
print(f"{impls_by_provider_id=}")
|
||||
raise ValueError(
|
||||
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
||||
)
|
||||
|
||||
for key, impl in inner_impls:
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
self.unique_providers.append((keys, impl))
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.registry = registry
|
||||
|
||||
for k in keys:
|
||||
if k in self.providers:
|
||||
raise ValueError(f"Duplicate routing key {k}")
|
||||
self.providers[k] = impl
|
||||
self.routing_keys.append(k)
|
||||
for p in self.impls_by_provider_id.values():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
|
||||
self.routing_table_config = routing_table_config
|
||||
self.routing_key_to_object = {}
|
||||
for obj in self.registry:
|
||||
self.routing_key_to_object[obj.identifier] = obj
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for keys, p in self.unique_providers:
|
||||
spec = p.__provider_spec__
|
||||
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
||||
continue
|
||||
|
||||
await p.validate_routing_keys(keys)
|
||||
for obj in self.registry:
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for _, p in self.unique_providers:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.providers:
|
||||
raise ValueError(f"Could not find provider for {routing_key}")
|
||||
return self.providers[routing_key]
|
||||
if routing_key not in self.routing_key_to_object:
|
||||
raise ValueError(f"`{routing_key}` not registered")
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
obj = self.routing_key_to_object[routing_key]
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == routing_key:
|
||||
return entry
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
|
||||
for obj in self.registry:
|
||||
if obj.identifier == identifier:
|
||||
return obj
|
||||
return None
|
||||
|
||||
async def register_object(self, obj: RoutableObject):
|
||||
if obj.identifier in self.routing_key_to_object:
|
||||
print(f"`{obj.identifier}` is already registered")
|
||||
return
|
||||
|
||||
if not obj.provider_id:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if not provider_ids:
|
||||
raise ValueError("No providers found")
|
||||
|
||||
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
|
||||
obj.provider_id = provider_ids[0]
|
||||
else:
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
self.routing_key_to_object[obj.identifier] = obj
|
||||
self.registry.append(obj)
|
||||
|
||||
# TODO: persist this to a store
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return self.registry
|
||||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
model_id = entry.routing_key
|
||||
specs.append(
|
||||
ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == core_model_id:
|
||||
return ModelServingSpec(
|
||||
llama_model=resolve_model(core_model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.register_object(model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return self.registry
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
if isinstance(entry.routing_key, list):
|
||||
for k in entry.routing_key:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=k,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
else:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
|
||||
return self.get_object_by_identifier(shield_type)
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == shield_type:
|
||||
return ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.register_object(shield)
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
return self.registry
|
||||
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
specs.append(
|
||||
MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == bank_type:
|
||||
return MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||
await self.register_object(bank)
|
||||
|
|
|
@ -5,18 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -43,20 +40,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
|
@ -169,11 +152,20 @@ async def passthrough(
|
|||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
def handle_sigint(app, *args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
|
||||
async def run_shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
print(f"Shutting down {impl}")
|
||||
await impl.shutdown()
|
||||
|
||||
asyncio.run(run_shutdown())
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
|
||||
loop.stop()
|
||||
|
||||
|
||||
|
@ -181,7 +173,10 @@ def handle_sigint(*args, **kwargs):
|
|||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
yield
|
||||
|
||||
print("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
|
@ -193,65 +188,59 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
if inspect.iscoroutine(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
if is_streaming:
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(**kwargs)
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
|
@ -285,29 +274,25 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
impls = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
if config.apis_to_serve:
|
||||
apis_to_serve = set(config.apis_to_serve)
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
apis_to_serve.add(Api.inspect)
|
||||
apis_to_serve.add("inspect")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
provider_spec = specs[api]
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
):
|
||||
if is_passthrough(impl.__provider_spec__):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
|
@ -337,7 +322,9 @@ def main(
|
|||
print("")
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
built_at: '2024-09-30T09:04:30.533391'
|
||||
version: '2'
|
||||
built_at: '2024-10-08T17:42:07.505267'
|
||||
image_name: local-cpu
|
||||
docker_image: local-cpu
|
||||
conda_env: null
|
||||
apis_to_serve:
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- models
|
||||
|
@ -10,40 +11,48 @@ apis_to_serve:
|
|||
- safety
|
||||
- shields
|
||||
- memory_banks
|
||||
api_providers:
|
||||
providers:
|
||||
inference:
|
||||
providers:
|
||||
- remote::ollama
|
||||
- provider_id: remote::ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 6000
|
||||
safety:
|
||||
providers:
|
||||
- meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
memory:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: /home/xiyan/.llama/runtime/kvstore.db
|
||||
memory:
|
||||
providers:
|
||||
- meta-reference
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 6000
|
||||
routing_key: Meta-Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||
memory:
|
||||
- provider_type: meta-reference
|
||||
config: {}
|
||||
routing_key: vector
|
||||
models:
|
||||
- identifier: Llama3.1-8B-Instruct
|
||||
llama_model: Llama3.1-8B-Instruct
|
||||
provider_id: remote::ollama
|
||||
shields:
|
||||
- identifier: llama_guard
|
||||
type: llama_guard
|
||||
provider_id: meta-reference
|
||||
params: {}
|
||||
memory_banks:
|
||||
- identifier: vector
|
||||
provider_id: meta-reference
|
||||
type: vector
|
||||
embedding_model: all-MiniLM-L6-v2
|
||||
chunk_size_in_tokens: 512
|
||||
overlap_size_in_tokens: null
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
built_at: '2024-09-30T09:00:56.693751'
|
||||
version: '2'
|
||||
built_at: '2024-10-08T17:42:33.690666'
|
||||
image_name: local-gpu
|
||||
docker_image: local-gpu
|
||||
conda_env: null
|
||||
apis_to_serve:
|
||||
apis:
|
||||
- memory
|
||||
- inference
|
||||
- agents
|
||||
|
@ -10,43 +11,51 @@ apis_to_serve:
|
|||
- safety
|
||||
- models
|
||||
- memory_banks
|
||||
api_providers:
|
||||
providers:
|
||||
inference:
|
||||
providers:
|
||||
- meta-reference
|
||||
safety:
|
||||
providers:
|
||||
- meta-reference
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: /home/xiyan/.llama/runtime/kvstore.db
|
||||
memory:
|
||||
providers:
|
||||
- meta-reference
|
||||
telemetry:
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||
memory:
|
||||
- provider_type: meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_key: vector
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
models:
|
||||
- identifier: Llama3.1-8B-Instruct
|
||||
llama_model: Llama3.1-8B-Instruct
|
||||
provider_id: meta-reference
|
||||
shields:
|
||||
- identifier: llama_guard
|
||||
type: llama_guard
|
||||
provider_id: meta-reference
|
||||
params: {}
|
||||
memory_banks:
|
||||
- identifier: vector
|
||||
provider_id: meta-reference
|
||||
type: vector
|
||||
embedding_model: all-MiniLM-L6-v2
|
||||
chunk_size_in_tokens: 512
|
||||
overlap_size_in_tokens: null
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
name: local-databricks
|
||||
distribution_spec:
|
||||
description: Use Databricks for running LLM inference
|
||||
providers:
|
||||
inference: remote::databricks
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
10
llama_stack/distribution/templates/local-vllm-build.yaml
Normal file
10
llama_stack/distribution/templates/local-vllm-build.yaml
Normal file
|
@ -0,0 +1,10 @@
|
|||
name: local-vllm
|
||||
distribution_spec:
|
||||
description: Like local, but use vLLM for running LLM inference
|
||||
providers:
|
||||
inference: vllm
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
|
@ -1,445 +1,445 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
|
||||
@staticmethod
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||
)
|
||||
self._config = config
|
||||
|
||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = self.map_to_provider_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
|
||||
@staticmethod
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
||||
)
|
||||
self._config = config
|
||||
|
||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = self.map_to_provider_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
assert isinstance(
|
||||
config, DatabricksImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: str = Field(
|
||||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
|
||||
DATABRICKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
|
@ -10,14 +10,19 @@ from fireworks.client import Fireworks
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
@ -27,21 +32,18 @@ FIREWORKS_SUPPORTED_MODELS = {
|
|||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Fireworks:
|
||||
return Fireworks(api_key=self.config.api_key)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -49,7 +51,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -59,27 +61,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
|
||||
fireworks_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
fireworks_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return fireworks_messages
|
||||
|
||||
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -101,147 +83,41 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
# accumulate sampling params and other options to pass to fireworks
|
||||
options = self.get_fireworks_chat_options(request)
|
||||
fireworks_model = self.map_to_provider_model(request.model)
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if r.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
async for chunk in self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "length"
|
||||
):
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is None:
|
||||
continue
|
||||
stream = client.completion.acreate(**params)
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
# Fireworks always prepends with BOS
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
options = get_sampling_options(request)
|
||||
options.setdefault("max_tokens", 512)
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": prompt,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
|
|
@ -7,6 +7,10 @@
|
|||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteProviderConfig):
|
||||
port: int = 11434
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
|
|
|
@ -9,34 +9,36 @@ from typing import AsyncGenerator
|
|||
import httpx
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
# "Llama3.1-8B-Instruct": "llama3.1",
|
||||
OLLAMA_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
|
||||
}
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
class OllamaInferenceAdapter(Inference):
|
||||
def __init__(self, url: str) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
||||
)
|
||||
self.url = url
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
|
@ -54,7 +56,29 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}"
|
||||
)
|
||||
|
||||
ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier]
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}")
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -64,32 +88,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
||||
ollama_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
ollama_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
if (
|
||||
request.sampling_params.repetition_penalty is not None
|
||||
and request.sampling_params.repetition_penalty != 1.0
|
||||
):
|
||||
options["repeat_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -110,156 +109,54 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.map_to_provider_model(request.model)
|
||||
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=False,
|
||||
options=options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r["message"]["content"], stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"options": get_sampling_options(request),
|
||||
"raw": True,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
assert isinstance(r, dict)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
)
|
||||
stream = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=True,
|
||||
options=options,
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk["done"]:
|
||||
if stop_reason is None and chunk["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and chunk["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk["message"]["content"]
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleInferenceImpl(Inference, RoutableProvider):
|
||||
class SampleInferenceImpl(Inference):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
model_id: str = Field(
|
||||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
|
|
|
@ -10,14 +10,19 @@ from typing import AsyncGenerator
|
|||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
@ -25,24 +30,31 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _HfAdapter(Inference, RoutableProvider):
|
||||
class _HfAdapter(Inference):
|
||||
client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
resolved_model = resolve_model(model.identifier)
|
||||
if resolved_model is None:
|
||||
raise ValueError(f"Unknown model: {model.identifier}")
|
||||
|
||||
if not resolved_model.huggingface_repo:
|
||||
raise ValueError(
|
||||
f"Model {model.identifier} does not have a HuggingFace repo"
|
||||
)
|
||||
|
||||
if self.model_id != resolved_model.huggingface_repo:
|
||||
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -52,16 +64,7 @@ class _HfAdapter(Inference, RoutableProvider):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -83,146 +86,64 @@ class _HfAdapter(Inference, RoutableProvider):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
input_tokens = len(model_input.tokens)
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
text="".join(t.text for t in r.details.tokens),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
|
||||
choice = OpenAICompatCompletionChoice(text=token_result.text)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||
request, self.formatter
|
||||
)
|
||||
max_new_tokens = min(
|
||||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
|
||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
||||
|
||||
options = self.get_chat_options(request)
|
||||
if not request.stream:
|
||||
response = await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if response.details.finish_reason:
|
||||
if response.details.finish_reason in ["stop", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif response.details.finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
response.generated_text,
|
||||
stop_reason,
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
tokens = []
|
||||
|
||||
async for response in await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
):
|
||||
token_result = response.token
|
||||
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.id)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
options = get_sampling_options(request)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
)
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
|
@ -236,7 +157,7 @@ class TGIAdapter(_HfAdapter):
|
|||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.model_id, token=config.api_token
|
||||
model=config.huggingface_repo, token=config.api_token
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
|
|
|
@ -8,17 +8,22 @@ from typing import AsyncGenerator
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
@ -34,19 +39,14 @@ TOGETHER_SUPPORTED_MODELS = {
|
|||
|
||||
|
||||
class TogetherInferenceAdapter(
|
||||
Inference, NeedsRequestProviderData, RoutableProviderForModels
|
||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Together:
|
||||
return Together(api_key=self.config.api_key)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -64,27 +64,7 @@ class TogetherInferenceAdapter(
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_together_messages(self, messages: list[Message]) -> list:
|
||||
together_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
together_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return together_messages
|
||||
|
||||
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -95,7 +75,6 @@ class TogetherInferenceAdapter(
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
|
@ -108,7 +87,6 @@ class TogetherInferenceAdapter(
|
|||
together_api_key = provider_data.together_api_key
|
||||
|
||||
client = Together(api_key=together_api_key)
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -120,146 +98,39 @@ class TogetherInferenceAdapter(
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# accumulate sampling params and other options to pass to together
|
||||
options = self.get_together_chat_options(request)
|
||||
together_model = self.map_to_provider_model(request.model)
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
if not request.stream:
|
||||
# TODO: might need to add back an async here
|
||||
r = client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if (
|
||||
r.choices[0].finish_reason == "stop"
|
||||
or r.choices[0].finish_reason == "eos"
|
||||
):
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
for chunk in client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if finish_reason := chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and finish_reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is None:
|
||||
continue
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -13,7 +12,6 @@ import chromadb
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
|
@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(Memory, RoutableProvider):
|
||||
class ChromaMemoryAdapter(Memory):
|
||||
def __init__(self, url: str) -> None:
|
||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
|
@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
collection = await self.client.create_collection(
|
||||
name=bank_id,
|
||||
metadata={"bank": bank.json()},
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=bank, index=ChromaIndex(self.client, collection)
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
)
|
||||
self.cache[bank_id] = bank_index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if bank is None:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
if collection.name == bank_id:
|
||||
print(collection.metadata)
|
||||
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
|
|
|
@ -4,18 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import psycopg2
|
||||
from numpy.typing import NDArray
|
||||
from psycopg2 import sql
|
||||
from psycopg2.extras import execute_values, Json
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
|
@ -32,33 +28,6 @@ def check_extension_version(cur):
|
|||
return result[0] if result else None
|
||||
|
||||
|
||||
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
||||
query = sql.SQL(
|
||||
"""
|
||||
INSERT INTO metadata_store (key, data)
|
||||
VALUES %s
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET data = EXCLUDED.data
|
||||
"""
|
||||
)
|
||||
|
||||
values = [(key, Json(model.dict())) for key, model in keys_models]
|
||||
execute_values(cur, query, values, template="(%s, %s)")
|
||||
|
||||
|
||||
def load_models(cur, keys: List[str], cls):
|
||||
query = "SELECT key, data FROM metadata_store"
|
||||
if keys:
|
||||
placeholders = ",".join(["%s"] * len(keys))
|
||||
query += f" WHERE key IN ({placeholders})"
|
||||
cur.execute(query, keys)
|
||||
else:
|
||||
cur.execute(query)
|
||||
|
||||
rows = cur.fetchall()
|
||||
return [cls(**row["data"]) for row in rows]
|
||||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
||||
self.cursor = cursor
|
||||
|
@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||
class PGVectorMemoryAdapter(Memory):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||
self.config = config
|
||||
|
@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
|||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
data JSONB
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
|
@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
[
|
||||
(bank.bank_id, bank),
|
||||
],
|
||||
)
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
bank=memory_bank,
|
||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
banks = load_models(self.cursor, [bank_id], MemoryBank)
|
||||
if not banks:
|
||||
return None
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
bank = banks[0]
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleMemoryImpl(Memory, RoutableProvider):
|
||||
class SampleMemoryImpl(Memory):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
# these are the memory banks the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
15
llama_stack/providers/adapters/memory/weaviate/__init__.py
Normal file
15
llama_stack/providers/adapters/memory/weaviate/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||
from .weaviate import WeaviateMemoryAdapter
|
||||
|
||||
impl = WeaviateMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
16
llama_stack/providers/adapters/memory/weaviate/config.py
Normal file
16
llama_stack/providers/adapters/memory/weaviate/config.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
weaviate_api_key: str
|
||||
weaviate_cluster_url: str
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
pass
|
180
llama_stack/providers/adapters/memory/weaviate/weaviate.py
Normal file
180
llama_stack/providers/adapters/memory/weaviate/weaviate.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
data_objects.append(
|
||||
wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_content": chunk.json(),
|
||||
},
|
||||
vector=embeddings[i].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
# Inserting chunks into a prespecified Weaviate collection
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
# TODO: make this async friendly
|
||||
collection.data.insert_many(data_objects)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
chunk_json = doc.properties["chunk_content"]
|
||||
try:
|
||||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / doc.metadata.distance)
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
|
||||
def __init__(self, config: WeaviateConfig) -> None:
|
||||
self.config = config
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
provider_data = self.get_request_provider_data()
|
||||
assert provider_data is not None, "Request provider data must be set"
|
||||
assert isinstance(provider_data, WeaviateRequestProviderData)
|
||||
|
||||
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
||||
client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=provider_data.weaviate_cluster_url,
|
||||
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
|
||||
)
|
||||
self.client_cache[key] = client
|
||||
return client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for client in self.client_cache.values():
|
||||
client.close()
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
if not client.collections.exists(memory_bank.identifier):
|
||||
client.collections.create(
|
||||
name=memory_bank.identifier,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
client = self._get_client()
|
||||
if not client.collections.exists(bank_id):
|
||||
raise ValueError(f"Collection with name `{bank_id}` not found")
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
|
@ -7,14 +7,12 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
import traceback
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import boto3
|
||||
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SUPPORTED_SHIELD_TYPES = [
|
||||
"bedrock_guardrail",
|
||||
BEDROCK_SUPPORTED_SHIELDS = [
|
||||
ShieldType.generic_content_shield.value,
|
||||
]
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||
class BedrockSafetyAdapter(Safety):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
if not config.aws_profile:
|
||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||
self.config = config
|
||||
self.registered_shields = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
|
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
shield_params = shield.params
|
||||
if "guardrailIdentifier" not in shield_params:
|
||||
raise ValueError(
|
||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||
)
|
||||
|
||||
if "guardrailVersion" not in shield_params:
|
||||
raise ValueError(
|
||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
|
@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
|||
|
||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"run_shield::{params}::messages={messages}")
|
||||
if "guardrailIdentifier" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||
)
|
||||
|
||||
if "guardrailVersion" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||
)
|
||||
shield_params = shield_def.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
|
||||
response = self.boto_client.apply_guardrail(
|
||||
guardrailIdentifier=params.get("guardrailIdentifier"),
|
||||
guardrailVersion=params.get("guardrailVersion"),
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
logger.debug(f"run_shield:: response: {response}::")
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
response = self.boto_client.apply_guardrail(
|
||||
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||
guardrailVersion=shield_params["guardrailVersion"],
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
except Exception:
|
||||
error_str = traceback.format_exc()
|
||||
logger.error(
|
||||
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
|||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
|
||||
class SampleSafetyImpl(Safety, RoutableProvider):
|
||||
class SampleSafetyImpl(Safety):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
@ -6,26 +6,20 @@
|
|||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
||||
SAFETY_SHIELD_TYPES = {
|
||||
TOGETHER_SHIELD_MODEL_MAP = {
|
||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -35,16 +29,20 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.llama_guard.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
model = shield_def.params.get("model", "llama_guard")
|
||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||
raise ValueError(f"Unsupported safety model: {model}")
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
|
@ -57,8 +55,6 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
|
@ -66,7 +62,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
api_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
violation = await get_safety_response(
|
||||
together_api_key, model_name, api_messages
|
||||
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
|
|
|
@ -43,23 +43,14 @@ class ProviderSpec(BaseModel):
|
|||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RoutingTable(Protocol):
|
||||
def get_routing_keys(self) -> List[str]: ...
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
class RoutableProvider(Protocol):
|
||||
"""
|
||||
A provider which sits behind the RoutingTable and can get routed to.
|
||||
|
||||
All Inference / Safety / Memory providers fall into this bucket.
|
||||
"""
|
||||
|
||||
async def validate_routing_keys(self, keys: List[str]) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
|
@ -156,6 +147,10 @@ as being "Llama Stack compatible"
|
|||
return None
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
|
|
|
@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
if session_info.memory_bank_id is None:
|
||||
memory_bank = await self.memory_api.create_memory_bank(
|
||||
name=f"memory_bank_{session_id}",
|
||||
config=VectorMemoryBankConfig(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
memory_bank = VectorMemoryBankDef(
|
||||
identifier=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
bank_id = memory_bank.bank_id
|
||||
await self.memory_api.register_memory_bank(memory_bank)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
@ -673,7 +674,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _retrieve_context(
|
||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
|
@ -722,12 +723,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
|
|
|
@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(agent_id)
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
attachments=attachments,
|
||||
stream=stream,
|
||||
stream=True,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _create_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import CodeShieldConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeShieldConfig, deps):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.code_scanner.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
result = await CodeShield.scan_code(text)
|
||||
|
||||
violation = None
|
||||
if result.is_insecure:
|
||||
violation = SafetyViolation(
|
||||
violation_level=(ViolationLevel.ERROR),
|
||||
user_message="Sorry, I found security concerns in the code.",
|
||||
metadata={
|
||||
"violation_type": ",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
)
|
||||
},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeShieldConfig(BaseModel):
|
||||
pass
|
|
@ -43,13 +43,12 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
print("generation start")
|
||||
for msg in x1[:5]:
|
||||
print("generation for msg: ", msg)
|
||||
response = self.inference_api.chat_completion(
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=model,
|
||||
messages=[msg],
|
||||
stream=False,
|
||||
)
|
||||
async for x in response:
|
||||
generation_outputs.append(x.completion_message.content)
|
||||
generation_outputs.append(response.completion_message.content)
|
||||
|
||||
x2 = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(x2)
|
||||
|
|
|
@ -297,7 +297,7 @@ class Llama:
|
|||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
||||
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
||||
if logprobs
|
||||
else None
|
||||
),
|
||||
|
|
|
@ -6,15 +6,14 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, List, Union
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
@ -25,7 +24,7 @@ from .model_parallel import LlamaModelParallelGenerator
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
|
@ -35,21 +34,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Loading model `{self.model.descriptor()}`")
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
assert (
|
||||
len(routing_keys) == 1
|
||||
), f"Only one routing key is supported {routing_keys}"
|
||||
assert routing_keys[0] == self.config.model
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||
# top-level server is going to do. make the typing more specific here
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -59,9 +57,10 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
) -> AsyncGenerator:
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
@ -74,7 +73,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
|
@ -88,21 +86,74 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
async with SEMAPHORE:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
||||
stop_reason = None
|
||||
|
||||
buffer = ""
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async with SEMAPHORE:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
|
@ -113,10 +164,9 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -127,13 +177,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
@ -154,59 +197,61 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
if request.stream:
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -13,15 +13,15 @@ from typing import Optional
|
|||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.apis.inference.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceImplConfig,
|
||||
)
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -14,7 +13,6 @@ import numpy as np
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
@ -63,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class FaissMemoryImpl(Memory, RoutableProvider):
|
||||
class FaissMemoryImpl(Memory):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
|
@ -72,37 +70,18 @@ class FaissMemoryImpl(Memory, RoutableProvider):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
assert url is None, "URL is not supported for this implementation"
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
config.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {config.type}"
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
||||
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
index = self.cache.get(bank_id)
|
||||
if index is None:
|
||||
return None
|
||||
return index.bank
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
|
|
@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str:
|
|||
return interleaved_text_media_as_str(message.content)
|
||||
|
||||
|
||||
# For shields that operate on simple strings
|
||||
class TextShield(ShieldBase):
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return "\n".join([message_content_as_str(m) for m in messages])
|
||||
|
@ -56,9 +55,3 @@ class TextShield(ShieldBase):
|
|||
@abstractmethod
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DummyShield(TextShield):
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
# Dummy return LOW to test e2e
|
||||
return ShieldResponse(is_violation=False)
|
|
@ -9,23 +9,19 @@ from typing import List, Optional
|
|||
|
||||
from llama_models.sku_list import CoreModelId, safety_models
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class MetaReferenceShieldType(Enum):
|
||||
llama_guard = "llama_guard"
|
||||
code_scanner_guard = "code_scanner_guard"
|
||||
injection_shield = "injection_shield"
|
||||
jailbreak_shield = "jailbreak_shield"
|
||||
class PromptGuardType(Enum):
|
||||
injection = "injection"
|
||||
jailbreak = "jailbreak"
|
||||
|
||||
|
||||
class LlamaGuardShieldConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-1B"
|
||||
excluded_categories: List[str] = []
|
||||
disable_input_check: bool = False
|
||||
disable_output_check: bool = False
|
||||
|
||||
@validator("model")
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
|
@ -47,10 +43,6 @@ class LlamaGuardShieldConfig(BaseModel):
|
|||
return model
|
||||
|
||||
|
||||
class PromptGuardShieldConfig(BaseModel):
|
||||
model: str = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None
|
||||
enable_prompt_guard: Optional[bool] = False
|
||||
|
|
|
@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
|
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
|
@ -6,56 +6,43 @@
|
|||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api, RoutableProvider
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceShieldType, SafetyConfig
|
||||
|
||||
from .shields import (
|
||||
CodeScannerShield,
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
LlamaGuardShield,
|
||||
PromptGuardShield,
|
||||
ShieldBase,
|
||||
)
|
||||
from .base import OnViolationAction, ShieldBase
|
||||
from .config import SafetyConfig
|
||||
from .llama_guard import LlamaGuardShield
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
||||
|
||||
|
||||
def resolve_and_get_path(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert model is not None, f"Could not resolve model {model_name}"
|
||||
model_dir = model_local_dir(model.descriptor())
|
||||
return model_dir
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
self.available_shields = []
|
||||
if config.llama_guard_shield:
|
||||
self.available_shields.append(ShieldType.llama_guard.value)
|
||||
if config.enable_prompt_guard:
|
||||
self.available_shields.append(ShieldType.prompt_guard.value)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
shield_cfg = self.config.prompt_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
if self.config.enable_prompt_guard:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
_ = PromptGuardShield.instance(model_dir)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
for key in routing_keys:
|
||||
if key not in available_shields:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type not in self.available_shields:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
@ -63,10 +50,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
assert shield_type in available_shields, f"Unknown shield {shield_type}"
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
|
||||
shield = self.get_shield_impl(shield_def)
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
|
@ -92,34 +80,22 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
cfg = self.config
|
||||
if typ == MetaReferenceShieldType.llama_guard:
|
||||
cfg = cfg.llama_guard_shield
|
||||
assert (
|
||||
cfg is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
|
||||
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||
if shield.type == ShieldType.llama_guard.value:
|
||||
cfg = self.config.llama_guard_shield
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
)
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||
return JailbreakShield.instance(model_dir)
|
||||
elif typ == MetaReferenceShieldType.injection_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
), "Cannot use PromptGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
||||
return CodeScannerShield.instance()
|
||||
elif shield.type == ShieldType.prompt_guard.value:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||
if subtype == "injection":
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif subtype == "jailbreak":
|
||||
return JailbreakShield.instance(model_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {typ}")
|
||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# supress warnings and spew of logs from hugging face
|
||||
import transformers
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
DummyShield,
|
||||
OnViolationAction,
|
||||
ShieldBase,
|
||||
ShieldResponse,
|
||||
TextShield,
|
||||
)
|
||||
from .code_scanner import CodeScannerShield # noqa: F401
|
||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||
from .prompt_guard import ( # noqa: F401
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
PromptGuardShield,
|
||||
)
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from .base import ShieldResponse, TextShield
|
||||
|
||||
|
||||
class CodeScannerShield(TextShield):
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
result = await CodeShield.scan_code(text)
|
||||
if result.is_insecure:
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
),
|
||||
violation_return_message="Sorry, I found security concerns in the code.",
|
||||
)
|
||||
else:
|
||||
return ShieldResponse(is_violation=False)
|
11
llama_stack/providers/impls/vllm/__init__.py
Normal file
11
llama_stack/providers/impls/vllm/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from typing import Any
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
35
llama_stack/providers/impls/vllm/config.py
Normal file
35
llama_stack/providers/impls/vllm/config.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMConfig(BaseModel):
|
||||
"""Configuration for the vLLM inference provider."""
|
||||
|
||||
model: str = Field(
|
||||
default="Llama3.1-8B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
tensor_parallel_size: int = Field(
|
||||
default=1,
|
||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||
)
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
return model
|
241
llama_stack/providers/impls/vllm/vllm.py
Normal file
241
llama_stack/providers/impls/vllm/vllm.py
Normal file
|
@ -0,0 +1,241 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
|
||||
"""Convert sampling params to vLLM sampling params."""
|
||||
if sampling_params is None:
|
||||
return SamplingParams()
|
||||
|
||||
# TODO convert what I saw in my first test ... but surely there's more to do here
|
||||
kwargs = {
|
||||
"temperature": sampling_params.temperature,
|
||||
}
|
||||
if sampling_params.top_k >= 1:
|
||||
kwargs["top_k"] = sampling_params.top_k
|
||||
if sampling_params.top_p:
|
||||
kwargs["top_p"] = sampling_params.top_p
|
||||
if sampling_params.max_tokens >= 1:
|
||||
kwargs["max_tokens"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
return SamplingParams(**kwargs)
|
||||
|
||||
|
||||
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||
"""Inference implementation for vLLM."""
|
||||
|
||||
HF_MODEL_MAPPINGS = {
|
||||
# TODO: seems like we should be able to build this table dynamically ...
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
||||
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
|
||||
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
|
||||
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
|
||||
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
|
||||
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
|
||||
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
|
||||
}
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
Inference.__init__(self)
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
|
||||
)
|
||||
self.config = config
|
||||
self.engine = None
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the vLLM inference adapter."""
|
||||
|
||||
log.info("Initializing vLLM inference adapter")
|
||||
|
||||
# Disable usage stats reporting. This would be a surprising thing for most
|
||||
# people to find out was on by default.
|
||||
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
|
||||
if "VLLM_NO_USAGE_STATS" not in os.environ:
|
||||
os.environ["VLLM_NO_USAGE_STATS"] = "1"
|
||||
|
||||
hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model)
|
||||
|
||||
# TODO -- there are a ton of options supported here ...
|
||||
engine_args = AsyncEngineArgs()
|
||||
engine_args.model = hf_model
|
||||
# We will need a new config item for this in the future if model support is more broad
|
||||
# than it is today (llama only)
|
||||
engine_args.tokenizer = hf_model
|
||||
engine_args.tensor_parallel_size = self.config.tensor_parallel_size
|
||||
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the vLLM inference adapter."""
|
||||
log.info("Shutting down vLLM inference adapter")
|
||||
if self.engine:
|
||||
self.engine.shutdown_background_loop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Any | None = ...,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
log.info("vLLM completion")
|
||||
messages = [UserMessage(content=content)]
|
||||
return self.chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[Message],
|
||||
sampling_params: Any | None = ...,
|
||||
tools: list[ToolDefinition] | None = ...,
|
||||
tool_choice: ToolChoice | None = ...,
|
||||
tool_prompt_format: ToolPromptFormat | None = ...,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
log.info("vLLM chat completion")
|
||||
|
||||
assert self.engine is not None
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
vllm_sampling_params = _vllm_sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(
|
||||
prompt, vllm_sampling_params, request_id
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, results_generator)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, results_generator)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> ChatCompletionResponse:
|
||||
outputs = [o async for o in results_generator]
|
||||
final_output = outputs[-1]
|
||||
|
||||
assert final_output is not None
|
||||
outputs = final_output.outputs
|
||||
finish_reason = outputs[-1].stop_reason
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=finish_reason,
|
||||
text="".join([output.text for output in outputs]),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> AsyncGenerator:
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
async for chunk in results_generator:
|
||||
if not chunk.outputs:
|
||||
log.warning("Empty chunk received")
|
||||
continue
|
||||
|
||||
text = "".join([output.text for output in chunk.outputs])
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk.outputs[-1].stop_reason,
|
||||
text=text,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self, model: str, contents: list[InterleavedTextMedia]
|
||||
) -> EmbeddingsResponse:
|
||||
log.info("vLLM embeddings")
|
||||
# TODO
|
||||
raise NotImplementedError()
|
|
@ -41,6 +41,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="ollama",
|
||||
pip_packages=["ollama"],
|
||||
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.adapters.inference.ollama",
|
||||
),
|
||||
),
|
||||
|
@ -103,4 +104,24 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="databricks",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.adapters.inference.databricks",
|
||||
config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="vllm",
|
||||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.impls.vllm",
|
||||
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -56,6 +56,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
AdapterSpec(
|
||||
adapter_type="weaviate",
|
||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||
module="llama_stack.providers.adapters.memory.weaviate",
|
||||
config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig",
|
||||
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -21,7 +21,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.safety,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
|
@ -61,4 +60,14 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="meta-reference/codeshield",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.codeshield",
|
||||
config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig",
|
||||
api_dependencies=[],
|
||||
),
|
||||
]
|
||||
|
|
5
llama_stack/providers/tests/__init__.py
Normal file
5
llama_stack/providers/tests/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
llama_stack/providers/tests/agents/__init__.py
Normal file
5
llama_stack/providers/tests/agents/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
llama_stack/providers/tests/inference/__init__.py
Normal file
5
llama_stack/providers/tests/inference/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,25 @@
|
|||
providers:
|
||||
- provider_id: test-ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
- provider_id: test-tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://localhost:7001
|
||||
- provider_id: test-remote
|
||||
provider_type: remote
|
||||
config:
|
||||
host: localhost
|
||||
port: 7002
|
||||
- provider_id: test-together
|
||||
provider_type: remote::together
|
||||
config: {}
|
||||
# if a provider needs private keys from the client, they use the
|
||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||
# this is a place to provide such data.
|
||||
provider_data:
|
||||
"test-together":
|
||||
together_api_key:
|
||||
0xdeadbeefputrealapikeyhere
|
243
llama_stack/providers/tests/inference/test_inference.py
Normal file
243
llama_stack/providers/tests/inference/test_inference.py
Normal file
|
@ -0,0 +1,243 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||
# since it depends on the provider you are testing. On top of that you need
|
||||
# `pytest` and `pytest-asyncio` installed.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
||||
def group_chunks(response):
|
||||
return {
|
||||
event_type: list(group)
|
||||
for event_type, group in itertools.groupby(
|
||||
response, key=lambda chunk: chunk.event.event_type
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Llama_8B = "Llama3.1-8B-Instruct"
|
||||
Llama_3B = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
def get_expected_stop_reason(model: str):
|
||||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||
|
||||
|
||||
# This is going to create multiple Stack impls without tearing down the previous one
|
||||
# Fix that!
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
{"model": Llama_8B},
|
||||
{"model": Llama_3B},
|
||||
],
|
||||
ids=lambda d: d["model"],
|
||||
)
|
||||
async def inference_settings(request):
|
||||
model = request.param["model"]
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.inference,
|
||||
models=[
|
||||
ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return {
|
||||
"impl": impls[Api.inference],
|
||||
"common_params": {
|
||||
"model": model,
|
||||
"tool_choice": ToolChoice.auto,
|
||||
"tool_prompt_format": (
|
||||
ToolPromptFormat.json
|
||||
if "Llama3.1" in model
|
||||
else ToolPromptFormat.python_list
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool_definition():
|
||||
return ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get the current weather",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
assert end.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling(
|
||||
inference_settings,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl = inference_settings["impl"]
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
|
||||
message = response.completion_message
|
||||
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
||||
# assert message.stop_reason == stop_reason
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) > 0
|
||||
|
||||
call = message.tool_calls[0]
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling_streaming(
|
||||
inference_settings,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl = inference_settings["impl"]
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# expected_stop_reason = get_expected_stop_reason(
|
||||
# inference_settings["common_params"]["model"]
|
||||
# )
|
||||
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
# assert end.event.stop_reason == expected_stop_reason
|
||||
|
||||
model = inference_settings["common_params"]["model"]
|
||||
if "Llama3.1" in model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
|
@ -8,7 +8,7 @@ import unittest
|
|||
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.augment_messages import augment_messages_for_tools
|
||||
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
|
||||
|
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
|
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
),
|
||||
],
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = augment_messages_for_tools(request)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
5
llama_stack/providers/tests/memory/__init__.py
Normal file
5
llama_stack/providers/tests/memory/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,24 @@
|
|||
providers:
|
||||
- provider_id: test-faiss
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
- provider_id: test-chroma
|
||||
provider_type: remote::chroma
|
||||
config:
|
||||
host: localhost
|
||||
port: 6001
|
||||
- provider_id: test-remote
|
||||
provider_type: remote
|
||||
config:
|
||||
host: localhost
|
||||
port: 7002
|
||||
- provider_id: test-weaviate
|
||||
provider_type: remote::weaviate
|
||||
config: {}
|
||||
# if a provider needs private keys from the client, they use the
|
||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||
# this is a place to provide such data.
|
||||
provider_data:
|
||||
"test-weaviate":
|
||||
weaviate_api_key: 0xdeadbeefputrealapikeyhere
|
||||
weaviate_cluster_url: http://foobarbaz
|
119
llama_stack/providers/tests/memory/test_memory.py
Normal file
119
llama_stack/providers/tests/memory/test_memory.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||
# since it depends on the provider you are testing. On top of that you need
|
||||
# `pytest` and `pytest-asyncio` installed.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/memory/test_memory.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_impl():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.memory,
|
||||
memory_banks=[],
|
||||
)
|
||||
return impls[Api.memory]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
return [
|
||||
MemoryBankDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc4",
|
||||
content="Neural networks are inspired by biological neural networks.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def register_memory_bank(memory_impl: Memory):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await memory_impl.register_memory_bank(bank)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(memory_impl, sample_documents):
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
await register_memory_bank(memory_impl)
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
query1 = "programming language"
|
||||
response1 = await memory_impl.query_documents("test_bank", query1)
|
||||
assert_valid_response(response1)
|
||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||
|
||||
# Test case 3: Query with semantic similarity
|
||||
query3 = "AI and brain-inspired computing"
|
||||
response3 = await memory_impl.query_documents("test_bank", query3)
|
||||
assert_valid_response(response3)
|
||||
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
|
||||
|
||||
# Test case 4: Query with limit on number of results
|
||||
query4 = "computer"
|
||||
params4 = {"max_chunks": 2}
|
||||
response4 = await memory_impl.query_documents("test_bank", query4, params4)
|
||||
assert_valid_response(response4)
|
||||
assert len(response4.chunks) <= 2
|
||||
|
||||
# Test case 5: Query with threshold on similarity score
|
||||
query5 = "quantum computing" # Not directly related to any document
|
||||
params5 = {"score_threshold": 0.5}
|
||||
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
||||
assert_valid_response(response5)
|
||||
assert all(score >= 0.5 for score in response5.scores)
|
||||
|
||||
|
||||
def assert_valid_response(response: QueryDocumentsResponse):
|
||||
assert isinstance(response, QueryDocumentsResponse)
|
||||
assert len(response.chunks) > 0
|
||||
assert len(response.scores) > 0
|
||||
assert len(response.chunks) == len(response.scores)
|
||||
for chunk in response.chunks:
|
||||
assert isinstance(chunk.content, str)
|
||||
assert chunk.document_id is not None
|
100
llama_stack/providers/tests/resolver.py
Normal file
100
llama_stack/providers/tests/resolver.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||
|
||||
|
||||
async def resolve_impls_for_test(
|
||||
api: Api,
|
||||
models: List[ModelDef] = None,
|
||||
memory_banks: List[MemoryBankDef] = None,
|
||||
shields: List[ShieldDef] = None,
|
||||
):
|
||||
if "PROVIDER_CONFIG" not in os.environ:
|
||||
raise ValueError(
|
||||
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
||||
)
|
||||
|
||||
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
if "providers" not in config_dict:
|
||||
raise ValueError("Config file should contain a `providers` key")
|
||||
|
||||
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
|
||||
if len(providers_by_id) == 0:
|
||||
raise ValueError("No providers found in config file")
|
||||
|
||||
if "PROVIDER_ID" in os.environ:
|
||||
provider_id = os.environ["PROVIDER_ID"]
|
||||
if provider_id not in providers_by_id:
|
||||
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
||||
provider = providers_by_id[provider_id]
|
||||
else:
|
||||
provider = list(providers_by_id.values())[0]
|
||||
provider_id = provider["provider_id"]
|
||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||
|
||||
models = models or []
|
||||
shields = shields or []
|
||||
memory_banks = memory_banks or []
|
||||
|
||||
models = [
|
||||
ModelDef(
|
||||
**{
|
||||
**m.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for m in models
|
||||
]
|
||||
shields = [
|
||||
ShieldDef(
|
||||
**{
|
||||
**s.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for s in shields
|
||||
]
|
||||
memory_banks = [
|
||||
MemoryBankDef(
|
||||
**{
|
||||
**m.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for m in memory_banks
|
||||
]
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=[api],
|
||||
providers={api.value: [Provider(**provider)]},
|
||||
models=models,
|
||||
memory_banks=memory_banks,
|
||||
shields=shields,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls_with_routing(run_config)
|
||||
|
||||
if "provider_data" in config_dict:
|
||||
provider_data = config_dict["provider_data"].get(provider_id, {})
|
||||
if provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||
)
|
||||
|
||||
return impls
|
35
llama_stack/providers/utils/inference/model_registry.py
Normal file
35
llama_stack/providers/utils/inference/model_registry.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
||||
class ModelRegistryHelper:
|
||||
|
||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
|
||||
def map_to_provider_model(self, identifier: str) -> str:
|
||||
model = resolve_model(identifier)
|
||||
if not model:
|
||||
raise ValueError(f"Unknown model: `{identifier}`")
|
||||
|
||||
if identifier not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
|
||||
)
|
||||
|
||||
return self.stack_to_provider_models_map[identifier]
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||
)
|
189
llama_stack/providers/utils/inference/openai_compat.py
Normal file
189
llama_stack/providers/utils/inference/openai_compat.py
Normal file
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoice(BaseModel):
|
||||
finish_reason: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
|
||||
|
||||
|
||||
class OpenAICompatCompletionResponse(BaseModel):
|
||||
choices: List[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_options(request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if params := request.sampling_params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
options[attr] = getattr(params, attr)
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
options["repeat_penalty"] = params.repetition_penalty
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def text_from_choice(choice) -> str:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
return choice.delta.content
|
||||
|
||||
return choice.text
|
||||
|
||||
|
||||
def process_chat_completion_response(
|
||||
request: ChatCompletionRequest,
|
||||
response: OpenAICompatCompletionResponse,
|
||||
formatter: ChatFormat,
|
||||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
stop_reason = None
|
||||
if reason := choice.finish_reason:
|
||||
if reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif reason == "eom":
|
||||
stop_reason = StopReason.end_of_message
|
||||
elif reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = formatter.decode_assistant_message_from_content(
|
||||
text_from_choice(choice), stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
|
||||
async def process_chat_completion_stream_response(
|
||||
request: ChatCompletionRequest,
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
formatter: ChatFormat,
|
||||
) -> AsyncGenerator:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if finish_reason:
|
||||
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = text_from_choice(choice)
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue