rebase on top of registry

This commit is contained in:
Xi Yan 2024-10-08 23:41:03 -07:00
commit 6abef716dd
107 changed files with 4813 additions and 3587 deletions

1
.gitignore vendored
View file

@ -13,3 +13,4 @@ xcuserdata/
Package.resolved
*.pte
*.ipynb_checkpoints*
.idea

View file

@ -1,4 +1,4 @@
# Contributing to Llama-Models
# Contributing to Llama-Stack
We want to make contributing to this project as easy and transparent as
possible.
@ -32,7 +32,7 @@ outlined on that page and do not file a public issue.
* ...
## Tips
* If you are developing with a llama-models repository checked out and need your distribution to reflect changes from there, set `LLAMA_MODELS_DIR` to that dir when running any of the `llama` CLI commands.
* If you are developing with a llama-stack repository checked out and need your distribution to reflect changes from there, set `LLAMA_STACK_DIR` to that dir when running any of the `llama` CLI commands.
## License
By contributing to Llama, you agree that your contributions will be licensed

View file

@ -81,11 +81,24 @@ cd llama-stack
$CONDA_PREFIX/bin/pip install -e .
```
## The Llama CLI
## Documentations
The `llama` CLI makes it easy to work with the Llama Stack set of tools, including installing and running Distributions, downloading models, studying model prompt formats, etc. Please see the [CLI reference](docs/cli_reference.md) for details. Please see the [Getting Started](docs/getting_started.md) guide for running a Llama Stack server.
The `llama` CLI makes it easy to work with the Llama Stack set of tools. Please find the following docs for details.
* [CLI reference](docs/cli_reference.md)
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
* [Getting Started](docs/getting_started.md)
* Guide to build and run a Llama Stack server.
* [Contributing](CONTRIBUTING.md)
## Llama Stack Client SDK
| **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) |
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) |
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.

5
SECURITY.md Normal file
View file

@ -0,0 +1,5 @@
# Security Policy
## Reporting a Vulnerability
Please report vulnerabilities to our bug bounty program at https://bugbounty.meta.com/

View file

@ -1,6 +1,6 @@
# Llama CLI Reference
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
The `llama` CLI tool helps you setup and use the Llama Stack & agentic systems. It should be available on your path after installing the `llama-stack` package.
### Subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
@ -117,9 +117,9 @@ llama download --source meta --model-id Llama-Guard-3-1B --meta-url META_URL
Essentially, the same commands above work, just replace `--source meta` with `--source huggingface`.
```bash
llama download --source huggingface --model-id Meta-Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
llama download --source huggingface --model-id Llama3.1-8B-Instruct --hf-token <HF_TOKEN>
llama download --source huggingface --model-id Meta-Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
llama download --source huggingface --model-id Llama3.1-70B-Instruct --hf-token <HF_TOKEN>
llama download --source huggingface --model-id Llama-Guard-3-1B --ignore-patterns *original*
llama download --source huggingface --model-id Prompt-Guard-86M --ignore-patterns *original*
@ -215,9 +215,8 @@ You can even run `llama model prompt-format` see all of the templates and their
```
llama model prompt-format -m Llama3.2-3B-Instruct
```
<p align="center">
<img width="719" alt="image" src="https://github.com/user-attachments/assets/c5332026-8c0b-4edc-b438-ec60cd7ca554">
</p>
![alt text](resources/prompt-format.png)
You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios.
@ -230,7 +229,7 @@ You will be shown a Markdown formatted description of the model interface and ho
- Please see our [Getting Started](getting_started.md) guide for more details on how to build and start a Llama Stack distribution.
### Step 3.1 Build
In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
In the following steps, imagine we'll be working with a `Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify:
- `name`: the name for our distribution (e.g. `8b-instruct`)
- `image_type`: our build image type (`conda | docker`)
- `distribution_spec`: our distribution specs for specifying API providers
@ -365,7 +364,7 @@ llama stack configure [ <name> | <docker-image-name> | <path/to/name.build.yaml>
$ llama stack configure ~/.llama/distributions/conda/8b-instruct-build.yaml
Configuring API: inference (meta-reference)
Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required):
Enter value for model (existing: Llama3.1-8B-Instruct) (required):
Enter value for quantization (optional):
Enter value for torch_seed (optional):
Enter value for max_seq_len (existing: 4096) (required):
@ -397,7 +396,7 @@ YAML configuration has been written to ~/.llama/builds/conda/8b-instruct-run.yam
After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/8b-instruct-run.yaml` with the following contents. You may edit this file to change the settings.
As you can see, we did basic configuration above and configured:
- inference to run on model `Meta-Llama3.1-8B-Instruct` (obtained from `llama model list`)
- inference to run on model `Llama3.1-8B-Instruct` (obtained from `llama model list`)
- Llama Guard safety shield with model `Llama-Guard-3-1B`
- Prompt Guard safety shield with model `Prompt-Guard-86M`

View file

@ -1,7 +1,7 @@
# llama-stack
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack)
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
@ -66,8 +66,17 @@ This guides allows you to quickly get started with building and running a Llama
You may also checkout this [notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) for trying out out demo scripts.
## Quick Cheatsheet
- Quick 3 line command to build and start a LlamaStack server using our Meta Reference implementation for all API endpoints with `conda` as build type.
#### Via docker
```
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack-local-gpu
```
> [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models.
#### Via conda
**`llama stack build`**
- You'll be prompted to enter build information interactively.
```

File diff suppressed because it is too large Load diff

View file

@ -580,63 +580,6 @@ components:
- uuid
- dataset
type: object
CreateMemoryBankRequest:
additionalProperties: false
properties:
config:
oneOf:
- additionalProperties: false
properties:
chunk_size_in_tokens:
type: integer
embedding_model:
type: string
overlap_size_in_tokens:
type: integer
type:
const: vector
default: vector
type: string
required:
- type
- embedding_model
- chunk_size_in_tokens
type: object
- additionalProperties: false
properties:
type:
const: keyvalue
default: keyvalue
type: string
required:
- type
type: object
- additionalProperties: false
properties:
type:
const: keyword
default: keyword
type: string
required:
- type
type: object
- additionalProperties: false
properties:
type:
const: graph
default: graph
type: string
required:
- type
type: object
name:
type: string
url:
$ref: '#/components/schemas/URL'
required:
- name
- config
type: object
DPOAlignmentConfig:
additionalProperties: false
properties:
@ -739,14 +682,6 @@ components:
- rank
- alpha
type: object
DropMemoryBankRequest:
additionalProperties: false
properties:
bank_id:
type: string
required:
- bank_id
type: object
EmbeddingsRequest:
additionalProperties: false
properties:
@ -908,6 +843,21 @@ components:
required:
- document_ids
type: object
GraphMemoryBankDef:
additionalProperties: false
properties:
identifier:
type: string
provider_id:
type: string
type:
const: graph
default: graph
type: string
required:
- identifier
- type
type: object
HealthInfo:
additionalProperties: false
properties:
@ -973,6 +923,36 @@ components:
- bank_id
- documents
type: object
KeyValueMemoryBankDef:
additionalProperties: false
properties:
identifier:
type: string
provider_id:
type: string
type:
const: keyvalue
default: keyvalue
type: string
required:
- identifier
- type
type: object
KeywordMemoryBankDef:
additionalProperties: false
properties:
identifier:
type: string
provider_id:
type: string
type:
const: keyword
default: keyword
type: string
required:
- identifier
- type
type: object
LogEventRequest:
additionalProperties: false
properties:
@ -1015,66 +995,6 @@ components:
- rank
- alpha
type: object
MemoryBank:
additionalProperties: false
properties:
bank_id:
type: string
config:
oneOf:
- additionalProperties: false
properties:
chunk_size_in_tokens:
type: integer
embedding_model:
type: string
overlap_size_in_tokens:
type: integer
type:
const: vector
default: vector
type: string
required:
- type
- embedding_model
- chunk_size_in_tokens
type: object
- additionalProperties: false
properties:
type:
const: keyvalue
default: keyvalue
type: string
required:
- type
type: object
- additionalProperties: false
properties:
type:
const: keyword
default: keyword
type: string
required:
- type
type: object
- additionalProperties: false
properties:
type:
const: graph
default: graph
type: string
required:
- type
type: object
name:
type: string
url:
$ref: '#/components/schemas/URL'
required:
- bank_id
- name
- config
type: object
MemoryBankDocument:
additionalProperties: false
properties:
@ -1107,41 +1027,6 @@ components:
- content
- metadata
type: object
MemoryBankSpec:
additionalProperties: false
properties:
bank_type:
$ref: '#/components/schemas/MemoryBankType'
provider_config:
additionalProperties: false
properties:
config:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_type:
type: string
required:
- provider_type
- config
type: object
required:
- bank_type
- provider_config
type: object
MemoryBankType:
enum:
- vector
- keyvalue
- keyword
- graph
type: string
MemoryRetrievalStep:
additionalProperties: false
properties:
@ -1349,36 +1234,18 @@ components:
- value
- unit
type: object
Model:
description: The model family and SKU of the model along with other parameters
corresponding to the model.
ModelServingSpec:
ModelDef:
additionalProperties: false
properties:
identifier:
type: string
llama_model:
$ref: '#/components/schemas/Model'
provider_config:
additionalProperties: false
properties:
config:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_type:
type: string
required:
- provider_type
- config
type: object
type: string
provider_id:
type: string
required:
- identifier
- llama_model
- provider_config
type: object
OptimizerConfig:
additionalProperties: false
@ -1554,13 +1421,13 @@ components:
ProviderInfo:
additionalProperties: false
properties:
description:
provider_id:
type: string
provider_type:
type: string
required:
- provider_id
- provider_type
- description
type: object
QLoraFinetuningConfig:
additionalProperties: false
@ -1650,6 +1517,56 @@ components:
enum:
- dpo
type: string
RegisterMemoryBankRequest:
additionalProperties: false
properties:
memory_bank:
oneOf:
- $ref: '#/components/schemas/VectorMemoryBankDef'
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
- $ref: '#/components/schemas/KeywordMemoryBankDef'
- $ref: '#/components/schemas/GraphMemoryBankDef'
required:
- memory_bank
type: object
RegisterModelRequest:
additionalProperties: false
properties:
model:
$ref: '#/components/schemas/ModelDef'
required:
- model
type: object
RegisterShieldRequest:
additionalProperties: false
properties:
shield:
additionalProperties: false
properties:
identifier:
type: string
params:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_id:
type: string
type:
type: string
required:
- identifier
- type
- params
type: object
required:
- shield
type: object
RestAPIExecutionConfig:
additionalProperties: false
properties:
@ -1728,7 +1645,7 @@ components:
properties:
method:
type: string
providers:
provider_types:
items:
type: string
type: array
@ -1737,7 +1654,7 @@ components:
required:
- route
- method
- providers
- provider_types
type: object
RunShieldRequest:
additionalProperties: false
@ -1892,7 +1809,11 @@ components:
additionalProperties: false
properties:
memory_bank:
$ref: '#/components/schemas/MemoryBank'
oneOf:
- $ref: '#/components/schemas/VectorMemoryBankDef'
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
- $ref: '#/components/schemas/KeywordMemoryBankDef'
- $ref: '#/components/schemas/GraphMemoryBankDef'
session_id:
type: string
session_name:
@ -1935,33 +1856,29 @@ components:
- step_id
- step_type
type: object
ShieldSpec:
ShieldDef:
additionalProperties: false
properties:
provider_config:
additionalProperties: false
properties:
config:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_type:
type: string
required:
- provider_type
- config
identifier:
type: string
params:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
shield_type:
provider_id:
type: string
type:
type: string
required:
- shield_type
- provider_config
- identifier
- type
- params
type: object
SpanEndPayload:
additionalProperties: false
@ -2571,6 +2488,29 @@ components:
- role
- content
type: object
VectorMemoryBankDef:
additionalProperties: false
properties:
chunk_size_in_tokens:
type: integer
embedding_model:
type: string
identifier:
type: string
overlap_size_in_tokens:
type: integer
provider_id:
type: string
type:
const: vector
default: vector
type: string
required:
- identifier
- type
- embedding_model
- chunk_size_in_tokens
type: object
ViolationLevel:
enum:
- info
@ -2604,7 +2544,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
\ draft and subject to change.\n Generated at 2024-10-08 15:18:57.600111"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3226,7 +3166,7 @@ paths:
description: OK
tags:
- Inference
/memory/create:
/inference/register_model:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
@ -3240,17 +3180,13 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/CreateMemoryBankRequest'
$ref: '#/components/schemas/RegisterModelRequest'
required: true
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/MemoryBank'
description: OK
tags:
- Memory
- Models
/memory/documents/delete:
post:
parameters:
@ -3302,57 +3238,6 @@ paths:
description: OK
tags:
- Memory
/memory/drop:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/DropMemoryBankRequest'
required: true
responses:
'200':
content:
application/json:
schema:
type: string
description: OK
tags:
- Memory
/memory/get:
get:
parameters:
- in: query
name: bank_id
required: true
schema:
type: string
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/MemoryBank'
- type: 'null'
description: OK
tags:
- Memory
/memory/insert:
post:
parameters:
@ -3374,25 +3259,6 @@ paths:
description: OK
tags:
- Memory
/memory/list:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/jsonl:
schema:
$ref: '#/components/schemas/MemoryBank'
description: OK
tags:
- Memory
/memory/query:
post:
parameters:
@ -3443,10 +3309,10 @@ paths:
get:
parameters:
- in: query
name: bank_type
name: identifier
required: true
schema:
$ref: '#/components/schemas/MemoryBankType'
type: string
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
@ -3460,7 +3326,11 @@ paths:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/MemoryBankSpec'
- oneOf:
- $ref: '#/components/schemas/VectorMemoryBankDef'
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
- $ref: '#/components/schemas/KeywordMemoryBankDef'
- $ref: '#/components/schemas/GraphMemoryBankDef'
- type: 'null'
description: OK
tags:
@ -3480,15 +3350,40 @@ paths:
content:
application/jsonl:
schema:
$ref: '#/components/schemas/MemoryBankSpec'
oneOf:
- $ref: '#/components/schemas/VectorMemoryBankDef'
- $ref: '#/components/schemas/KeyValueMemoryBankDef'
- $ref: '#/components/schemas/KeywordMemoryBankDef'
- $ref: '#/components/schemas/GraphMemoryBankDef'
description: OK
tags:
- MemoryBanks
/memory_banks/register:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterMemoryBankRequest'
required: true
responses:
'200':
description: OK
tags:
- Memory
/models/get:
get:
parameters:
- in: query
name: core_model_id
name: identifier
required: true
schema:
type: string
@ -3505,7 +3400,7 @@ paths:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/ModelServingSpec'
- $ref: '#/components/schemas/ModelDef'
- type: 'null'
description: OK
tags:
@ -3525,7 +3420,7 @@ paths:
content:
application/jsonl:
schema:
$ref: '#/components/schemas/ModelServingSpec'
$ref: '#/components/schemas/ModelDef'
description: OK
tags:
- Models
@ -3760,6 +3655,27 @@ paths:
description: OK
tags:
- Inspect
/safety/register_shield:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterShieldRequest'
required: true
responses:
'200':
description: OK
tags:
- Shields
/safety/run_shield:
post:
parameters:
@ -3806,7 +3722,29 @@ paths:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/ShieldSpec'
- additionalProperties: false
properties:
identifier:
type: string
params:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
provider_id:
type: string
type:
type: string
required:
- identifier
- type
- params
type: object
- type: 'null'
description: OK
tags:
@ -3826,7 +3764,7 @@ paths:
content:
application/jsonl:
schema:
$ref: '#/components/schemas/ShieldSpec'
$ref: '#/components/schemas/ShieldDef'
description: OK
tags:
- Shields
@ -3905,21 +3843,21 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: Datasets
- name: Inspect
- name: Memory
- name: BatchInference
- name: Agents
- name: Datasets
- name: Inference
- name: Shields
- name: SyntheticDataGeneration
- name: Models
- name: RewardScoring
- name: MemoryBanks
- name: Safety
- name: Evaluations
- name: Telemetry
- name: Memory
- name: Safety
- name: PostTraining
- name: MemoryBanks
- name: Models
- name: Shields
- name: Inspect
- name: SyntheticDataGeneration
- name: Telemetry
- name: Agents
- name: RewardScoring
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
@ -4123,11 +4061,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateDatasetRequest"
/>
name: CreateDatasetRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateMemoryBankRequest"
/>
name: CreateMemoryBankRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBank" />
name: MemoryBank
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsRequest"
/>
name: DeleteAgentsRequest
@ -4140,9 +4073,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteDocumentsRequest"
/>
name: DeleteDocumentsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/DropMemoryBankRequest"
/>
name: DropMemoryBankRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/EmbeddingsRequest"
/>
name: EmbeddingsRequest
@ -4163,11 +4093,23 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/GetAgentsSessionRequest"
/>
name: GetAgentsSessionRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankDef"
/>
name: GraphMemoryBankDef
- description: <SchemaDefinition schemaRef="#/components/schemas/KeyValueMemoryBankDef"
/>
name: KeyValueMemoryBankDef
- description: <SchemaDefinition schemaRef="#/components/schemas/KeywordMemoryBankDef"
/>
name: KeywordMemoryBankDef
- description: 'A single session of an interaction with an Agentic System.
<SchemaDefinition schemaRef="#/components/schemas/Session" />'
name: Session
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBankDef"
/>
name: VectorMemoryBankDef
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse"
/>
name: AgentStepResponse
@ -4189,21 +4131,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJobStatusResponse"
/>
name: EvaluationJobStatusResponse
- description: 'The model family and SKU of the model along with other parameters
corresponding to the model.
<SchemaDefinition schemaRef="#/components/schemas/Model" />'
name: Model
- description: <SchemaDefinition schemaRef="#/components/schemas/ModelServingSpec"
/>
name: ModelServingSpec
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankType" />
name: MemoryBankType
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankSpec" />
name: MemoryBankSpec
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldSpec" />
name: ShieldSpec
- description: <SchemaDefinition schemaRef="#/components/schemas/ModelDef" />
name: ModelDef
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
name: Trace
- description: 'Checkpoint created during training runs
@ -4243,6 +4172,8 @@ tags:
name: ProviderInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
name: RouteInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldDef" />
name: ShieldDef
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
name: LogSeverity
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
@ -4282,6 +4213,15 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryDocumentsResponse"
/>
name: QueryDocumentsResponse
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterMemoryBankRequest"
/>
name: RegisterMemoryBankRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterModelRequest"
/>
name: RegisterModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterShieldRequest"
/>
name: RegisterShieldRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/DialogGenerations"
/>
name: DialogGenerations
@ -4387,7 +4327,6 @@ x-tagGroups:
- CreateAgentSessionRequest
- CreateAgentTurnRequest
- CreateDatasetRequest
- CreateMemoryBankRequest
- DPOAlignmentConfig
- DeleteAgentsRequest
- DeleteAgentsSessionRequest
@ -4395,7 +4334,6 @@ x-tagGroups:
- DeleteDocumentsRequest
- DialogGenerations
- DoraFinetuningConfig
- DropMemoryBankRequest
- EmbeddingsRequest
- EmbeddingsResponse
- EvaluateQuestionAnsweringRequest
@ -4409,22 +4347,21 @@ x-tagGroups:
- FunctionCallToolDefinition
- GetAgentsSessionRequest
- GetDocumentsRequest
- GraphMemoryBankDef
- HealthInfo
- ImageMedia
- InferenceStep
- InsertDocumentsRequest
- KeyValueMemoryBankDef
- KeywordMemoryBankDef
- LogEventRequest
- LogSeverity
- LoraFinetuningConfig
- MemoryBank
- MemoryBankDocument
- MemoryBankSpec
- MemoryBankType
- MemoryRetrievalStep
- MemoryToolDefinition
- MetricEvent
- Model
- ModelServingSpec
- ModelDef
- OptimizerConfig
- PhotogenToolDefinition
- PostTrainingJob
@ -4438,6 +4375,9 @@ x-tagGroups:
- QueryDocumentsRequest
- QueryDocumentsResponse
- RLHFAlgorithm
- RegisterMemoryBankRequest
- RegisterModelRequest
- RegisterShieldRequest
- RestAPIExecutionConfig
- RestAPIMethod
- RewardScoreRequest
@ -4453,7 +4393,7 @@ x-tagGroups:
- SearchToolDefinition
- Session
- ShieldCallStep
- ShieldSpec
- ShieldDef
- SpanEndPayload
- SpanStartPayload
- SpanStatus
@ -4483,5 +4423,6 @@ x-tagGroups:
- UnstructuredLogEvent
- UpdateDocumentsRequest
- UserMessage
- VectorMemoryBankDef
- ViolationLevel
- WolframAlphaToolDefinition

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

View file

@ -261,7 +261,7 @@ class Session(BaseModel):
turns: List[Turn]
started_at: datetime
memory_bank: Optional[MemoryBank] = None
memory_bank: Optional[MemoryBankDef] = None
class AgentConfigCommon(BaseModel):
@ -411,8 +411,10 @@ class Agents(Protocol):
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create")
async def create_agent_turn(
def create_agent_turn(
self,
agent_id: str,
session_id: str,

View file

@ -7,7 +7,7 @@
import asyncio
import json
import os
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
import fire
import httpx
@ -67,9 +67,17 @@ class AgentsClient(Agents):
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())
async def create_agent_turn(
def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
@ -93,6 +101,9 @@ class AgentsClient(Agents):
print(data)
print(f"Error with parsing or validation: {e}")
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
raise NotImplementedError("Non-streaming not implemented yet")
async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
@ -132,8 +143,7 @@ async def _run_agent(
log.print()
async def run_llama_3_1(host: str, port: int):
model = "Llama3.1-8B-Instruct"
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_llama_3_2_rag(host: str, port: int):
model = "Llama3.2-3B-Instruct"
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
urls = [
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
)
async def run_llama_3_2(host: str, port: int):
model = "Llama3.2-3B-Instruct"
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
)
def main(host: str, port: int, run_type: str):
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
"tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag,
}
asyncio.run(fn[run_type](host, port))
args = [host, port]
if model is not None:
args.append(model)
asyncio.run(fn[run_type](*args))
if __name__ == "__main__":

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -66,6 +66,29 @@ class InferenceClient(Inference):
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
response.raise_for_status()
j = response.json()
return ChatCompletionResponse(**j)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
@ -77,7 +100,8 @@ class InferenceClient(Inference):
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red"
f"Error: HTTP {response.status_code} {content.decode()}",
"red",
)
return
@ -85,40 +109,59 @@ class InferenceClient(Inference):
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if request.stream:
if "error" in data:
cprint(data, "red")
continue
if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
else:
yield ChatCompletionResponse(**json.loads(data))
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int, stream: bool):
async def run_main(
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
):
client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.1-8B-Instruct"
message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon"
)
cprint(f"User>{message.content}", "green")
if logprobs:
logprobs_config = LogProbConfig(
top_k=1,
)
else:
logprobs_config = None
iterator = client.chat_completion(
model="Llama3.1-8B-Instruct",
model=model,
messages=[message],
stream=stream,
logprobs=logprobs_config,
)
async for log in EventLogger().log(iterator):
log.print()
if logprobs:
async for chunk in iterator:
cprint(f"Response: {chunk}", "red")
else:
async for log in EventLogger().log(iterator):
log.print()
async def run_mm_main(host: str, port: int, stream: bool, path: str):
async def run_mm_main(
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
):
client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.2-11B-Vision-Instruct"
message = UserMessage(
content=[
ImageMedia(image=URL(uri=f"file://{path}")),
@ -127,7 +170,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct",
model=model,
messages=[message],
stream=stream,
)
@ -135,11 +178,19 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
def main(
host: str,
port: int,
stream: bool = True,
mm: bool = False,
logprobs: bool = False,
file: Optional[str] = None,
model: Optional[str] = None,
):
if mm:
asyncio.run(run_mm_main(host, port, stream, file))
asyncio.run(run_mm_main(host, port, stream, file, model))
else:
asyncio.run(run_main(host, port, stream))
asyncio.run(run_main(host, port, stream, model, logprobs))
if __name__ == "__main__":

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
class LogProbConfig(BaseModel):
@ -172,9 +173,17 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol):
model_store: ModelStore
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion")
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -183,8 +192,10 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -203,3 +214,6 @@ class Inference(Protocol):
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
@webmethod(route="/inference/register_model")
async def register_model(self, model: ModelDef) -> None: ...

View file

@ -12,15 +12,15 @@ from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
provider_id: str
provider_type: str
description: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
method: str
providers: List[str]
provider_types: List[str]
@json_schema_type

View file

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks.client import MemoryBanksClient
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -35,44 +35,16 @@ class MemoryClient(Memory):
async def shutdown(self) -> None:
pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{self.base_url}/memory/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory/create",
response = await client.post(
f"{self.base_url}/memory/register_memory_bank",
json={
"name": name,
"config": config.dict(),
"url": url,
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
response.raise_for_status()
async def insert_documents(
self,
@ -114,22 +86,20 @@ class MemoryClient(Memory):
async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}")
banks_client = MemoryBanksClient(f"http://{host}:{port}")
# create a memory bank
bank = await client.create_memory_bank(
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
bank = VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
cprint(json.dumps(bank.dict(), indent=4), "green")
await client.register_memory_bank(bank)
retrieved_bank = await client.get_memory_bank(bank.bank_id)
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
urls = [
"memory_optimizations.rst",
@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool):
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
query=[
"How do I use Lora?",
],
@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool):
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
query=[
"Tell me more about llama3 and torchtune",
],

View file

@ -13,9 +13,9 @@ from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
@json_schema_type
@ -26,44 +26,6 @@ class MemoryBankDocument(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
class KeyValueMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
class KeywordMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
@ -76,45 +38,12 @@ class QueryDocumentsResponse(BaseModel):
scores: List[float]
@json_schema_type
class QueryAPI(Protocol):
@webmethod(route="/query_documents")
def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@json_schema_type
class MemoryBank(BaseModel):
bank_id: str
name: str
config: MemoryBankConfig
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
url: Optional[URL] = None
class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
class Memory(Protocol):
@webmethod(route="/memory/create")
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank: ...
@webmethod(route="/memory/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory/drop", method="DELETE")
async def drop_memory_bank(
self,
bank_id: str,
) -> str: ...
memory_bank_store: MemoryBankStore
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@ -154,3 +83,6 @@ class Memory(Protocol):
bank_id: str,
document_ids: List[str],
) -> None: ...
@webmethod(route="/memory/register_memory_bank")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...

View file

@ -6,7 +6,7 @@
import asyncio
from typing import List, Optional
from typing import Any, Dict, List, Optional
import fire
import httpx
@ -15,6 +15,25 @@ from termcolor import cprint
from .memory_banks import * # noqa: F403
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
if j is None:
return None
if "type" not in j:
raise ValueError("Memory bank type not specified")
type = j["type"]
if type == MemoryBankType.vector.value:
return VectorMemoryBankDef(**j)
elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBankDef(**j)
elif type == MemoryBankType.keyword.value:
return KeywordMemoryBankDef(**j)
elif type == MemoryBankType.graph.value:
return GraphMemoryBankDef(**j)
else:
raise ValueError(f"Unknown memory bank type: {type}")
class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str):
self.base_url = base_url
@ -25,37 +44,36 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None:
pass
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
async def list_memory_banks(self) -> List[MemoryBankDef]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()]
return [deserialize_memory_bank_def(x) for x in response.json()]
async def get_serving_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]:
async def get_memory_bank(
self,
identifier: str,
) -> Optional[MemoryBankDef]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_type": bank_type.value,
"identifier": identifier,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return MemoryBankSpec(**j)
return deserialize_memory_bank_def(j)
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_available_memory_banks()
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")

View file

@ -4,29 +4,67 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
from typing_extensions import Annotated
@json_schema_type
class MemoryBankSpec(BaseModel):
bank_type: MemoryBankType
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
)
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class CommonDef(BaseModel):
identifier: str
provider_id: Optional[str] = None
@json_schema_type
class VectorMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
@json_schema_type
class KeywordMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
@json_schema_type
class GraphMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankDef = Annotated[
Union[
VectorMemoryBankDef,
KeyValueMemoryBankDef,
KeywordMemoryBankDef,
GraphMemoryBankDef,
],
Field(discriminator="type"),
]
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: ...
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_serving_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ...
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
@webmethod(route="/memory_banks/register", method="POST")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...

View file

@ -56,7 +56,7 @@ async def run_main(host: str, port: int, stream: bool):
response = await client.list_models()
cprint(f"list_models response={response}", "green")
response = await client.get_model("Meta-Llama3.1-8B-Instruct")
response = await client.get_model("Llama3.1-8B-Instruct")
cprint(f"get_model response={response}", "blue")
response = await client.get_model("Llama-Guard-3-1B")

View file

@ -6,27 +6,32 @@
from typing import List, Optional, Protocol
from llama_models.llama3.api.datatypes import Model
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ModelServingSpec(BaseModel):
llama_model: Model = Field(
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
class ModelDef(BaseModel):
identifier: str = Field(
description="A unique identifier for the model type",
)
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
llama_model: str = Field(
description="Pointer to the core Llama family model",
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this model"
)
# For now, we are only supporting core llama models but as soon as finetuned
# and other custom models (for example various quantizations) are allowed, there
# will be more metadata fields here
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[ModelServingSpec]: ...
async def list_models(self) -> List[ModelDef]: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
@webmethod(route="/models/register", method="POST")
async def register_model(self, model: ModelDef) -> None: ...

View file

@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
)
print(response)
response = await client.run_shield(
shield_type="injection_shield",
messages=[message],
)
print(response)
def main(host: str, port: int, image: str = None):
asyncio.run(run_main(host, port, image))

View file

@ -11,6 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
@json_schema_type
@ -37,8 +38,17 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...
@webmethod(route="/safety/register_shield")
async def register_shield(self, shield: ShieldDef) -> None: ...

View file

@ -4,25 +4,43 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ShieldSpec(BaseModel):
shield_type: str
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
class ShieldType(Enum):
generic_content_shield = "generic_content_shield"
llama_guard = "llama_guard"
code_scanner = "code_scanner"
prompt_guard = "prompt_guard"
class ShieldDef(BaseModel):
identifier: str = Field(
description="A unique identifier for the shield type",
)
type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum"
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this shield"
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional parameters needed for this shield",
)
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldSpec]: ...
async def list_shields(self) -> List[ShieldDef]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
@webmethod(route="/shields/register", method="POST")
async def register_shield(self, shield: ShieldDef) -> None: ...

View file

@ -158,19 +158,18 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
info = prompt_guard_download_info()
else:
model = resolve_model(args.model_id)
if model is None:
parser.error(f"Model {args.model_id} not found")
return
info = llama_meta_net_info(model)
if model is None:
parser.error(f"Model {args.model_id} not found")
return
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url
if not meta_url:
meta_url = input(
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
"Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
assert meta_url is not None and "llamameta.net" in meta_url
_meta_download(model, meta_url, info)

View file

@ -22,7 +22,7 @@ def available_templates_specs() -> List[BuildConfig]:
import yaml
template_specs = []
for p in TEMPLATES_PATH.rglob("*.yaml"):
for p in TEMPLATES_PATH.rglob("*build.yaml"):
with open(p, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
template_specs.append(build_config)
@ -105,8 +105,7 @@ class StackBuild(Subcommand):
import yaml
from llama_stack.distribution.build import ApiInput, build_image, ImageType
from llama_stack.distribution.build import build_image, ImageType
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.serialize import EnumEncoder
from termcolor import cprint
@ -137,16 +136,19 @@ class StackBuild(Subcommand):
if build_config.image_type == "conda"
else (f"llamastack-{build_config.name}")
)
cprint(
f"You can now run `llama stack configure {configure_name}`",
color="green",
)
if build_config.image_type == "conda":
cprint(
f"You can now run `llama stack configure {configure_name}`",
color="green",
)
else:
cprint(
f"You can now run `llama stack run {build_config.name}`",
color="green",
)
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json
import yaml
from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
@ -172,9 +174,11 @@ class StackBuild(Subcommand):
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import textwrap
import yaml
from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator
from termcolor import cprint
@ -238,26 +242,29 @@ class StackBuild(Subcommand):
)
cprint(
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs.
""",
),
color="green",
)
print("Tip: use <TAB> to see options for the providers.\n")
providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [
x for x in providers_for_api.keys() if x != "remote"
]
api_provider = prompt(
"> Enter provider for the {} API: (default=meta-reference): ".format(
api.value
),
"> Enter provider for API {}: ".format(api.value),
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in providers_for_api,
error_message="Invalid provider, please enter one of the following: {}".format(
list(providers_for_api.keys())
),
),
default=(
"meta-reference"
if "meta-reference" in providers_for_api
else list(providers_for_api.keys())[0]
lambda x: x in available_providers,
error_message="Invalid provider, use <TAB> to see options",
),
)

View file

@ -71,10 +71,8 @@ class StackConfigure(Subcommand):
conda_dir = (
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
)
output = subprocess.check_output(
["bash", "-c", "conda info --json -a | jq '.envs'"]
)
conda_envs = json.loads(output.decode("utf-8"))
output = subprocess.check_output(["bash", "-c", "conda info --json"])
conda_envs = json.loads(output.decode("utf-8"))["envs"]
for x in conda_envs:
if x.endswith(f"/llamastack-{args.config}"):
@ -129,7 +127,10 @@ class StackConfigure(Subcommand):
import yaml
from termcolor import cprint
from llama_stack.distribution.configure import configure_api_providers
from llama_stack.distribution.configure import (
configure_api_providers,
parse_and_maybe_upgrade_config,
)
from llama_stack.distribution.utils.serialize import EnumEncoder
builds_dir = BUILDS_BASE_DIR / build_config.image_type
@ -145,13 +146,17 @@ class StackConfigure(Subcommand):
"yellow",
attrs=["bold"],
)
config = StackRunConfig(**yaml.safe_load(run_config_file.read_text()))
config_dict = yaml.safe_load(run_config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
else:
config = StackRunConfig(
built_at=datetime.now(),
image_name=image_name,
apis_to_serve=[],
api_providers={},
apis=list(build_config.distribution_spec.providers.keys()),
providers={},
models=[],
shields=[],
memory_banks=[],
)
config = configure_api_providers(config, build_config.distribution_spec)

View file

@ -7,7 +7,6 @@
import argparse
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
class StackRun(Subcommand):
@ -46,10 +45,11 @@ class StackRun(Subcommand):
import pkg_resources
import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_with_pty
if not args.config:
@ -75,8 +75,10 @@ class StackRun(Subcommand):
)
return
cprint(f"Using config `{config_file}`", "green")
with open(config_file, "r") as f:
config = StackRunConfig(**yaml.safe_load(f))
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if config.docker_image:
script = pkg_resources.resource_filename(

View file

@ -1,105 +0,0 @@
from argparse import Namespace
from unittest.mock import MagicMock, patch
import pytest
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.cli.stack.build import StackBuild
# temporary while we make the tests work
pytest.skip(allow_module_level=True)
@pytest.fixture
def stack_build():
parser = MagicMock()
subparsers = MagicMock()
return StackBuild(subparsers)
def test_stack_build_initialization(stack_build):
assert stack_build.parser is not None
assert stack_build.parser.set_defaults.called_once_with(
func=stack_build._run_stack_build_command
)
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_with_config(
mock_build_image, mock_build_config, stack_build
):
args = Namespace(
config="test_config.yaml",
template=None,
list_templates=False,
name=None,
image_type="conda",
)
with patch("builtins.open", MagicMock()):
with patch("yaml.safe_load") as mock_yaml_load:
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()
@patch("llama_stack.cli.table.print_table")
def test_run_stack_build_command_list_templates(mock_print_table, stack_build):
args = Namespace(list_templates=True)
stack_build._run_stack_build_command(args)
mock_print_table.assert_called_once()
@patch("prompt_toolkit.prompt")
@patch("llama_stack.distribution.datatypes.BuildConfig")
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_interactive(
mock_build_image, mock_build_config, mock_prompt, stack_build
):
args = Namespace(
config=None, template=None, list_templates=False, name=None, image_type=None
)
mock_prompt.side_effect = [
"test_name",
"conda",
"meta-reference",
"test description",
]
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
assert mock_prompt.call_count == 4
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()
@patch("llama_stack.distribution.datatypes.BuildConfig")
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_with_template(
mock_build_image, mock_build_config, stack_build
):
args = Namespace(
config=None,
template="test_template",
list_templates=False,
name="test_name",
image_type="docker",
)
with patch("builtins.open", MagicMock()):
with patch("yaml.safe_load") as mock_yaml_load:
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()

View file

@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
import pytest
import yaml
from llama_stack.distribution.configure import (
LLAMA_STACK_RUN_CONFIG_VERSION,
parse_and_maybe_upgrade_config,
)
@pytest.fixture
def up_to_date_config():
return yaml.safe_load(
"""
version: {version}
image_name: foo
apis_to_serve: []
built_at: {built_at}
models:
- identifier: model1
provider_id: provider1
llama_model: Llama3.1-8B-Instruct
shields:
- identifier: shield1
type: llama_guard
provider_id: provider1
memory_banks:
- identifier: memory1
type: vector
provider_id: provider1
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
providers:
inference:
- provider_id: provider1
provider_type: meta-reference
config: {{}}
safety:
- provider_id: provider1
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- provider_id: provider1
provider_type: meta-reference
config: {{}}
""".format(
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
)
)
@pytest.fixture
def old_config():
return yaml.safe_load(
"""
image_name: foo
built_at: {built_at}
apis_to_serve: []
routing_table:
inference:
- provider_type: remote::ollama
config:
host: localhost
port: 11434
routing_key: Llama3.2-1B-Instruct
- provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
routing_key: Llama3.1-8B-Instruct
safety:
- routing_key: ["shield1", "shield2"]
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- routing_key: vector
provider_type: meta-reference
config: {{}}
api_providers:
telemetry:
provider_type: noop
config: {{}}
""".format(
built_at=datetime.now().isoformat()
)
)
@pytest.fixture
def invalid_config():
return yaml.safe_load(
"""
routing_table: {}
api_providers: {}
"""
)
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
result = parse_and_maybe_upgrade_config(up_to_date_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert len(result.models) == 1
assert len(result.shields) == 1
assert len(result.memory_banks) == 1
assert "inference" in result.providers
def test_parse_and_maybe_upgrade_config_old_format(old_config):
result = parse_and_maybe_upgrade_config(old_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert len(result.models) == 2
assert len(result.shields) == 2
assert len(result.memory_banks) == 1
assert all(
api in result.providers
for api in ["inference", "safety", "memory", "telemetry"]
)
safety_provider = result.providers["safety"][0]
assert safety_provider.provider_type == "meta-reference"
assert "llama_guard_shield" in safety_provider.config
inference_providers = result.providers["inference"]
assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == {
"remote::ollama-00",
"meta-reference-01",
}
ollama = inference_providers[0]
assert ollama.provider_type == "remote::ollama"
assert ollama.config["port"] == 11434
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
with pytest.raises(ValueError):
parse_and_maybe_upgrade_config(invalid_config)

View file

@ -8,15 +8,16 @@ from enum import Enum
from typing import List, Optional
import pkg_resources
from llama_stack.distribution.utils.exec import run_with_pty
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.distribution import get_provider_registry
@ -95,6 +96,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
build_config.name,
package_deps.docker_image,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.docker.value),
" ".join(deps),
]
else:

View file

@ -23,7 +23,7 @@ if [ "$#" -lt 3 ]; then
exit 1
fi
special_pip_deps="$3"
special_pip_deps="$4"
set -euo pipefail

View file

@ -10,7 +10,7 @@ if [ "$#" -lt 4 ]; then
exit 1
fi
special_pip_deps="$5"
special_pip_deps="$6"
set -euo pipefail
@ -18,7 +18,8 @@ build_name="$1"
image_name="llamastack-$build_name"
docker_base=$2
build_file_path=$3
pip_dependencies=$4
host_build_dir=$4
pip_dependencies=$5
# Define color codes
RED='\033[0;31m'
@ -33,7 +34,8 @@ REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
TEMP_DIR=$(mktemp -d)
llama stack configure $build_file_path --output-dir $REPO_CONFIGS_DIR
llama stack configure $build_file_path
cp $host_build_dir/$build_name-run.yaml $REPO_CONFIGS_DIR
add_to_docker() {
local input
@ -132,6 +134,9 @@ fi
set -x
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
# clean up tmp/configs
rm -rf $REPO_CONFIGS_DIR
set +x
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"

View file

@ -3,171 +3,369 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import textwrap
from typing import Any
from pydantic import BaseModel
from llama_models.sku_list import (
llama3_1_family,
llama3_2_family,
llama3_family,
resolve_model,
safety_models,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
stack_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.impls.meta_reference.safety.config import (
MetaReferenceShieldType,
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
ALLOWED_MODELS = (
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
)
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
config: config_class
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
if provider.config:
existing = config_type(**provider.config)
else:
existing = None
except Exception:
existing = None
return BaseModelWithConfig
cfg = prompt_for_config(config_type, existing)
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.dict(),
)
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
"""Get corresponding builtin APIs given provider backed APIs"""
res = []
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in provider_backed_apis:
res.append(inf.routing_table_api.value)
return res
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
config: StackRunConfig, build_spec: DistributionSpec
) -> StackRunConfig:
apis = config.apis_to_serve or list(spec.providers.keys())
# append the bulitin routing APIs
apis += get_builtin_apis(apis)
is_nux = len(config.providers) == 0
router_api2builtin_api = {
inf.router_api.value: inf.routing_table_api.value
for inf in builtin_automatically_routed_apis()
}
if is_nux:
print(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
"""
)
)
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
apis = [v.value for v in stack_apis()]
all_providers = get_provider_registry()
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
# configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys():
if api_str not in apis:
for api_str in apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue
if api not in provider_registry:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
api = Api(api_str)
p = spec.providers[api_str]
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
if isinstance(p, list):
existing_providers = config.providers.get(api_str, [])
if existing_providers:
cprint(
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
"yellow",
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
p = p[0]
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
try:
provider_config = config.api_providers.get(api_str)
if provider_config:
existing = config_type(**provider_config.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
if api_str in router_api2builtin_api:
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
routing_entries = []
if api_str == "inference":
if hasattr(cfg, "model"):
routing_key = cfg.model
else:
routing_key = prompt(
"> Please enter the supported model your provider has for inference: ",
default="Meta-Llama3.1-8B-Instruct",
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
updated_providers = []
for p in existing_providers:
print(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(
configure_single_provider(provider_registry[api], p)
)
if api_str == "safety":
# TODO: add support for other safety providers, and simplify safety provider config
if p == "meta-reference":
routing_entries.append(
RoutableProviderConfig(
routing_key=[s.value for s in MetaReferenceShieldType],
provider_type=p,
config=cfg.dict(),
)
)
else:
cprint(
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
"yellow",
attrs=["bold"],
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
)
if api_str == "memory":
bank_types = list([x.value for x in MemoryBankType])
routing_key = prompt(
"> Please enter the supported memory bank type your provider has for memory: ",
default="vector",
validator=Validator.from_callable(
lambda x: x in bank_types,
error_message="Invalid provider, please enter one of the following: {}".format(
bank_types
),
),
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
)
config.routing_table[api_str] = routing_entries
config.api_providers[api_str] = PlaceholderProviderConfig(
providers=p if isinstance(p, list) else [p]
)
print("")
else:
config.api_providers[api_str] = GenericProviderConfig(
provider_type=p,
config=cfg.dict(),
)
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
plist = plist if isinstance(plist, list) else [plist]
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
updated_providers = []
for i, provider_type in enumerate(plist):
print(f"> Configuring provider `({provider_type})`")
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(
f"{provider_type}-{i:02d}"
if len(plist) > 1
else provider_type
),
provider_type=provider_type,
config={},
),
)
)
print("")
config.providers[api_str] = updated_providers
if is_nux:
print(
textwrap.dedent(
"""
=========================================================================================
Now let's configure the `objects` you will be serving via the stack. These are:
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
This wizard will guide you through setting up one of each of these objects. You can
always add more later by editing the run.yaml file.
"""
)
)
object_types = {
"models": (ModelDef, configure_models, "inference"),
"shields": (ShieldDef, configure_shields, "safety"),
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
}
safety_providers = config.providers.get("safety", [])
for otype, (odef, config_method, api_str) in object_types.items():
existing_objects = getattr(config, otype)
if existing_objects:
cprint(
f"{len(existing_objects)} {otype} exist. Skipping...",
"blue",
attrs=["bold"],
)
updated_objects = existing_objects
else:
providers = config.providers.get(api_str, [])
if not providers:
updated_objects = []
else:
# we are newly configuring this API
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
updated_objects = config_method(
config.providers[api_str], safety_providers
)
setattr(config, otype, updated_objects)
print("")
return config
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
if not safety_providers:
return None
provider = safety_providers[0]
assert provider.provider_type == "meta-reference"
cfg = provider.config["llama_guard_shield"]
if not cfg:
return None
return cfg["model"]
def configure_models(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ModelDef]:
model = prompt(
"> Please enter the model you want to serve: ",
default="Llama3.2-1B-Instruct",
validator=Validator.from_callable(
lambda x: resolve_model(x) is not None,
error_message="Model must be: {}".format(
[x.descriptor() for x in ALLOWED_MODELS]
),
),
)
model = ModelDef(
identifier=model,
llama_model=model,
provider_id=providers[0].provider_id,
)
ret = [model]
if llama_guard := get_llama_guard_model(safety_providers):
ret.append(
ModelDef(
identifier=llama_guard,
llama_model=llama_guard,
provider_id=providers[0].provider_id,
)
)
return ret
def configure_shields(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ShieldDef]:
if get_llama_guard_model(safety_providers):
return [
ShieldDef(
identifier="llama_guard",
type="llama_guard",
provider_id=providers[0].provider_id,
params={},
)
]
return []
def configure_memory_banks(
providers: List[Provider], safety_providers: List[Provider]
) -> List[MemoryBankDef]:
bank_name = prompt(
"> Please enter a name for your memory bank: ",
default="my-memory-bank",
)
return [
VectorMemoryBankDef(
identifier=bank_name,
provider_id=providers[0].provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
]
def upgrade_from_routing_table_to_registry(
config_dict: Dict[str, Any],
) -> Dict[str, Any]:
def get_providers(entries):
return [
Provider(
provider_id=(
f"{entry['provider_type']}-{i:02d}"
if len(entries) > 1
else entry["provider_type"]
),
provider_type=entry["provider_type"],
config=entry["config"],
)
for i, entry in enumerate(entries)
]
providers_by_api = {}
models = []
shields = []
memory_banks = []
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
if api_str == "inference":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
models.append(
ModelDef(
identifier=key,
provider_id=provider.provider_id,
llama_model=key,
)
)
elif api_str == "safety":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
shields.append(
ShieldDef(
identifier=key,
type=ShieldType.llama_guard.value,
provider_id=provider.provider_id,
)
)
elif api_str == "memory":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
# we currently only support Vector memory banks so this is OK
memory_banks.append(
VectorMemoryBankDef(
identifier=key,
provider_id=provider.provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
)
config_dict["models"] = models
config_dict["shields"] = shields
config_dict["memory_banks"] = memory_banks
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
if isinstance(provider, dict) and "provider_type" in provider:
providers_by_api[api_str] = [
Provider(
provider_id=f"{provider['provider_type']}",
provider_type=provider["provider_type"],
config=provider["config"],
)
]
config_dict["providers"] = providers_by_api
config_dict.pop("routing_table", None)
config_dict.pop("api_providers", None)
config_dict.pop("provider_map", None)
config_dict["apis"] = config_dict["apis_to_serve"]
config_dict.pop("apis_to_serve", None)
return config_dict
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict)
if "models" not in config_dict:
print("Upgrading config...")
config_dict = upgrade_from_routing_table_to_registry(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
config_dict["built_at"] = datetime.now().isoformat()
return StackRunConfig(**config_dict)

View file

@ -11,28 +11,32 @@ from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]]
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
RoutableObject = Union[
ModelDef,
ShieldDef,
MemoryBankDef,
]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
RoutedProtocol = Union[
Inference,
Safety,
Memory,
]
# Example: /inference, /safety
@ -53,18 +57,17 @@ class AutoRoutedProviderSpec(ProviderSpec):
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
router_api: Api
registry: List[RoutableObject]
module: str
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class DistributionSpec(BaseModel):
description: Optional[str] = Field(
default="",
@ -80,7 +83,12 @@ in the runtime configuration to help route to the correct provider.""",
)
@json_schema_type
class Provider(BaseModel):
provider_id: str
provider_type: str
config: Dict[str, Any]
class StackRunConfig(BaseModel):
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
built_at: datetime
@ -100,36 +108,39 @@ this could be just a hash
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
apis_to_serve: List[str] = Field(
apis: List[str] = Field(
default_factory=list,
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
api_providers: Dict[
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
] = Field(
providers: Dict[str, List[Provider]] = Field(
description="""
Provider configurations for each of the APIs provided by this package.
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
default_factory=dict,
description="""
E.g. The following is a ProviderRoutingEntry for models:
- routing_key: Meta-Llama3.1-8B-Instruct
provider_type: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
""",
models: List[ModelDef] = Field(
description="""
List of model definitions to serve. This list may get extended by
/models/register API calls at runtime.
""",
)
shields: List[ShieldDef] = Field(
description="""
List of shield definitions to serve. This list may get extended by
/shields/register API calls at runtime.
""",
)
memory_banks: List[MemoryBankDef] = Field(
description="""
List of memory bank definitions to serve. This list may get extended by
/memory_banks/register API calls at runtime.
""",
)
@json_schema_type
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
name: str

View file

@ -6,45 +6,58 @@
from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
class DistributionInspectConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = DistributionInspectImpl(config, deps)
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect):
def __init__(self):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
run_config = self.config.run_config
ret = {}
all_providers = get_provider_registry()
for api, providers in all_providers.items():
ret[api.value] = [
for api, providers in run_config.providers.items():
ret[api] = [
ProviderInfo(
provider_id=p.provider_id,
provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
)
for p in providers.values()
for p in providers
]
return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
run_config = self.config.run_config
ret = {}
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret[api.value] = [
RouteInfo(
route=e.route,
method=e.method,
providers=[],
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
]

View file

@ -13,138 +13,207 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.inspect import DistributionInspectImpl
from llama_stack.distribution.utils.dynamic import instantiate_class_type
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = get_provider_registry()
specs = {}
configs = {}
all_api_providers = get_provider_registry()
for api_str, config in run_config.api_providers.items():
api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# skip checks for API whose provider config is specified in routing_table
if isinstance(config, PlaceholderProviderConfig):
continue
if config.provider_type not in providers:
raise ValueError(
f"Provider `{config.provider_type}` is not available for API `{api}`"
)
specs[api] = providers[config.provider_type]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_table.keys())
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
providers_with_specs = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
specs = {}
for provider in providers:
if provider.provider_type not in all_api_providers[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = all_api_providers[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.dict()),
)
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
)
for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
routing_table = run_config.routing_table[info.router_api.value]
providers = all_providers[info.router_api]
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_type not in providers:
registry = getattr(run_config, info.routing_table_api.value)
for entry in registry:
if entry.provider_id not in available_providers:
raise ValueError(
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
)
inner_specs.append(providers[rt_entry.provider_type])
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
inner_specs=inner_specs,
provider = available_providers[entry.provider_id]
inner_deps.extend(provider.spec.api_dependencies)
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__builtin__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
registry=registry,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
deps__=(
[x.value for x in inner_deps]
+ [f"inner-{info.router_api.value}"]
),
),
)
}
providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__builtin__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
)
}
sorted_providers = topological_sort(
{k: v.values() for k, v in providers_with_specs.items()}
)
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
"inspect",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={
"run_config": run_config.dict(),
},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=([x.value for x in apis]),
),
),
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_type}")
print("")
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, configs[api])
impls[api] = impl
impls[Api.inspect] = DistributionInspectImpl()
specs[Api.inspect] = InlineProviderSpec(
api=Api.inspect,
provider_type="__distribution_builtin__",
config_class="",
module="",
)
return impls, specs
print(f"Resolved {len(sorted_providers)} providers in topological order")
for api_str, provider in sorted_providers:
print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}")
print("")
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider(
provider,
deps,
inner_impls,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
return impls
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[ProviderWithSpec]:
def dfs(kv, visited: Set[str], stack: List[str]):
api_str, providers = kv
visited.add(api_str)
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
deps = []
for provider in providers:
for dep in provider.spec.deps__:
deps.append(dep)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
for dep in deps:
if dep not in visited:
dfs((dep, providers_with_specs[dep]), visited, stack)
stack.append(a.api)
stack.append(api_str)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
dfs((api_str, providers), visited, stack)
return [by_id[x] for x in stack]
flattened = []
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
provider: ProviderWithSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
inner_impls: Dict[str, Any],
):
provider_spec = provider.spec
module = importlib.import_module(provider_spec.module)
args = []
@ -154,9 +223,8 @@ async def instantiate_provider(
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
config = config_type(**provider.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
@ -166,31 +234,18 @@ async def instantiate_provider(
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
config = config_type(**provider.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -4,23 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from typing import Any, List
from llama_stack.distribution.datatypes import * # noqa: F403
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
) -> Any:
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
@ -29,7 +28,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](inner_impls, routing_table_config)
impl = api_to_tables[api.value](registry, impls_by_provider_id)
await impl.initialize()
return impl

View file

@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type"""
"""Routes to an provider based on the memory bank identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
@ -29,32 +28,8 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None:
pass
def get_provider_from_bank_id(self, bank_id: str) -> Any:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
name, config, url
)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(memory_bank)
async def insert_documents(
self,
@ -62,7 +37,7 @@ class MemoryRouter(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.get_provider_from_bank_id(bank_id).insert_documents(
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds
)
@ -72,7 +47,7 @@ class MemoryRouter(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.get_provider_from_bank_id(bank_id).query_documents(
return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params
)
@ -92,7 +67,10 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None:
pass
async def chat_completion(
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
def chat_completion(
self,
model: str,
messages: List[Message],
@ -113,27 +91,32 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
**params
):
yield chunk
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in provider.chat_completion(**params))
else:
return provider.chat_completion(**params)
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(model).completion(
) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model)
params = dict(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in provider.completion(**params))
else:
return provider.completion(**params)
async def embeddings(
self,
@ -159,6 +142,9 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield)
async def run_shield(
self,
shield_type: str,

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -16,129 +15,129 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.inference:
await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[RoutingKey, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
) -> None:
self.unique_providers = []
self.providers = {}
self.routing_keys = []
for obj in registry:
if obj.provider_id not in impls_by_provider_id:
print(f"{impls_by_provider_id=}")
raise ValueError(
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
)
for key, impl in inner_impls:
keys = key if isinstance(key, list) else [key]
self.unique_providers.append((keys, impl))
self.impls_by_provider_id = impls_by_provider_id
self.registry = registry
for k in keys:
if k in self.providers:
raise ValueError(f"Duplicate routing key {k}")
self.providers[k] = impl
self.routing_keys.append(k)
for p in self.impls_by_provider_id.values():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
self.routing_table_config = routing_table_config
self.routing_key_to_object = {}
for obj in self.registry:
self.routing_key_to_object[obj.identifier] = obj
async def initialize(self) -> None:
for keys, p in self.unique_providers:
spec = p.__provider_spec__
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
continue
await p.validate_routing_keys(keys)
for obj in self.registry:
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
async def shutdown(self) -> None:
for _, p in self.unique_providers:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.providers:
raise ValueError(f"Could not find provider for {routing_key}")
return self.providers[routing_key]
if routing_key not in self.routing_key_to_object:
raise ValueError(f"`{routing_key}` not registered")
def get_routing_keys(self) -> List[str]:
return self.routing_keys
obj = self.routing_key_to_object[routing_key]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config:
if entry.routing_key == routing_key:
return entry
return self.impls_by_provider_id[obj.provider_id]
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
for obj in self.registry:
if obj.identifier == identifier:
return obj
return None
async def register_object(self, obj: RoutableObject):
if obj.identifier in self.routing_key_to_object:
print(f"`{obj.identifier}` is already registered")
return
if not obj.provider_id:
provider_ids = list(self.impls_by_provider_id.keys())
if not provider_ids:
raise ValueError("No providers found")
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
obj.provider_id = provider_ids[0]
else:
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj)
# TODO: persist this to a store
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDef]:
return self.registry
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=entry,
)
)
return specs
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return self.get_object_by_identifier(identifier)
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
provider_config=entry,
)
return None
async def register_model(self, model: ModelDef) -> None:
await self.register_object(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]:
return self.registry
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config:
if isinstance(entry.routing_key, list):
for k in entry.routing_key:
specs.append(
ShieldSpec(
shield_type=k,
provider_config=entry,
)
)
else:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
return self.get_object_by_identifier(shield_type)
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
return None
async def register_shield(self, shield: ShieldDef) -> None:
await self.register_object(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.registry
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
return self.get_object_by_identifier(identifier)
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
return None
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object(bank)

View file

@ -5,18 +5,15 @@
# the root directory of this source tree.
import asyncio
import functools
import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
from typing import Any, Dict, Optional
import fire
import httpx
@ -43,20 +40,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
@ -169,11 +152,20 @@ async def passthrough(
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(*args, **kwargs):
def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
async def run_shutdown():
for impl in app.__llama_stack_impls__.values():
print(f"Shutting down {impl}")
await impl.shutdown()
asyncio.run(run_shutdown())
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@ -181,7 +173,10 @@ def handle_sigint(*args, **kwargs):
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
def create_dynamic_passthrough(
@ -193,65 +188,59 @@ def create_dynamic_passthrough(
return endpoint
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
finally:
await end_trace()
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints.get("return")
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
if is_streaming:
set_request_provider_data(request.headers)
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
finally:
await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
try:
return (
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func)
new_params = [
@ -285,29 +274,25 @@ def main(
app = FastAPI()
impls, specs = asyncio.run(resolve_impls_with_routing(config))
impls = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
all_endpoints = get_all_api_endpoints()
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
apis_to_serve.add(Api.inspect)
apis_to_serve.add("inspect")
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
provider_spec = specs[api]
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
if is_passthrough(impl.__provider_spec__):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
@ -337,7 +322,9 @@ def main(
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
app.__llama_stack_impls__ = impls
import uvicorn

View file

@ -1,8 +1,9 @@
built_at: '2024-09-30T09:04:30.533391'
version: '2'
built_at: '2024-10-08T17:42:07.505267'
image_name: local-cpu
docker_image: local-cpu
conda_env: null
apis_to_serve:
apis:
- agents
- inference
- models
@ -10,40 +11,48 @@ apis_to_serve:
- safety
- shields
- memory_banks
api_providers:
providers:
inference:
providers:
- remote::ollama
- provider_id: remote::ollama
provider_type: remote::ollama
config:
host: localhost
port: 6000
safety:
providers:
- meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
memory:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_type: remote::ollama
config:
host: localhost
port: 6000
routing_key: Meta-Llama3.1-8B-Instruct
safety:
- provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_type: meta-reference
config: {}
routing_key: vector
models:
- identifier: Llama3.1-8B-Instruct
llama_model: Llama3.1-8B-Instruct
provider_id: remote::ollama
shields:
- identifier: llama_guard
type: llama_guard
provider_id: meta-reference
params: {}
memory_banks:
- identifier: vector
provider_id: meta-reference
type: vector
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
overlap_size_in_tokens: null

View file

@ -1,8 +1,9 @@
built_at: '2024-09-30T09:00:56.693751'
version: '2'
built_at: '2024-10-08T17:42:33.690666'
image_name: local-gpu
docker_image: local-gpu
conda_env: null
apis_to_serve:
apis:
- memory
- inference
- agents
@ -10,43 +11,51 @@ apis_to_serve:
- safety
- models
- memory_banks
api_providers:
providers:
inference:
providers:
- meta-reference
safety:
providers:
- meta-reference
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
telemetry:
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
routing_key: Llama3.1-8B-Instruct
safety:
- provider_type: meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_type: meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_key: vector
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
models:
- identifier: Llama3.1-8B-Instruct
llama_model: Llama3.1-8B-Instruct
provider_id: meta-reference
shields:
- identifier: llama_guard
type: llama_guard
provider_id: meta-reference
params: {}
memory_banks:
- identifier: vector
provider_id: meta-reference
type: vector
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
overlap_size_in_tokens: null

View file

@ -0,0 +1,10 @@
name: local-databricks
distribution_spec:
description: Use Databricks for running LLM inference
providers:
inference: remote::databricks
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-vllm
distribution_spec:
description: Like local, but use vLLM for running LLM inference
providers:
inference: vllm
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,445 +1,445 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import * # noqa: F403
import boto3
from botocore.client import BaseClient
from botocore.config import Config
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
BEDROCK_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
}
class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
@staticmethod
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)
def __init__(self, config: BedrockConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
)
self._config = config
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> BaseClient:
return self._client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
@staticmethod
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
if bedrock_stop_reason == "max_tokens":
return StopReason.out_of_tokens
return StopReason.end_of_turn
@staticmethod
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
for builtin_tool in BuiltinTool:
if builtin_tool.value == tool_name_str:
return builtin_tool
else:
return tool_name_str
@staticmethod
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
converse_api_res["stopReason"]
)
bedrock_message = converse_api_res["output"]["message"]
role = bedrock_message["role"]
contents = bedrock_message["content"]
tool_calls = []
text_content = []
for content in contents:
if "toolUse" in content:
tool_use = content["toolUse"]
tool_calls.append(
ToolCall(
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
tool_use["name"]
),
arguments=tool_use["input"] if "input" in tool_use else None,
call_id=tool_use["toolUseId"],
)
)
elif "text" in content:
text_content.append(content["text"])
return CompletionMessage(
role=role,
content=text_content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
@staticmethod
def _messages_to_bedrock_messages(
messages: List[Message],
) -> Tuple[List[Dict], Optional[List[Dict]]]:
bedrock_messages = []
system_bedrock_messages = []
user_contents = []
assistant_contents = None
for message in messages:
role = message.role
content_list = (
message.content
if isinstance(message.content, list)
else [message.content]
)
if role == "ipython" or role == "user":
if not user_contents:
user_contents = []
if role == "ipython":
user_contents.extend(
[
{
"toolResult": {
"toolUseId": message.call_id,
"content": [
{"text": content} for content in content_list
],
}
}
]
)
else:
user_contents.extend(
[{"text": content} for content in content_list]
)
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
assistant_contents = None
elif role == "system":
system_bedrock_messages.extend(
[{"text": content} for content in content_list]
)
elif role == "assistant":
if not assistant_contents:
assistant_contents = []
assistant_contents.extend(
[
{
"text": content,
}
for content in content_list
]
+ [
{
"toolUse": {
"input": tool_call.arguments,
"name": (
tool_call.tool_name
if isinstance(tool_call.tool_name, str)
else tool_call.tool_name.value
),
"toolUseId": tool_call.call_id,
}
}
for tool_call in message.tool_calls
]
)
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
user_contents = None
else:
# Unknown role
pass
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
if system_bedrock_messages:
return bedrock_messages, system_bedrock_messages
return bedrock_messages, None
@staticmethod
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
inference_config = {}
if sampling_params:
param_mapping = {
"max_tokens": "maxTokens",
"temperature": "temperature",
"top_p": "topP",
}
for k, v in param_mapping.items():
if getattr(sampling_params, k):
inference_config[v] = getattr(sampling_params, k)
return inference_config
@staticmethod
def _tool_parameters_to_input_schema(
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
) -> Dict:
input_schema = {"type": "object"}
if not tool_parameters:
return input_schema
json_properties = {}
required = []
for name, param in tool_parameters.items():
json_property = {
"type": param.param_type,
}
if param.description:
json_property["description"] = param.description
if param.required:
required.append(name)
json_properties[name] = json_property
input_schema["properties"] = json_properties
if required:
input_schema["required"] = required
return input_schema
@staticmethod
def _tools_to_tool_config(
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
) -> Optional[Dict]:
if not tools:
return None
bedrock_tools = []
for tool in tools:
tool_name = (
tool.tool_name
if isinstance(tool.tool_name, str)
else tool.tool_name.value
)
tool_spec = {
"toolSpec": {
"name": tool_name,
"inputSchema": {
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
tool.parameters
),
},
}
}
if tool.description:
tool_spec["toolSpec"]["description"] = tool.description
bedrock_tools.append(tool_spec)
tool_config = {
"tools": bedrock_tools,
}
if tool_choice:
tool_config["toolChoice"] = (
{"any": {}}
if tool_choice.value == ToolChoice.required
else {"auto": {}}
)
return tool_config
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> (
AsyncGenerator
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
bedrock_model = self.map_to_provider_model(model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
}
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
if not stream:
converse_api_res = self.client.converse(**converse_api_params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res
)
yield ChatCompletionResponse(
completion_message=output_message,
logprobs=None,
)
else:
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
event_stream = converse_stream_api_res["stream"]
for chunk in event_stream:
if "messageStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
elif "contentBlockStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"][
"name"
],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
),
parse_status=ToolCallParseStatus.started,
),
)
)
elif "contentBlockDelta" in chunk:
if "text" in chunk["contentBlockDelta"]["delta"]:
delta = chunk["contentBlockDelta"]["delta"]["text"]
else:
delta = ToolCallDelta(
content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"][
"toolUse"
]["input"]
),
parse_status=ToolCallParseStatus.success,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
)
)
elif "contentBlockStop" in chunk:
# Ignored
pass
elif "messageStop" in chunk:
stop_reason = (
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"]
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
elif "metadata" in chunk:
# Ignored
pass
else:
# Ignored
pass
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import * # noqa: F403
import boto3
from botocore.client import BaseClient
from botocore.config import Config
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
BEDROCK_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
}
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
@staticmethod
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
)
self._config = config
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> BaseClient:
return self._client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
@staticmethod
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
if bedrock_stop_reason == "max_tokens":
return StopReason.out_of_tokens
return StopReason.end_of_turn
@staticmethod
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
for builtin_tool in BuiltinTool:
if builtin_tool.value == tool_name_str:
return builtin_tool
else:
return tool_name_str
@staticmethod
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
converse_api_res["stopReason"]
)
bedrock_message = converse_api_res["output"]["message"]
role = bedrock_message["role"]
contents = bedrock_message["content"]
tool_calls = []
text_content = []
for content in contents:
if "toolUse" in content:
tool_use = content["toolUse"]
tool_calls.append(
ToolCall(
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
tool_use["name"]
),
arguments=tool_use["input"] if "input" in tool_use else None,
call_id=tool_use["toolUseId"],
)
)
elif "text" in content:
text_content.append(content["text"])
return CompletionMessage(
role=role,
content=text_content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
@staticmethod
def _messages_to_bedrock_messages(
messages: List[Message],
) -> Tuple[List[Dict], Optional[List[Dict]]]:
bedrock_messages = []
system_bedrock_messages = []
user_contents = []
assistant_contents = None
for message in messages:
role = message.role
content_list = (
message.content
if isinstance(message.content, list)
else [message.content]
)
if role == "ipython" or role == "user":
if not user_contents:
user_contents = []
if role == "ipython":
user_contents.extend(
[
{
"toolResult": {
"toolUseId": message.call_id,
"content": [
{"text": content} for content in content_list
],
}
}
]
)
else:
user_contents.extend(
[{"text": content} for content in content_list]
)
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
assistant_contents = None
elif role == "system":
system_bedrock_messages.extend(
[{"text": content} for content in content_list]
)
elif role == "assistant":
if not assistant_contents:
assistant_contents = []
assistant_contents.extend(
[
{
"text": content,
}
for content in content_list
]
+ [
{
"toolUse": {
"input": tool_call.arguments,
"name": (
tool_call.tool_name
if isinstance(tool_call.tool_name, str)
else tool_call.tool_name.value
),
"toolUseId": tool_call.call_id,
}
}
for tool_call in message.tool_calls
]
)
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
user_contents = None
else:
# Unknown role
pass
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
if system_bedrock_messages:
return bedrock_messages, system_bedrock_messages
return bedrock_messages, None
@staticmethod
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
inference_config = {}
if sampling_params:
param_mapping = {
"max_tokens": "maxTokens",
"temperature": "temperature",
"top_p": "topP",
}
for k, v in param_mapping.items():
if getattr(sampling_params, k):
inference_config[v] = getattr(sampling_params, k)
return inference_config
@staticmethod
def _tool_parameters_to_input_schema(
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
) -> Dict:
input_schema = {"type": "object"}
if not tool_parameters:
return input_schema
json_properties = {}
required = []
for name, param in tool_parameters.items():
json_property = {
"type": param.param_type,
}
if param.description:
json_property["description"] = param.description
if param.required:
required.append(name)
json_properties[name] = json_property
input_schema["properties"] = json_properties
if required:
input_schema["required"] = required
return input_schema
@staticmethod
def _tools_to_tool_config(
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
) -> Optional[Dict]:
if not tools:
return None
bedrock_tools = []
for tool in tools:
tool_name = (
tool.tool_name
if isinstance(tool.tool_name, str)
else tool.tool_name.value
)
tool_spec = {
"toolSpec": {
"name": tool_name,
"inputSchema": {
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
tool.parameters
),
},
}
}
if tool.description:
tool_spec["toolSpec"]["description"] = tool.description
bedrock_tools.append(tool_spec)
tool_config = {
"tools": bedrock_tools,
}
if tool_choice:
tool_config["toolChoice"] = (
{"any": {}}
if tool_choice.value == ToolChoice.required
else {"auto": {}}
)
return tool_config
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> (
AsyncGenerator
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
bedrock_model = self.map_to_provider_model(model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
}
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
if not stream:
converse_api_res = self.client.converse(**converse_api_params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res
)
yield ChatCompletionResponse(
completion_message=output_message,
logprobs=None,
)
else:
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
event_stream = converse_stream_api_res["stream"]
for chunk in event_stream:
if "messageStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
elif "contentBlockStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"][
"name"
],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
),
parse_status=ToolCallParseStatus.started,
),
)
)
elif "contentBlockDelta" in chunk:
if "text" in chunk["contentBlockDelta"]["delta"]:
delta = chunk["contentBlockDelta"]["delta"]["text"]
else:
delta = ToolCallDelta(
content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"][
"toolUse"
]["input"]
),
parse_status=ToolCallParseStatus.success,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
)
)
elif "contentBlockStop" in chunk:
# Ignored
pass
elif "messageStop" in chunk:
stop_reason = (
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"]
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
elif "metadata" in chunk:
# Ignored
pass
else:
# Ignored
pass

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(
config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class DatabricksImplConfig(BaseModel):
url: str = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: str = Field(
default=None,
description="The Databricks API token",
)

View file

@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
}
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request),
}

View file

@ -10,14 +10,19 @@ from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import FireworksImplConfig
@ -27,21 +32,18 @@ FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
}
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: FireworksImplConfig) -> None:
RoutableProviderForModels.__init__(
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> Fireworks:
return Fireworks(api_key=self.config.api_key)
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -49,7 +51,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
async def shutdown(self) -> None:
pass
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -59,27 +61,7 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
fireworks_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
fireworks_messages.append({"role": role, "content": message.content})
return fireworks_messages
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -101,147 +83,41 @@ class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request)
fireworks_model = self.map_to_provider_model(request.model)
if not request.stream:
r = await self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_chat_completion(request, client)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
return self._nonstream_chat_completion(request, client)
buffer = ""
ipython = False
stop_reason = None
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_chat_completion_response(request, r, self.formatter)
async for chunk in self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(messages),
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
text = chunk.choices[0].delta.content
if text is None:
continue
stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt = chat_completion_request_to_prompt(request, self.formatter)
# Fireworks always prepends with BOS
if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
options = get_sampling_options(request)
options.setdefault("max_tokens", 512)
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
"stream": request.stream,
**options,
}

View file

@ -7,6 +7,10 @@
from llama_stack.distribution.datatypes import RemoteProviderConfig
class OllamaImplConfig(RemoteProviderConfig):
port: int = 11434
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter

View file

@ -9,34 +9,36 @@ from typing import AsyncGenerator
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
# "Llama3.1-8B-Instruct": "llama3.1",
OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
}
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
)
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> AsyncClient:
@ -54,7 +56,29 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
async def shutdown(self) -> None:
pass
async def completion(
async def register_model(self, model: ModelDef) -> None:
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}"
)
ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier]
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}")
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -64,32 +88,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
ollama_messages.append({"role": role, "content": message.content})
return ollama_messages
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -110,156 +109,54 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
if not request.stream:
r = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=False,
options=options,
)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r["message"]["content"], stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
if stream:
return self._stream_chat_completion(request)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
return self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request),
"raw": True,
"stream": request.stream,
}
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
)
stream = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=True,
options=options,
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk["message"]["content"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleInferenceImpl(Inference, RoutableProvider):
class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig):
self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
async def register_model(self, model: ModelDef) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel):
@json_schema_type
class InferenceAPIImplConfig(BaseModel):
model_id: str = Field(
huggingface_repo: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
)
api_token: Optional[str] = Field(

View file

@ -10,14 +10,19 @@ from typing import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.distribution.datatypes import RoutableProvider
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
@ -25,24 +30,31 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
logger = logging.getLogger(__name__)
class _HfAdapter(Inference, RoutableProvider):
class _HfAdapter(Inference):
client: AsyncInferenceClient
max_tokens: int
model_id: str
def __init__(self) -> None:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
self.formatter = ChatFormat(Tokenizer.get_instance())
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def register_model(self, model: ModelDef) -> None:
resolved_model = resolve_model(model.identifier)
if resolved_model is None:
raise ValueError(f"Unknown model: {model.identifier}")
if not resolved_model.huggingface_repo:
raise ValueError(
f"Model {model.identifier} does not have a HuggingFace repo"
)
if self.model_id != resolved_model.huggingface_repo:
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
async def shutdown(self) -> None:
pass
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -52,16 +64,7 @@ class _HfAdapter(Inference, RoutableProvider):
) -> AsyncGenerator:
raise NotImplementedError()
def get_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -83,146 +86,64 @@ class _HfAdapter(Inference, RoutableProvider):
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
input_tokens = len(model_input.tokens)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
choice = OpenAICompatCompletionChoice(text=token_result.text)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
print(f"Calculated max_new_tokens: {max_new_tokens}")
options = self.get_chat_options(request)
if not request.stream:
response = await self.client.text_generation(
prompt=prompt,
stream=False,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
)
stop_reason = None
if response.details.finish_reason:
if response.details.finish_reason in ["stop", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif response.details.finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
response.generated_text,
stop_reason,
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
tokens = []
async for response in await self.client.text_generation(
prompt=prompt,
stream=True,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
):
token_result = response.token
buffer += token_result.text
tokens.append(token_result.id)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# parse tool calls and report errors
message = self.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
options = get_sampling_options(request)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
)
class TGIAdapter(_HfAdapter):
@ -236,7 +157,7 @@ class TGIAdapter(_HfAdapter):
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(
model=config.model_id, token=config.api_token
model=config.huggingface_repo, token=config.api_token
)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]

View file

@ -8,17 +8,22 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from .config import TogetherImplConfig
@ -34,19 +39,14 @@ TOGETHER_SUPPORTED_MODELS = {
class TogetherInferenceAdapter(
Inference, NeedsRequestProviderData, RoutableProviderForModels
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: TogetherImplConfig) -> None:
RoutableProviderForModels.__init__(
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> Together:
return Together(api_key=self.config.api_key)
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -64,27 +64,7 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_together_messages(self, messages: list[Message]) -> list:
together_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
together_messages.append({"role": role, "content": message.content})
return together_messages
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -95,7 +75,6 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
@ -108,7 +87,6 @@ class TogetherInferenceAdapter(
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
@ -120,146 +98,39 @@ class TogetherInferenceAdapter(
logprobs=logprobs,
)
# accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request)
together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request)
if not request.stream:
# TODO: might need to add back an async here
r = client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if (
r.choices[0].finish_reason == "stop"
or r.choices[0].finish_reason == "eos"
):
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
if stream:
return self._stream_chat_completion(request, client)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
return self._nonstream_chat_completion(request, client)
buffer = ""
ipython = False
stop_reason = None
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Together
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
for chunk in client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=True,
**options,
):
if finish_reason := chunk.choices[0].finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Together
) -> AsyncGenerator:
params = self._get_params(request)
text = chunk.choices[0].delta.content
if text is None:
continue
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request),
}

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import uuid
from typing import List
from urllib.parse import urlparse
@ -13,7 +12,6 @@ import chromadb
from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory, RoutableProvider):
class ChromaMemoryAdapter(Memory):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
async def register_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
collection = await self.client.create_collection(
name=bank_id,
metadata={"bank": bank.json()},
memory_bank: MemoryBankDef,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
collection = await self.client.get_or_create_collection(
name=memory_bank.identifier,
)
bank_index = BankWithIndex(
bank=bank, index=ChromaIndex(self.client, collection)
bank=memory_bank, index=ChromaIndex(self.client, collection)
)
self.cache[bank_id] = bank_index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
self.cache[memory_bank.identifier] = bank_index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if bank is None:
raise ValueError(f"Bank {bank_id} not found")
collections = await self.client.list_collections()
for collection in collections:
if collection.name == bank_id:
print(collection.metadata)
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),

View file

@ -4,18 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import List, Tuple
from typing import List
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
@ -32,33 +28,6 @@ def check_extension_version(cur):
return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, keys: List[str], cls):
query = "SELECT key, data FROM metadata_store"
if keys:
placeholders = ",".join(["%s"] * len(keys))
query += f" WHERE key IN ({placeholders})"
cur.execute(query, keys)
else:
cur.execute(query)
rows = cur.fetchall()
return [cls(**row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor):
self.cursor = cursor
@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory, RoutableProvider):
class PGVectorMemoryAdapter(Memory):
def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config
@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
else:
raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
import traceback
@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
async def register_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
upsert_models(
self.cursor,
[
(bank.bank_id, bank),
],
)
memory_bank: MemoryBankDef,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
banks = load_models(self.cursor, [bank_id], MemoryBank)
if not banks:
return None
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
bank = banks[0]
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleMemoryImpl(Memory, RoutableProvider):
class SampleMemoryImpl(Memory):
def __init__(self, config: SampleConfig):
self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class WeaviateRequestProviderData(BaseModel):
weaviate_api_key: str
weaviate_cluster_url: str
class WeaviateConfig(BaseModel):
pass

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import weaviate
import weaviate.classes as wvc
from numpy.typing import NDArray
from weaviate.classes.init import Auth
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str):
self.client = client
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
data_objects = []
for i, chunk in enumerate(chunks):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_content": chunk.json(),
},
vector=embeddings[i].tolist(),
)
)
# Inserting chunks into a prespecified Weaviate collection
collection = self.client.collections.get(self.collection_name)
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_metadata=wvc.query.MetadataQuery(distance=True),
)
chunks = []
scores = []
for doc in results.objects:
chunk_json = doc.properties["chunk_content"]
try:
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {chunk_json}")
continue
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
def __init__(self, config: WeaviateConfig) -> None:
self.config = config
self.client_cache = {}
self.cache = {}
def _get_client(self) -> weaviate.Client:
provider_data = self.get_request_provider_data()
assert provider_data is not None, "Request provider data must be set"
assert isinstance(provider_data, WeaviateRequestProviderData)
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
if key in self.client_cache:
return self.client_cache[key]
client = weaviate.connect_to_weaviate_cloud(
cluster_url=provider_data.weaviate_cluster_url,
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
)
self.client_cache[key] = client
return client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for client in self.client_cache.values():
client.close()
async def register_memory_bank(
self,
memory_bank: MemoryBankDef,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
client = self._get_client()
# Create collection if it doesn't exist
if not client.collections.exists(memory_bank.identifier):
client.collections.create(
name=memory_bank.identifier,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
],
)
index = BankWithIndex(
bank=memory_bank,
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
)
self.cache[memory_bank.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
client = self._get_client()
if not client.collections.exists(bank_id):
raise ValueError(f"Collection with name `{bank_id}` not found")
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id),
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -7,14 +7,12 @@
import json
import logging
import traceback
from typing import Any, Dict, List
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from .config import BedrockSafetyConfig
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
BEDROCK_SUPPORTED_SHIELDS = [
ShieldType.generic_content_shield.value,
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
self.registered_shields = []
async def initialize(self) -> None:
try:
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
shield_params = shield.params
if "guardrailIdentifier" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
shield_params = shield_def.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get("guardrailVersion"),
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
logger.debug(f"run_shield:: response: {response}::")
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=shield_params["guardrailIdentifier"],
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
except Exception:
error_str = traceback.format_exc()
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
return None

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleSafetyImpl(Safety, RoutableProvider):
class SampleSafetyImpl(Safety):
def __init__(self, config: SampleConfig):
self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
async def register_shield(self, shield: ShieldDef) -> None:
# these are the safety shields the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -6,26 +6,20 @@
from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig
SAFETY_SHIELD_TYPES = {
TOGETHER_SHIELD_MODEL_MAP = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
}
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
@ -35,16 +29,20 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type != ShieldType.llama_guard.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
model = shield_def.params.get("model", "llama_guard")
if model not in TOGETHER_SHIELD_MODEL_MAP:
raise ValueError(f"Unsupported safety model: {model}")
together_api_key = None
if self.config.api_key is not None:
@ -57,8 +55,6 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
)
together_api_key = provider_data.together_api_key
model_name = SAFETY_SHIELD_TYPES[shield_type]
# messages can have role assistant or user
api_messages = []
for message in messages:
@ -66,7 +62,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(
together_api_key, model_name, api_messages
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
)
return RunShieldResponse(violation=violation)

View file

@ -43,23 +43,14 @@ class ProviderSpec(BaseModel):
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)
class RoutingTable(Protocol):
def get_routing_keys(self) -> List[str]: ...
def get_provider_impl(self, routing_key: str) -> Any: ...
class RoutableProvider(Protocol):
"""
A provider which sits behind the RoutingTable and can get routed to.
All Inference / Safety / Memory providers fall into this bucket.
"""
async def validate_routing_keys(self, keys: List[str]) -> None: ...
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@ -156,6 +147,10 @@ as being "Llama Stack compatible"
return None
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None

View file

@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None:
memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session_id}",
config=VectorMemoryBankConfig(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
bank_id = f"memory_bank_{session_id}"
memory_bank = VectorMemoryBankDef(
identifier=bank_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
bank_id = memory_bank.bank_id
await self.memory_api.register_memory_bank(memory_bank)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
@ -673,7 +674,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
@ -722,12 +723,13 @@ class ChatAgent(ShieldRunnerMixin):
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None, bank_ids
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
if not chunks:
return None, bank_ids
tokens = 0
picked = []

View file

@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id,
)
async def create_agent_turn(
def create_agent_turn(
self,
agent_id: str,
session_id: str,
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
agent = await self.get_agent(agent_id)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=stream,
stream=True,
)
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import CodeShieldConfig
async def get_provider_impl(config: CodeShieldConfig, deps):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from termcolor import cprint
from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403
class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type != ShieldType.code_scanner.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
from codeshield.cs import CodeShield
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
violation = None
if result.is_insecure:
violation = SafetyViolation(
violation_level=(ViolationLevel.ERROR),
user_message="Sorry, I found security concerns in the code.",
metadata={
"violation_type": ",".join(
[issue.pattern_id for issue in result.issues_found]
)
},
)
return RunShieldResponse(violation=violation)

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class CodeShieldConfig(BaseModel):
pass

View file

@ -43,13 +43,12 @@ class MetaReferenceEvalsImpl(Evals):
print("generation start")
for msg in x1[:5]:
print("generation for msg: ", msg)
response = self.inference_api.chat_completion(
response = await self.inference_api.chat_completion(
model=model,
messages=[msg],
stream=False,
)
async for x in response:
generation_outputs.append(x.completion_message.content)
generation_outputs.append(response.completion_message.content)
x2 = task_impl.postprocess(generation_outputs)
eval_results = task_impl.score(x2)

View file

@ -297,7 +297,7 @@ class Llama:
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs
else None
),

View file

@ -6,15 +6,14 @@
import asyncio
from typing import AsyncIterator, List, Union
from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceImplConfig
@ -25,7 +24,7 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
@ -35,21 +34,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
assert (
len(routing_keys) == 1
), f"Only one routing key is supported {routing_keys}"
assert routing_keys[0] == self.config.model
async def register_model(self, model: ModelDef) -> None:
if model.identifier != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
)
async def shutdown(self) -> None:
self.generator.stop()
# hm, when stream=False, we should not be doing SSE :/ which is what the
# top-level server is going to do. make the typing more specific here
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -59,9 +57,10 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
@ -74,7 +73,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
@ -88,21 +86,74 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with SEMAPHORE:
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
messages = chat_completion_request_to_messages(request)
tokens = []
logprobs = []
stop_reason = None
buffer = ""
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(
@ -113,10 +164,9 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -127,13 +177,6 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
continue
if token_result.text == "<|eot_id|>":
@ -154,59 +197,61 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
delta = text
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -13,15 +13,15 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from termcolor import cprint
from torch import Tensor
from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.config import (
CheckpointQuantizationFormat,
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceImplConfig,
)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
@ -14,7 +13,6 @@ import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import (
@ -63,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, RoutableProvider):
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
@ -72,37 +70,18 @@ class FaissMemoryImpl(Memory, RoutableProvider):
async def shutdown(self) -> None: ...
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
async def register_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
assert url is None, "URL is not supported for this implementation"
memory_bank: MemoryBankDef,
) -> None:
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
index = BankWithIndex(
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
)
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
index = self.cache.get(bank_id)
if index is None:
return None
return index.bank
self.cache[memory_bank.identifier] = index
async def insert_documents(
self,

View file

@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str:
return interleaved_text_media_as_str(message.content)
# For shields that operate on simple strings
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
@ -56,9 +55,3 @@ class TextShield(ShieldBase):
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()
class DummyShield(TextShield):
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(is_violation=False)

View file

@ -9,23 +9,19 @@ from typing import List, Optional
from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
class MetaReferenceShieldType(Enum):
llama_guard = "llama_guard"
code_scanner_guard = "code_scanner_guard"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
class PromptGuardType(Enum):
injection = "injection"
jailbreak = "jailbreak"
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = []
disable_input_check: bool = False
disable_output_check: bool = False
@validator("model")
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
@ -47,10 +43,6 @@ class LlamaGuardShieldConfig(BaseModel):
return model
class PromptGuardShieldConfig(BaseModel):
model: str = "Prompt-Guard-86M"
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None
enable_prompt_guard: Optional[bool] = False

View file

@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
async def run(self, messages: List[Message]) -> ShieldResponse:
messages = self.validate_messages(messages)
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
is_violation=False,
)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
shield_input_message = self.build_vision_shield_input(messages)

View file

@ -6,56 +6,43 @@
from typing import Any, Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api, RoutableProvider
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
)
from .config import MetaReferenceShieldType, SafetyConfig
from .shields import (
CodeScannerShield,
InjectionShield,
JailbreakShield,
LlamaGuardShield,
PromptGuardShield,
ShieldBase,
)
from .base import OnViolationAction, ShieldBase
from .config import SafetyConfig
from .llama_guard import LlamaGuardShield
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model.descriptor())
return model_dir
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
self.available_shields = []
if config.llama_guard_shield:
self.available_shields.append(ShieldType.llama_guard.value)
if config.enable_prompt_guard:
self.available_shields.append(ShieldType.prompt_guard.value)
async def initialize(self) -> None:
shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
if self.config.enable_prompt_guard:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
_ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
available_shields = [v.value for v in MetaReferenceShieldType]
for key in routing_keys:
if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}")
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type not in self.available_shields:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
async def run_shield(
self,
@ -63,10 +50,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
available_shields = [v.value for v in MetaReferenceShieldType]
assert shield_type in available_shields, f"Unknown shield {shield_type}"
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
shield = self.get_shield_impl(shield_def)
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
@ -92,34 +80,22 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
return RunShieldResponse(violation=violation)
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
cfg = cfg.llama_guard_shield
assert (
cfg is not None
), "Cannot use LlamaGuardShield since not present in config"
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
if shield.type == ShieldType.llama_guard.value:
cfg = self.config.llama_guard_shield
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
elif typ == MetaReferenceShieldType.jailbreak_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif typ == MetaReferenceShieldType.injection_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif typ == MetaReferenceShieldType.code_scanner_guard:
return CodeScannerShield.instance()
elif shield.type == ShieldType.prompt_guard.value:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
subtype = shield.params.get("prompt_guard_type", "injection")
if subtype == "injection":
return InjectionShield.instance(model_dir)
elif subtype == "jailbreak":
return JailbreakShield.instance(model_dir)
else:
raise ValueError(f"Unknown prompt guard type: {subtype}")
else:
raise ValueError(f"Unknown shield type: {typ}")
raise ValueError(f"Unknown shield type: {shield.type}")

View file

@ -1,33 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# supress warnings and spew of logs from hugging face
import transformers
from .base import ( # noqa: F401
DummyShield,
OnViolationAction,
ShieldBase,
ShieldResponse,
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
transformers.logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from .base import ShieldResponse, TextShield
class CodeScannerShield(TextShield):
async def run_impl(self, text: str) -> ShieldResponse:
from codeshield.cs import CodeShield
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
),
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(is_violation=False)

View file

@ -0,0 +1,11 @@
from typing import Any
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
@json_schema_type
class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""
model: str = Field(
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model

View file

@ -0,0 +1,241 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import uuid
from typing import Any
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMConfig
log = logging.getLogger(__name__)
def _random_uuid() -> str:
return str(uuid.uuid4().hex)
def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
"""Convert sampling params to vLLM sampling params."""
if sampling_params is None:
return SamplingParams()
# TODO convert what I saw in my first test ... but surely there's more to do here
kwargs = {
"temperature": sampling_params.temperature,
}
if sampling_params.top_k >= 1:
kwargs["top_k"] = sampling_params.top_k
if sampling_params.top_p:
kwargs["top_p"] = sampling_params.top_p
if sampling_params.max_tokens >= 1:
kwargs["max_tokens"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return SamplingParams(**kwargs)
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
"""Inference implementation for vLLM."""
HF_MODEL_MAPPINGS = {
# TODO: seems like we should be able to build this table dynamically ...
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
}
def __init__(self, config: VLLMConfig):
Inference.__init__(self)
ModelRegistryHelper.__init__(
self,
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
)
self.config = config
self.engine = None
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
async def initialize(self):
"""Initialize the vLLM inference adapter."""
log.info("Initializing vLLM inference adapter")
# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
if "VLLM_NO_USAGE_STATS" not in os.environ:
os.environ["VLLM_NO_USAGE_STATS"] = "1"
hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model)
# TODO -- there are a ton of options supported here ...
engine_args = AsyncEngineArgs()
engine_args.model = hf_model
# We will need a new config item for this in the future if model support is more broad
# than it is today (llama only)
engine_args.tokenizer = hf_model
engine_args.tensor_parallel_size = self.config.tensor_parallel_size
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
async def shutdown(self):
"""Shutdown the vLLM inference adapter."""
log.info("Shutting down vLLM inference adapter")
if self.engine:
self.engine.shutdown_background_loop()
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Any | None = ...,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
log.info("vLLM completion")
messages = [UserMessage(content=content)]
return self.chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
def chat_completion(
self,
model: str,
messages: list[Message],
sampling_params: Any | None = ...,
tools: list[ToolDefinition] | None = ...,
tool_choice: ToolChoice | None = ...,
tool_prompt_format: ToolPromptFormat | None = ...,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
log.info("vLLM chat completion")
assert self.engine is not None
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = chat_completion_request_to_prompt(request, self.formatter)
vllm_sampling_params = _vllm_sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
if stream:
return self._stream_chat_completion(request, results_generator)
else:
return self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> ChatCompletionResponse:
outputs = [o async for o in results_generator]
final_output = outputs[-1]
assert final_output is not None
outputs = final_output.outputs
finish_reason = outputs[-1].stop_reason
choice = OpenAICompatCompletionChoice(
finish_reason=finish_reason,
text="".join([output.text for output in outputs]),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
async def _generate_and_convert_to_openai_compat():
async for chunk in results_generator:
if not chunk.outputs:
log.warning("Empty chunk received")
continue
text = "".join([output.text for output in chunk.outputs])
choice = OpenAICompatCompletionChoice(
finish_reason=chunk.outputs[-1].stop_reason,
text=text,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
async def embeddings(
self, model: str, contents: list[InterleavedTextMedia]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
raise NotImplementedError()

View file

@ -41,6 +41,7 @@ def available_providers() -> List[ProviderSpec]:
adapter=AdapterSpec(
adapter_type="ollama",
pip_packages=["ollama"],
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.adapters.inference.ollama",
),
),
@ -103,4 +104,24 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="databricks",
pip_packages=[
"openai",
],
module="llama_stack.providers.adapters.inference.databricks",
config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig",
),
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.impls.vllm",
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
),
]

View file

@ -56,6 +56,16 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
),
),
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_type="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
module="llama_stack.providers.adapters.memory.weaviate",
config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
),
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(

View file

@ -21,7 +21,6 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
provider_type="meta-reference",
pip_packages=[
"codeshield",
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
@ -61,4 +60,14 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
),
),
InlineProviderSpec(
api=Api.safety,
provider_type="meta-reference/codeshield",
pip_packages=[
"codeshield",
],
module="llama_stack.providers.impls.meta_reference.codeshield",
config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig",
api_dependencies=[],
),
]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,25 @@
providers:
- provider_id: test-ollama
provider_type: remote::ollama
config:
host: localhost
port: 11434
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://localhost:7001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002
- provider_id: test-together
provider_type: remote::together
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-together":
together_api_key:
0xdeadbeefputrealapikeyhere

View file

@ -0,0 +1,243 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import itertools
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
# --tb=short --disable-warnings
# ```
def group_chunks(response):
return {
event_type: list(group)
for event_type, group in itertools.groupby(
response, key=lambda chunk: chunk.event.event_type
)
}
Llama_8B = "Llama3.1-8B-Instruct"
Llama_3B = "Llama3.2-3B-Instruct"
def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
],
ids=lambda d: d["model"],
)
async def inference_settings(request):
model = request.param["model"]
impls = await resolve_impls_for_test(
Api.inference,
models=[
ModelDef(
identifier=model,
llama_model=model,
)
],
)
return {
"impl": impls[Api.inference],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if "Llama3.1" in model
else ToolPromptFormat.python_list
),
},
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = await inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
messages=sample_messages,
stream=True,
**inference_settings["common_params"],
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = await inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=True,
**inference_settings["common_params"],
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
model = inference_settings["common_params"]["model"]
if "Llama3.1" in model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]

View file

@ -8,7 +8,7 @@ import unittest
from llama_models.llama3.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.augment_messages import augment_messages_for_tools
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
MODEL = "Llama3.1-8B-Instruct"
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content),
],
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_prompt_format=ToolPromptFormat.json,
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
),
],
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = augment_messages_for_tools(request)
messages = chat_completion_request_to_messages(request)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,24 @@
providers:
- provider_id: test-faiss
provider_type: meta-reference
config: {}
- provider_id: test-chroma
provider_type: remote::chroma
config:
host: localhost
port: 6001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002
- provider_id: test-weaviate
provider_type: remote::weaviate
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-weaviate":
weaviate_api_key: 0xdeadbeefputrealapikeyhere
weaviate_cluster_url: http://foobarbaz

View file

@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/memory/test_memory.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session")
async def memory_impl():
impls = await resolve_impls_for_test(
Api.memory,
memory_banks=[],
)
return impls[Api.memory]
@pytest.fixture
def sample_documents():
return [
MemoryBankDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
MemoryBankDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
MemoryBankDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
MemoryBankDocument(
document_id="doc4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
),
]
async def register_memory_bank(memory_impl: Memory):
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await memory_impl.register_memory_bank(bank)
@pytest.mark.asyncio
async def test_query_documents(memory_impl, sample_documents):
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await register_memory_bank(memory_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
query1 = "programming language"
response1 = await memory_impl.query_documents("test_bank", query1)
assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Test case 3: Query with semantic similarity
query3 = "AI and brain-inspired computing"
response3 = await memory_impl.query_documents("test_bank", query3)
assert_valid_response(response3)
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
# Test case 4: Query with limit on number of results
query4 = "computer"
params4 = {"max_chunks": 2}
response4 = await memory_impl.query_documents("test_bank", query4, params4)
assert_valid_response(response4)
assert len(response4.chunks) <= 2
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.5}
response5 = await memory_impl.query_documents("test_bank", query5, params5)
assert_valid_response(response5)
assert all(score >= 0.5 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse):
assert isinstance(response, QueryDocumentsResponse)
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
for chunk in response.chunks:
assert isinstance(chunk.content, str)
assert chunk.document_id is not None

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
async def resolve_impls_for_test(
api: Api,
models: List[ModelDef] = None,
memory_banks: List[MemoryBankDef] = None,
shields: List[ShieldDef] = None,
):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
)
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
if len(providers_by_id) == 0:
raise ValueError("No providers found in config file")
if "PROVIDER_ID" in os.environ:
provider_id = os.environ["PROVIDER_ID"]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
models = models or []
shields = shields or []
memory_banks = memory_banks or []
models = [
ModelDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in models
]
shields = [
ShieldDef(
**{
**s.dict(),
"provider_id": provider_id,
}
)
for s in shields
]
memory_banks = [
MemoryBankDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in memory_banks
]
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api],
providers={api.value: [Provider(**provider)]},
models=models,
memory_banks=memory_banks,
shields=shields,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
class ModelRegistryHelper:
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
def map_to_provider_model(self, identifier: str) -> str:
model = resolve_model(identifier)
if not model:
raise ValueError(f"Unknown model: `{identifier}`")
if identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
)
return self.stack_to_provider_models_map[identifier]
async def register_model(self, model: ModelDef) -> None:
if model.identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)

View file

@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_stack.apis.inference import * # noqa: F403
from pydantic import BaseModel
class OpenAICompatCompletionChoiceDelta(BaseModel):
content: str
class OpenAICompatCompletionChoice(BaseModel):
finish_reason: Optional[str] = None
text: Optional[str] = None
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice]
def get_sampling_options(request: ChatCompletionRequest) -> dict:
options = {}
if params := request.sampling_params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr):
options[attr] = getattr(params, attr)
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
return options
def text_from_choice(choice) -> str:
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content
return choice.text
def process_chat_completion_response(
request: ChatCompletionRequest,
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
) -> ChatCompletionResponse:
choice = response.choices[0]
stop_reason = None
if reason := choice.finish_reason:
if reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif reason == "eom":
stop_reason = StopReason.end_of_message
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
completion_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), stop_reason
)
return ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
async def process_chat_completion_stream_response(
request: ChatCompletionRequest,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
if ipython:
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

Some files were not shown because too many files have changed in this diff Show more