Merge remote-tracking branch 'origin/main' into if_eval

This commit is contained in:
Botao Chen 2025-03-18 23:23:13 -07:00
commit a690c7b230
123 changed files with 4482 additions and 3161 deletions

View file

@ -1,9 +0,0 @@
---
description: General rules always applicable across the project
globs:
alwaysApply: true
---
# Style
- Comments must add value to code. Don't write filler comments explaining what you are doing next; they just add noise.
- Add a comment to clarify surprising behavior which would not be obvious. Good variable naming and clear code organization is more important.

View file

@ -5,4 +5,19 @@ updates:
- package-ecosystem: "github-actions" - package-ecosystem: "github-actions"
directory: "/" # Will use the default workflow location of `.github/workflows` directory: "/" # Will use the default workflow location of `.github/workflows`
schedule: schedule:
interval: "daily" interval: "weekly"
day: "saturday"
commit-message:
prefix: chore(github-deps)
- package-ecosystem: "uv"
directory: "/"
schedule:
interval: "weekly"
day: "saturday"
# ignore all non-security updates: https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#open-pull-requests-limit
open-pull-requests-limit: 0
labels:
- type/dependencies
- python
commit-message:
prefix: chore(python-deps)

29
.github/workflows/changelog.yml vendored Normal file
View file

@ -0,0 +1,29 @@
name: Update Changelog
on:
release:
types: [published, unpublished, created, edited, deleted, released]
permissions:
contents: read
jobs:
generate_changelog:
name: Generate changelog
permissions:
contents: write # for peter-evans/create-pull-request to create branch
pull-requests: write # for peter-evans/create-pull-request to create a PR
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: main
fetch-depth: 0
- run: |
python ./scripts/gen-changelog.py
- uses: peter-evans/create-pull-request@v7
with:
title: 'docs: update CHANGELOG.md for ${{ github.ref_name }}'
commit-message: 'docs: update CHANGELOG.md for ${{ github.ref_name }}'
branch: create-pull-request/changelog
signoff: true

View file

@ -1,13 +1,28 @@
name: Integration tests name: Integration Tests
on: on:
pull_request:
push: push:
branches: [main] branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'distributions/**'
- 'llama_stack/**'
- 'tests/integration/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/integration-tests.yml' # This workflow
jobs: jobs:
ollama: test-matrix:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
# Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed
test-type: [inference, datasets, inspect, scoring, post_training, providers]
fail-fast: false # we want to run all tests regardless of failure
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -34,6 +49,8 @@ jobs:
run: | run: |
uv sync --extra dev --extra test uv sync --extra dev --extra test
uv pip install ollama faiss-cpu uv pip install ollama faiss-cpu
# always test against the latest version of the client
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install -e . uv pip install -e .
- name: Wait for Ollama to start - name: Wait for Ollama to start
@ -56,25 +73,24 @@ jobs:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: | run: |
source .venv/bin/activate source .venv/bin/activate
# TODO: use "llama stack run" nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
nohup uv run python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
run: | run: |
echo "Waiting for Llama Stack server..." echo "Waiting for Llama Stack server..."
for i in {1..30}; do for i in {1..30}; do
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
echo " Llama Stack server is up!" echo "Llama Stack server is up!"
exit 0 exit 0
fi fi
sleep 1 sleep 1
done done
echo " Llama Stack server failed to start" echo "Llama Stack server failed to start"
cat server.log cat server.log
exit 1 exit 1
- name: Run Inference Integration Tests - name: Run Integration Tests
env: env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: | run: |
uv run pytest -v tests/integration/inference --stack-config=ollama --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2 uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=ollama --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2

View file

@ -40,6 +40,7 @@ jobs:
matrix: matrix:
template: ${{ fromJson(needs.generate-matrix.outputs.templates) }} template: ${{ fromJson(needs.generate-matrix.outputs.templates) }}
image-type: [venv, container] image-type: [venv, container]
fail-fast: false # We want to run all jobs even if some fail
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -67,7 +68,9 @@ jobs:
- name: Run Llama Stack Build - name: Run Llama Stack Build
run: | run: |
uv run llama stack build --template ${{ matrix.template }} --image-type ${{ matrix.image-type }} --image-name test # USE_COPY_NOT_MOUNT is set to true since mounting is not supported by docker buildx, we use COPY instead
# LLAMA_STACK_DIR is set to the current directory so we are building from the source
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --template ${{ matrix.template }} --image-type ${{ matrix.image-type }} --image-name test
- name: Print dependencies in the image - name: Print dependencies in the image
if: matrix.image-type == 'venv' if: matrix.image-type == 'venv'

View file

@ -5,6 +5,14 @@ on:
branches: [ main ] branches: [ main ]
pull_request: pull_request:
branches: [ main ] branches: [ main ]
paths:
- 'distributions/**'
- 'llama_stack/**'
- 'tests/unit/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/unit-tests.yml' # This workflow
workflow_dispatch: workflow_dispatch:
jobs: jobs:

View file

@ -77,7 +77,7 @@ repos:
name: Distribution Template Codegen name: Distribution Template Codegen
additional_dependencies: additional_dependencies:
- uv==0.6.0 - uv==0.6.0
entry: uv run --extra codegen python -m llama_stack.scripts.distro_codegen entry: uv run --extra codegen ./scripts/distro_codegen.py
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true

View file

@ -86,7 +86,7 @@ LLAMA_STACK_CONFIG=
And then use this dotenv file when running client SDK tests via the following: And then use this dotenv file when running client SDK tests via the following:
```bash ```bash
uv run --env-file .env -- pytest -v tests/api/inference/test_text_inference.py uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py
``` ```
## Pre-commit Hooks ## Pre-commit Hooks
@ -159,7 +159,7 @@ LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama
### Updating Provider Configurations ### Updating Provider Configurations
If you have made changes to a provider's configuration in any form (introducing a new config key, or changing models, etc.), you should run `python llama_stack/scripts/distro_codegen.py` to re-generate various YAML files as well as the documentation. You should not change `docs/source/.../distributions/` files manually as they are auto-generated. If you have made changes to a provider's configuration in any form (introducing a new config key, or changing models, etc.), you should run `./scripts/distro_codegen.py` to re-generate various YAML files as well as the documentation. You should not change `docs/source/.../distributions/` files manually as they are auto-generated.
### Building the Documentation ### Building the Documentation

View file

@ -4,7 +4,8 @@
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack) [![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack)
![Unit](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main) [![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
@ -72,26 +73,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) | | Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
| vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) | | vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) |
### Installation
You have two ways to install this repository:
* **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash
pip install llama-stack
```
* **Install from source**:
If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv).
Then, run the following commands:
```bash
git clone git@github.com:meta-llama/llama-stack.git
cd llama-stack
uv sync
uv pip install -e .
```
### Documentation ### Documentation

View file

@ -401,16 +401,13 @@
], ],
"nvidia": [ "nvidia": [
"aiosqlite", "aiosqlite",
"autoevals",
"blobfile", "blobfile",
"chardet", "chardet",
"datasets",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"matplotlib", "matplotlib",
"mcp",
"nltk", "nltk",
"numpy", "numpy",
"openai", "openai",

File diff suppressed because it is too large Load diff

View file

@ -10,56 +10,7 @@ info:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
paths: paths:
/v1/datasetio/rows: /v1/datasetio/append-rows/{dataset_id}:
get:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/PaginatedRowsResult'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- DatasetIO
description: >-
Get a paginated list of rows from a dataset.
parameters:
- name: dataset_id
in: query
description: >-
The ID of the dataset to get the rows from.
required: true
schema:
type: string
- name: rows_in_page
in: query
description: The number of rows to get per page.
required: true
schema:
type: integer
- name: page_token
in: query
description: The token to get the next page of rows.
required: false
schema:
type: string
- name: filter_condition
in: query
description: >-
(Optional) A condition to filter the rows by.
required: false
schema:
type: string
post: post:
responses: responses:
'200': '200':
@ -77,7 +28,12 @@ paths:
tags: tags:
- DatasetIO - DatasetIO
description: '' description: ''
parameters: [] parameters:
- name: dataset_id
in: path
required: true
schema:
type: string
requestBody: requestBody:
content: content:
application/json: application/json:
@ -394,7 +350,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files (Coming Soon) - Files
description: List all buckets. description: List all buckets.
parameters: parameters:
- name: bucket - name: bucket
@ -421,7 +377,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files (Coming Soon) - Files
description: >- description: >-
Create a new upload session for a file identified by a bucket and key. Create a new upload session for a file identified by a bucket and key.
parameters: [] parameters: []
@ -580,7 +536,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files (Coming Soon) - Files
description: >- description: >-
Get a file info identified by a bucket and key. Get a file info identified by a bucket and key.
parameters: parameters:
@ -616,7 +572,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files (Coming Soon) - Files
description: >- description: >-
Delete a file identified by a bucket and key. Delete a file identified by a bucket and key.
parameters: parameters:
@ -801,9 +757,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/Benchmark'
- $ref: '#/components/schemas/Benchmark'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -831,9 +785,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/Dataset'
- $ref: '#/components/schemas/Dataset'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -884,9 +836,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/Model'
- $ref: '#/components/schemas/Model'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -937,9 +887,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/ScoringFn'
- $ref: '#/components/schemas/ScoringFn'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -967,9 +915,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/Shield'
- $ref: '#/components/schemas/Shield'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1171,9 +1117,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/PostTrainingJobArtifactsResponse'
- $ref: '#/components/schemas/PostTrainingJobArtifactsResponse'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1201,9 +1145,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/PostTrainingJobStatusResponse'
- $ref: '#/components/schemas/PostTrainingJobStatusResponse'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1254,9 +1196,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/FileUploadResponse'
- $ref: '#/components/schemas/FileUploadResponse'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1268,7 +1208,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files (Coming Soon) - Files
description: >- description: >-
Returns information about an existsing upload session Returns information about an existsing upload session
parameters: parameters:
@ -1299,7 +1239,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files (Coming Soon) - Files
description: >- description: >-
Upload file content to an existing upload session. On the server, request Upload file content to an existing upload session. On the server, request
body will have the raw bytes that are uploaded. body will have the raw bytes that are uploaded.
@ -1325,9 +1265,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/VectorDB'
- $ref: '#/components/schemas/VectorDB'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1501,6 +1439,50 @@ paths:
schema: schema:
$ref: '#/components/schemas/InvokeToolRequest' $ref: '#/components/schemas/InvokeToolRequest'
required: true required: true
/v1/datasetio/iterrows/{dataset_id}:
get:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/IterrowsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- DatasetIO
description: >-
Get a paginated list of rows from a dataset. Uses cursor-based pagination.
parameters:
- name: dataset_id
in: path
description: >-
The ID of the dataset to get the rows from.
required: true
schema:
type: string
- name: start_index
in: query
description: >-
Index into dataset for the first row to get. Get all rows if None.
required: false
schema:
type: integer
- name: limit
in: query
description: The number of rows to get.
required: false
schema:
type: integer
/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}: /v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}:
get: get:
responses: responses:
@ -1509,9 +1491,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/JobStatus'
- $ref: '#/components/schemas/JobStatus'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1710,6 +1690,10 @@ paths:
responses: responses:
'200': '200':
description: OK description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/Dataset'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1722,7 +1706,7 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Datasets - Datasets
description: '' description: Register a new dataset.
parameters: [] parameters: []
requestBody: requestBody:
content: content:
@ -1750,7 +1734,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files (Coming Soon) - Files
description: List all files in a bucket. description: List all files in a bucket.
parameters: parameters:
- name: bucket - name: bucket
@ -2607,8 +2591,6 @@ components:
AppendRowsRequest: AppendRowsRequest:
type: object type: object
properties: properties:
dataset_id:
type: string
rows: rows:
type: array type: array
items: items:
@ -2623,7 +2605,6 @@ components:
- type: object - type: object
additionalProperties: false additionalProperties: false
required: required:
- dataset_id
- rows - rows
title: AppendRowsRequest title: AppendRowsRequest
CompletionMessage: CompletionMessage:
@ -4727,6 +4708,148 @@ components:
- scoring_functions - scoring_functions
- metadata - metadata
title: Benchmark title: Benchmark
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
Dataset:
type: object
properties:
identifier:
type: string
provider_resource_id:
type: string
provider_id:
type: string
type:
type: string
const: dataset
default: dataset
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
title: DatasetPurpose
description: >-
Purpose of the dataset. Each purpose has a required input data schema.
source:
$ref: '#/components/schemas/DataSource'
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- identifier
- provider_resource_id
- provider_id
- type
- purpose
- source
- metadata
title: Dataset
RowsDataSource:
type: object
properties:
type:
type: string
const: rows
default: rows
rows:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The dataset is stored in rows. E.g. - [ {"messages": [{"role": "user",
"content": "Hello, world!"}, {"role": "assistant", "content": "Hello,
world!"}]} ]
additionalProperties: false
required:
- type
- rows
title: RowsDataSource
description: A dataset stored in rows.
URIDataSource:
type: object
properties:
type:
type: string
const: uri
default: uri
uri:
type: string
description: >-
The dataset can be obtained from a URI. E.g. - "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}"
additionalProperties: false
required:
- type
- uri
title: URIDataSource
description: >-
A dataset that can be obtained from a URI.
Model:
type: object
properties:
identifier:
type: string
provider_resource_id:
type: string
provider_id:
type: string
type:
type: string
const: model
default: model
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
model_type:
$ref: '#/components/schemas/ModelType'
default: llm
additionalProperties: false
required:
- identifier
- provider_resource_id
- provider_id
- type
- metadata
- model_type
title: Model
ModelType:
type: string
enum:
- llm
- embedding
title: ModelType
AgentTurnInputType: AgentTurnInputType:
type: object type: object
properties: properties:
@ -4782,45 +4905,6 @@ components:
required: required:
- type - type
title: CompletionInputType title: CompletionInputType
Dataset:
type: object
properties:
identifier:
type: string
provider_resource_id:
type: string
provider_id:
type: string
type:
type: string
const: dataset
default: dataset
dataset_schema:
type: object
additionalProperties:
$ref: '#/components/schemas/ParamType'
url:
$ref: '#/components/schemas/URL'
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- identifier
- provider_resource_id
- provider_id
- type
- dataset_schema
- url
- metadata
title: Dataset
JsonType: JsonType:
type: object type: object
properties: properties:
@ -4879,97 +4963,6 @@ components:
chat_completion_input: '#/components/schemas/ChatCompletionInputType' chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType' completion_input: '#/components/schemas/CompletionInputType'
agent_turn_input: '#/components/schemas/AgentTurnInputType' agent_turn_input: '#/components/schemas/AgentTurnInputType'
StringType:
type: object
properties:
type:
type: string
const: string
default: string
additionalProperties: false
required:
- type
title: StringType
UnionType:
type: object
properties:
type:
type: string
const: union
default: union
additionalProperties: false
required:
- type
title: UnionType
Model:
type: object
properties:
identifier:
type: string
provider_resource_id:
type: string
provider_id:
type: string
type:
type: string
const: model
default: model
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
model_type:
$ref: '#/components/schemas/ModelType'
default: llm
additionalProperties: false
required:
- identifier
- provider_resource_id
- provider_id
- type
- metadata
- model_type
title: Model
ModelType:
type: string
enum:
- llm
- embedding
title: ModelType
PaginatedRowsResult:
type: object
properties:
rows:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The rows in the current page.
total_count:
type: integer
description: The total number of rows in the dataset.
next_page_token:
type: string
description: The token to get the next page of rows.
additionalProperties: false
required:
- rows
- total_count
title: PaginatedRowsResult
description: A paginated list of rows from a dataset.
ScoringFn: ScoringFn:
type: object type: object
properties: properties:
@ -5008,6 +5001,28 @@ components:
- metadata - metadata
- return_type - return_type
title: ScoringFn title: ScoringFn
StringType:
type: object
properties:
type:
type: string
const: string
default: string
additionalProperties: false
required:
- type
title: StringType
UnionType:
type: object
properties:
type:
type: string
const: union
default: union
additionalProperties: false
required:
- type
title: UnionType
Shield: Shield:
type: object type: object
properties: properties:
@ -5507,6 +5522,32 @@ components:
required: required:
- content - content
title: ToolInvocationResult title: ToolInvocationResult
IterrowsResponse:
type: object
properties:
data:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The rows in the current page.
next_start_index:
type: integer
description: >-
Index into dataset for the first row in the next page. None if there are
no more rows.
additionalProperties: false
required:
- data
title: IterrowsResponse
description: A paginated list of rows from a dataset.
ListAgentSessionsResponse: ListAgentSessionsResponse:
type: object type: object
properties: properties:
@ -6314,18 +6355,35 @@ components:
RegisterDatasetRequest: RegisterDatasetRequest:
type: object type: object
properties: properties:
dataset_id: purpose:
type: string
dataset_schema:
type: object
additionalProperties:
$ref: '#/components/schemas/ParamType'
url:
$ref: '#/components/schemas/URL'
provider_dataset_id:
type: string
provider_id:
type: string type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata: metadata:
type: object type: object
additionalProperties: additionalProperties:
@ -6336,11 +6394,16 @@ components:
- type: string - type: string
- type: array - type: array
- type: object - type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false additionalProperties: false
required: required:
- dataset_id - purpose
- dataset_schema - source
- url
title: RegisterDatasetRequest title: RegisterDatasetRequest
RegisterModelRequest: RegisterModelRequest:
type: object type: object
@ -6616,15 +6679,6 @@ components:
required: required:
- results - results
title: ScoreBatchResponse title: ScoreBatchResponse
AlgorithmConfig:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
discriminator:
propertyName: type
mapping:
LoRA: '#/components/schemas/LoraFinetuningConfig'
QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig: LoraFinetuningConfig:
type: object type: object
properties: properties:
@ -6708,7 +6762,9 @@ components:
checkpoint_dir: checkpoint_dir:
type: string type: string
algorithm_config: algorithm_config:
$ref: '#/components/schemas/AlgorithmConfig' oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
additionalProperties: false additionalProperties: false
required: required:
- job_uuid - job_uuid
@ -6856,7 +6912,7 @@ tags:
- name: Eval - name: Eval
x-displayName: >- x-displayName: >-
Llama Stack Evaluation API for running evaluations on model and agent candidates. Llama Stack Evaluation API for running evaluations on model and agent candidates.
- name: Files (Coming Soon) - name: Files
- name: Inference - name: Inference
description: >- description: >-
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Two kinds of models
@ -6894,7 +6950,7 @@ x-tagGroups:
- DatasetIO - DatasetIO
- Datasets - Datasets
- Eval - Eval
- Files (Coming Soon) - Files
- Inference - Inference
- Inspect - Inspect
- Models - Models

File diff suppressed because one or more lines are too long

View file

@ -84,16 +84,14 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Not in Google Colab environment\n", "Not in Google Colab environment\n"
"\u001b[33mWarning: `bwrap` is not available. Code interpreter tool will not work correctly.\u001b[0m\n"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/opt/anaconda3/envs/master/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", "Warning: `bwrap` is not available. Code interpreter tool will not work correctly.\n"
" from .autonotebook import tqdm as notebook_tqdm\n"
] ]
}, },
{ {
@ -117,76 +115,146 @@
"- datasetio\n", "- datasetio\n",
"- eval\n", "- eval\n",
"- inference\n", "- inference\n",
"- memory\n",
"- safety\n", "- safety\n",
"- scoring\n", "- scoring\n",
"- telemetry\n", "- telemetry\n",
"- tool_runtime\n", "- tool_runtime\n",
"datasets: <span style=\"font-weight: bold\">[]</span>\n", "- vector_io\n",
"container_image: null\n",
"benchmarks: <span style=\"font-weight: bold\">[]</span>\n", "benchmarks: <span style=\"font-weight: bold\">[]</span>\n",
"container_image: null\n",
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
"image_name: together\n", "image_name: together\n",
"memory_banks: <span style=\"font-weight: bold\">[]</span>\n", "logging: null\n",
"metadata_store:\n", "metadata_store:\n",
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/xiyan/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">registry.db</span>\n", " db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/xiyan/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">registry.db</span>\n",
" namespace: null\n", " namespace: null\n",
" type: sqlite\n", " type: sqlite\n",
"models:\n", "models:\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-FP8\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-FP8\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n", " model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n", " provider_model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision\n", " model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision-Turbo\n", " provider_model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision-Turbo\n",
"- metadata:\n", "- metadata:\n",
" context_length: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8192</span>\n",
" embedding_dimension: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">768</span>\n",
" model_id: togethercomputer/m2-bert-80M-8k-retrieval\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - embedding\n",
" provider_id: together\n",
" provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval\n",
"- metadata:\n",
" context_length: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">32768</span>\n",
" embedding_dimension: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">768</span>\n",
" model_id: togethercomputer/m2-bert-80M-32k-retrieval\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - embedding\n",
" provider_id: together\n",
" provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval\n",
"- metadata:\n",
" embedding_dimension: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">384</span>\n", " embedding_dimension: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">384</span>\n",
" model_id: all-MiniLM-L6-v2\n", " model_id: all-MiniLM-L6-v2\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
@ -203,14 +271,26 @@
" provider_id: meta-reference\n", " provider_id: meta-reference\n",
" provider_type: inline::meta-reference\n", " provider_type: inline::meta-reference\n",
" datasetio:\n", " datasetio:\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config:\n",
" kvstore:\n",
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/xiyan/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">huggingface_datasetio.db</span>\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: huggingface\n", " provider_id: huggingface\n",
" provider_type: remote::huggingface\n", " provider_type: remote::huggingface\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config:\n",
" kvstore:\n",
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/xiyan/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">localfs_datasetio.db</span>\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: localfs\n", " provider_id: localfs\n",
" provider_type: inline::localfs\n", " provider_type: inline::localfs\n",
" eval:\n", " eval:\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config:\n",
" kvstore:\n",
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/xiyan/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">meta_reference_eval.db</span>\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: meta-reference\n", " provider_id: meta-reference\n",
" provider_type: inline::meta-reference\n", " provider_type: inline::meta-reference\n",
" inference:\n", " inference:\n",
@ -222,16 +302,9 @@
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config: <span style=\"font-weight: bold\">{}</span>\n",
" provider_id: sentence-transformers\n", " provider_id: sentence-transformers\n",
" provider_type: inline::sentence-transformers\n", " provider_type: inline::sentence-transformers\n",
" memory:\n",
" - config:\n",
" kvstore:\n",
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/xiyan/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">faiss_store.db</span>\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: faiss\n",
" provider_type: inlin<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e::fa</span>iss\n",
" safety:\n", " safety:\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config:\n",
" excluded_categories: <span style=\"font-weight: bold\">[]</span>\n",
" provider_id: llama-guard\n", " provider_id: llama-guard\n",
" provider_type: inline::llama-guard\n", " provider_type: inline::llama-guard\n",
" scoring:\n", " scoring:\n",
@ -269,7 +342,26 @@
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config: <span style=\"font-weight: bold\">{}</span>\n",
" provider_id: rag-runtime\n", " provider_id: rag-runtime\n",
" provider_type: inline::rag-runtime\n", " provider_type: inline::rag-runtime\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n",
" provider_id: model-context-protocol\n",
" provider_type: remote::model-context-protocol\n",
" - config:\n",
" api_key: <span style=\"color: #008000; text-decoration-color: #008000\">'********'</span>\n",
" provider_id: wolfram-alpha\n",
" provider_type: remote::wolfram-alpha\n",
" vector_io:\n",
" - config:\n",
" kvstore:\n",
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/xiyan/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">faiss_store.db</span>\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: faiss\n",
" provider_type: inlin<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e::fa</span>iss\n",
"scoring_fns: <span style=\"font-weight: bold\">[]</span>\n", "scoring_fns: <span style=\"font-weight: bold\">[]</span>\n",
"server:\n",
" port: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8321</span>\n",
" tls_certfile: null\n",
" tls_keyfile: null\n",
"shields:\n", "shields:\n",
"- params: null\n", "- params: null\n",
" provider_id: null\n", " provider_id: null\n",
@ -288,6 +380,11 @@
" mcp_endpoint: null\n", " mcp_endpoint: null\n",
" provider_id: code-interpreter\n", " provider_id: code-interpreter\n",
" toolgroup_id: builtin::code_interpreter\n", " toolgroup_id: builtin::code_interpreter\n",
"- args: null\n",
" mcp_endpoint: null\n",
" provider_id: wolfram-alpha\n",
" toolgroup_id: builtin::wolfram_alpha\n",
"vector_dbs: <span style=\"font-weight: bold\">[]</span>\n",
"version: <span style=\"color: #008000; text-decoration-color: #008000\">'2'</span>\n", "version: <span style=\"color: #008000; text-decoration-color: #008000\">'2'</span>\n",
"\n", "\n",
"</pre>\n" "</pre>\n"
@ -298,76 +395,146 @@
"- datasetio\n", "- datasetio\n",
"- eval\n", "- eval\n",
"- inference\n", "- inference\n",
"- memory\n",
"- safety\n", "- safety\n",
"- scoring\n", "- scoring\n",
"- telemetry\n", "- telemetry\n",
"- tool_runtime\n", "- tool_runtime\n",
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "- vector_io\n",
"container_image: null\n",
"benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"container_image: null\n",
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"image_name: together\n", "image_name: together\n",
"memory_banks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "logging: null\n",
"metadata_store:\n", "metadata_store:\n",
" db_path: \u001b[35m/Users/xiyan/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n", " db_path: \u001b[35m/Users/xiyan/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n",
" namespace: null\n", " namespace: null\n",
" type: sqlite\n", " type: sqlite\n",
"models:\n", "models:\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-FP8\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-FP8\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n", " model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n", " provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision\n", " model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n", " provider_model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n",
"- metadata:\n", "- metadata:\n",
" context_length: \u001b[1;36m8192\u001b[0m\n",
" embedding_dimension: \u001b[1;36m768\u001b[0m\n",
" model_id: togethercomputer/m2-bert-80M-8k-retrieval\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - embedding\n",
" provider_id: together\n",
" provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval\n",
"- metadata:\n",
" context_length: \u001b[1;36m32768\u001b[0m\n",
" embedding_dimension: \u001b[1;36m768\u001b[0m\n",
" model_id: togethercomputer/m2-bert-80M-32k-retrieval\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - embedding\n",
" provider_id: together\n",
" provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval\n",
"- metadata:\n",
" embedding_dimension: \u001b[1;36m384\u001b[0m\n", " embedding_dimension: \u001b[1;36m384\u001b[0m\n",
" model_id: all-MiniLM-L6-v2\n", " model_id: all-MiniLM-L6-v2\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
@ -384,14 +551,26 @@
" provider_id: meta-reference\n", " provider_id: meta-reference\n",
" provider_type: inline::meta-reference\n", " provider_type: inline::meta-reference\n",
" datasetio:\n", " datasetio:\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config:\n",
" kvstore:\n",
" db_path: \u001b[35m/Users/xiyan/.llama/distributions/together/\u001b[0m\u001b[95mhuggingface_datasetio.db\u001b[0m\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: huggingface\n", " provider_id: huggingface\n",
" provider_type: remote::huggingface\n", " provider_type: remote::huggingface\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config:\n",
" kvstore:\n",
" db_path: \u001b[35m/Users/xiyan/.llama/distributions/together/\u001b[0m\u001b[95mlocalfs_datasetio.db\u001b[0m\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: localfs\n", " provider_id: localfs\n",
" provider_type: inline::localfs\n", " provider_type: inline::localfs\n",
" eval:\n", " eval:\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config:\n",
" kvstore:\n",
" db_path: \u001b[35m/Users/xiyan/.llama/distributions/together/\u001b[0m\u001b[95mmeta_reference_eval.db\u001b[0m\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: meta-reference\n", " provider_id: meta-reference\n",
" provider_type: inline::meta-reference\n", " provider_type: inline::meta-reference\n",
" inference:\n", " inference:\n",
@ -403,16 +582,9 @@
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" provider_id: sentence-transformers\n", " provider_id: sentence-transformers\n",
" provider_type: inline::sentence-transformers\n", " provider_type: inline::sentence-transformers\n",
" memory:\n",
" - config:\n",
" kvstore:\n",
" db_path: \u001b[35m/Users/xiyan/.llama/distributions/together/\u001b[0m\u001b[95mfaiss_store.db\u001b[0m\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: faiss\n",
" provider_type: inlin\u001b[1;92me::fa\u001b[0miss\n",
" safety:\n", " safety:\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config:\n",
" excluded_categories: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
" provider_id: llama-guard\n", " provider_id: llama-guard\n",
" provider_type: inline::llama-guard\n", " provider_type: inline::llama-guard\n",
" scoring:\n", " scoring:\n",
@ -450,7 +622,26 @@
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" provider_id: rag-runtime\n", " provider_id: rag-runtime\n",
" provider_type: inline::rag-runtime\n", " provider_type: inline::rag-runtime\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" provider_id: model-context-protocol\n",
" provider_type: remote::model-context-protocol\n",
" - config:\n",
" api_key: \u001b[32m'********'\u001b[0m\n",
" provider_id: wolfram-alpha\n",
" provider_type: remote::wolfram-alpha\n",
" vector_io:\n",
" - config:\n",
" kvstore:\n",
" db_path: \u001b[35m/Users/xiyan/.llama/distributions/together/\u001b[0m\u001b[95mfaiss_store.db\u001b[0m\n",
" namespace: null\n",
" type: sqlite\n",
" provider_id: faiss\n",
" provider_type: inlin\u001b[1;92me::fa\u001b[0miss\n",
"scoring_fns: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "scoring_fns: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"server:\n",
" port: \u001b[1;36m8321\u001b[0m\n",
" tls_certfile: null\n",
" tls_keyfile: null\n",
"shields:\n", "shields:\n",
"- params: null\n", "- params: null\n",
" provider_id: null\n", " provider_id: null\n",
@ -469,6 +660,11 @@
" mcp_endpoint: null\n", " mcp_endpoint: null\n",
" provider_id: code-interpreter\n", " provider_id: code-interpreter\n",
" toolgroup_id: builtin::code_interpreter\n", " toolgroup_id: builtin::code_interpreter\n",
"- args: null\n",
" mcp_endpoint: null\n",
" provider_id: wolfram-alpha\n",
" toolgroup_id: builtin::wolfram_alpha\n",
"vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"version: \u001b[32m'2'\u001b[0m\n", "version: \u001b[32m'2'\u001b[0m\n",
"\n" "\n"
] ]
@ -532,7 +728,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 3,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
@ -643,17 +839,7 @@
"id": "DJkmoG2kq1_P", "id": "DJkmoG2kq1_P",
"outputId": "8493ee59-c6ff-4bb6-d787-f295944db1cf" "outputId": "8493ee59-c6ff-4bb6-d787-f295944db1cf"
}, },
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"Generating dev split: 100%|██████████| 5/5 [00:00<00:00, 139.81 examples/s]\n",
"Generating validation split: 100%|██████████| 30/30 [00:00<00:00, 258.29 examples/s]\n",
"Generating test split: 100%|██████████| 287/287 [00:01<00:00, 197.69 examples/s]\n"
]
}
],
"source": [ "source": [
"import datasets\n", "import datasets\n",
"\n", "\n",
@ -676,7 +862,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 4,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
@ -691,7 +877,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 5/5 [00:42<00:00, 8.60s/it]\n" "100%|██████████| 5/5 [00:33<00:00, 6.71s/it]\n"
] ]
}, },
{ {
@ -699,16 +885,18 @@
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">EvaluateResponse</span><span style=\"font-weight: bold\">(</span>\n", "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">EvaluateResponse</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">generations</span>=<span style=\"font-weight: bold\">[</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">generations</span>=<span style=\"font-weight: bold\">[</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'Answer: D'</span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'The image shows a sunflower leaf with small, dark spots and white powdery patches. The dark spots are likely caused by a fungal pathogen, such as rust or septoria leaf spot, while the white powdery patches are likely caused by a fungal pathogen, such as powdery mildew.\\n\\nSince there are two distinct types of lesions on the leaf, it is likely that there are two different pathogens infecting the leaf.\\n\\n**Answer:** B) Two pathogens'</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'**Potato Pests**\\n\\nThe two insects depicted are:\\n\\n* **Colorado Potato Beetle (Leptinotarsa decemlineata)**: Characterized by black and yellow stripes, this beetle is a significant pest of potatoes. It feeds on the leaves and can cause substantial damage to the crop.\\n* **False Potato Beetle (Leptinotarsa juncta)**: Also known as the false Colorado beetle, this species has similar coloring but is not as harmful to potatoes as the Colorado potato beetle.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"The question requires the identification of the reason behind the massive gum production on the trunks of grapefruit trees in Cyprus, despite appearing healthy from a distance. The correct answer can be deduced by analyzing the symptoms and considering the possible causes.\\n\\nTo determine the correct answer, let's evaluate each option:\\n\\nA) Don't know or not sure: This option is incorrect because it does not provide a specific reason for the gum production.\\n\\nB) Physiological stress: This option is also incorrect because it is too broad and does not specifically explain the gum production.\\n\\nC) Bacterial disease: This option is incorrect because bacterial diseases typically cause different symptoms such as leaf spots, blights, or wilting.\\n\\nD) Harvesting damage when cutting with knives: This option is incorrect because harvesting damage would likely cause wounds or scars on the tree, but it would not lead to massive gum production.\\n\\nE) Fungal gummosis: This option is the most likely cause of the gum production. Fungal gummosis is a common disease in citrus trees, including grapefruit, that causes the production of gum or sap on the trunks and branches. The disease is typically caused by fungi such as Phytophthora or Diplodia, which infect the tree through wounds or natural openings. The gum production is a defense mechanism by the tree to try to seal off the infection and prevent further damage.\\n\\nTherefore, the correct answer is:\\n\\nAnswer: E\"</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"The image shows a sunflower leaf with a powdery mildew, which is a fungal disease caused by various species of fungi. The white powdery coating on the leaves is a characteristic symptom of this disease. The leaf also has some black spots, which could be indicative of a secondary infection or another type of disease. However, without more information or a closer examination, it's difficult to determine the exact cause of the black spots.\\n\\nBased on the image alone, we can see at least two types of symptoms: the powdery mildew and the black spots. This suggests that there may be more than one pathogen involved, but it's also possible that the black spots are a result of the same fungal infection causing the powdery mildew.\\n\\nAnswer: B) Two pathogens\"</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'The symptoms observed, characterized by the massive gum production on the trunks of the grapefruit trees in Cyprus, suggest a physiological or pathological response. Given the absence of visible signs of damage or pests from a higher point on a hillside, and considering the specific nature of the symptom (gum production), we can infer that the cause is more likely related to an internal process within the tree rather than external damage from harvesting. While physiological stress (B) could lead to such symptoms, the primary reason for gum production in trees, especially in citrus species, is typically linked to disease. Among the options provided, fungal gummosis (E) is a condition known to cause gumming in citrus trees, which aligns with the observed symptoms. Therefore, without direct evidence of external damage (harvesting) or confirmation of physiological stress being the primary cause, the most appropriate answer based on the information given is:\\n\\nAnswer: E'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'Answer: D'</span><span style=\"font-weight: bold\">}</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'Answer: D'</span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'**Causes of Splitting Petioles in Rhubarb**\\n\\nThe following factors can cause the petioles of rhubarb to split:\\n\\n* **Physiological Problems**: Issues such as water stress, nutrient deficiencies, or extreme temperatures can lead to splitting.\\n* **Phytoplasma Infection**: A bacterial infection caused by phytoplasma can lead to splitting of the petioles.\\n* **Animal Damage**: Pests like slugs, snails, or rodents can damage the plant and cause splitting.\\n* **Bacterial Infection**: Bacterial infections can also cause splitting.\\n\\nAs a result, the correct answer is:\\n\\n*Answer*: A) Physiological problems'</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"**Analysis of the Image**\\n\\nThe image provided shows a rhubarb plant with split petioles. To determine the cause of this issue, we need to consider various factors that could lead to such damage.\\n\\n**Possible Causes of Petiole Splitting**\\n\\n* **Physiological Problems**: Rhubarb plants can experience physiological stress due to environmental factors like extreme temperatures, waterlogging, or nutrient deficiencies. This stress can cause the petioles to split.\\n* **Phytoplasma Infection**: Phytoplasma is a type of bacteria that can infect plants, including rhubarb. It can cause symptoms such as yellowing leaves, stunted growth, and splitting of petioles.\\n* **Animal Damage**: Animals like rabbits, deer, or insects can damage rhubarb plants by eating the leaves or stems, which can lead to splitting of the petioles.\\n* **Bacteria**: Bacterial infections can also cause damage to rhubarb plants, including splitting of the petioles.\\n\\n**Conclusion**\\n\\nBased on the analysis, it is clear that all the options listed (A) Physiological problems, B) Phytoplasma infection, D) Animal damage, and E) Bacteria) could potentially cause the petioles of the rhubarb plant to split. Therefore, there is no single option that would not be a cause for the petioles splitting.\\n\\n**Answer**: C) I don't know and don't want to guess.\"</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">]</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">]</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">scores</span>=<span style=\"font-weight: bold\">{</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">scores</span>=<span style=\"font-weight: bold\">{</span>\n",
@ -723,16 +911,18 @@
"text/plain": [ "text/plain": [
"\u001b[1;35mEvaluateResponse\u001b[0m\u001b[1m(\u001b[0m\n", "\u001b[1;35mEvaluateResponse\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mgenerations\u001b[0m=\u001b[1m[\u001b[0m\n", "\u001b[2;32m│ \u001b[0m\u001b[33mgenerations\u001b[0m=\u001b[1m[\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'Answer: D'\u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'The image shows a sunflower leaf with small, dark spots and white powdery patches. The dark spots are likely caused by a fungal pathogen, such as rust or septoria leaf spot, while the white powdery patches are likely caused by a fungal pathogen, such as powdery mildew.\\n\\nSince there are two distinct types of lesions on the leaf, it is likely that there are two different pathogens infecting the leaf.\\n\\n**Answer:** B\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Two pathogens'\u001b[0m\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'**Potato Pests**\\n\\nThe two insects depicted are:\\n\\n* **Colorado Potato Beetle \u001b[0m\u001b[32m(\u001b[0m\u001b[32mLeptinotarsa decemlineata\u001b[0m\u001b[32m)\u001b[0m\u001b[32m**: Characterized by black and yellow stripes, this beetle is a significant pest of potatoes. It feeds on the leaves and can cause substantial damage to the crop.\\n* **False Potato Beetle \u001b[0m\u001b[32m(\u001b[0m\u001b[32mLeptinotarsa juncta\u001b[0m\u001b[32m)\u001b[0m\u001b[32m**: Also known as the false Colorado beetle, this species has similar coloring but is not as harmful to potatoes as the Colorado potato beetle.'\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m,\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m\"The question requires the identification of the reason behind the massive gum production on the trunks of grapefruit trees in Cyprus, despite appearing healthy from a distance. The correct answer can be deduced by analyzing the symptoms and considering the possible causes.\\n\\nTo determine the correct answer, let's evaluate each option:\\n\\nA\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Don't know or not sure: This option is incorrect because it does not provide a specific reason for the gum production.\\n\\nB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Physiological stress: This option is also incorrect because it is too broad and does not specifically explain the gum production.\\n\\nC\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Bacterial disease: This option is incorrect because bacterial diseases typically cause different symptoms such as leaf spots, blights, or wilting.\\n\\nD\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Harvesting damage when cutting with knives: This option is incorrect because harvesting damage would likely cause wounds or scars on the tree, but it would not lead to massive gum production.\\n\\nE\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Fungal gummosis: This option is the most likely cause of the gum production. Fungal gummosis is a common disease in citrus trees, including grapefruit, that causes the production of gum or sap on the trunks and branches. The disease is typically caused by fungi such as Phytophthora or Diplodia, which infect the tree through wounds or natural openings. The gum production is a defense mechanism by the tree to try to seal off the infection and prevent further damage.\\n\\nTherefore, the correct answer is:\\n\\nAnswer: E\"\u001b[0m\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m\"The image shows a sunflower leaf with a powdery mildew, which is a fungal disease caused by various species of fungi. The white powdery coating on the leaves is a characteristic symptom of this disease. The leaf also has some black spots, which could be indicative of a secondary infection or another type of disease. However, without more information or a closer examination, it's difficult to determine the exact cause of the black spots.\\n\\nBased on the image alone, we can see at least two types of symptoms: the powdery mildew and the black spots. This suggests that there may be more than one pathogen involved, but it's also possible that the black spots are a result of the same fungal infection causing the powdery mildew.\\n\\nAnswer: B\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Two pathogens\"\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'The symptoms observed, characterized by the massive gum production on the trunks of the grapefruit trees in Cyprus, suggest a physiological or pathological response. Given the absence of visible signs of damage or pests from a higher point on a hillside, and considering the specific nature of the symptom \u001b[0m\u001b[32m(\u001b[0m\u001b[32mgum production\u001b[0m\u001b[32m)\u001b[0m\u001b[32m, we can infer that the cause is more likely related to an internal process within the tree rather than external damage from harvesting. While physiological stress \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m could lead to such symptoms, the primary reason for gum production in trees, especially in citrus species, is typically linked to disease. Among the options provided, fungal gummosis \u001b[0m\u001b[32m(\u001b[0m\u001b[32mE\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is a condition known to cause gumming in citrus trees, which aligns with the observed symptoms. Therefore, without direct evidence of external damage \u001b[0m\u001b[32m(\u001b[0m\u001b[32mharvesting\u001b[0m\u001b[32m)\u001b[0m\u001b[32m or confirmation of physiological stress being the primary cause, the most appropriate answer based on the information given is:\\n\\nAnswer: E'\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m,\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'Answer: D'\u001b[0m\u001b[1m}\u001b[0m,\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'Answer: D'\u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'**Causes of Splitting Petioles in Rhubarb**\\n\\nThe following factors can cause the petioles of rhubarb to split:\\n\\n* **Physiological Problems**: Issues such as water stress, nutrient deficiencies, or extreme temperatures can lead to splitting.\\n* **Phytoplasma Infection**: A bacterial infection caused by phytoplasma can lead to splitting of the petioles.\\n* **Animal Damage**: Pests like slugs, snails, or rodents can damage the plant and cause splitting.\\n* **Bacterial Infection**: Bacterial infections can also cause splitting.\\n\\nAs a result, the correct answer is:\\n\\n*Answer*: A\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Physiological problems'\u001b[0m\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m\"**Analysis of the Image**\\n\\nThe image provided shows a rhubarb plant with split petioles. To determine the cause of this issue, we need to consider various factors that could lead to such damage.\\n\\n**Possible Causes of Petiole Splitting**\\n\\n* **Physiological Problems**: Rhubarb plants can experience physiological stress due to environmental factors like extreme temperatures, waterlogging, or nutrient deficiencies. This stress can cause the petioles to split.\\n* **Phytoplasma Infection**: Phytoplasma is a type of bacteria that can infect plants, including rhubarb. It can cause symptoms such as yellowing leaves, stunted growth, and splitting of petioles.\\n* **Animal Damage**: Animals like rabbits, deer, or insects can damage rhubarb plants by eating the leaves or stems, which can lead to splitting of the petioles.\\n* **Bacteria**: Bacterial infections can also cause damage to rhubarb plants, including splitting of the petioles.\\n\\n**Conclusion**\\n\\nBased on the analysis, it is clear that all the options listed \u001b[0m\u001b[32m(\u001b[0m\u001b[32mA\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Physiological problems, B\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Phytoplasma infection, D\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Animal damage, and E\u001b[0m\u001b[32m)\u001b[0m\u001b[32m Bacteria\u001b[0m\u001b[32m)\u001b[0m\u001b[32m could potentially cause the petioles of the rhubarb plant to split. Therefore, there is no single option that would not be a cause for the petioles splitting.\\n\\n**Answer**: C\u001b[0m\u001b[32m)\u001b[0m\u001b[32m I don't know and don't want to guess.\"\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[1m]\u001b[0m,\n", "\u001b[2;32m│ \u001b[0m\u001b[1m]\u001b[0m,\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mscores\u001b[0m=\u001b[1m{\u001b[0m\n", "\u001b[2;32m│ \u001b[0m\u001b[33mscores\u001b[0m=\u001b[1m{\u001b[0m\n",
@ -815,7 +1005,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 5,
"metadata": { "metadata": {
"id": "HXmZf3Ymw-aX" "id": "HXmZf3Ymw-aX"
}, },
@ -823,39 +1013,33 @@
"source": [ "source": [
"simpleqa_dataset_id = \"huggingface::simpleqa\"\n", "simpleqa_dataset_id = \"huggingface::simpleqa\"\n",
"\n", "\n",
"_ = client.datasets.register(\n", "register_dataset_response = client.datasets.register(\n",
" purpose=\"eval/messages-answer\",\n",
" source={\n",
" \"type\": \"uri\",\n",
" \"uri\": \"huggingface://datasets/llamastack/simpleqa?split=train\",\n",
" },\n",
" dataset_id=simpleqa_dataset_id,\n", " dataset_id=simpleqa_dataset_id,\n",
" provider_id=\"huggingface\",\n", ")"
" url={\"uri\": \"https://huggingface.co/datasets/llamastack/simpleqa\"},\n",
" metadata={\n",
" \"path\": \"llamastack/simpleqa\",\n",
" \"split\": \"train\",\n",
" },\n",
" dataset_schema={\n",
" \"input_query\": {\"type\": \"string\"},\n",
" \"expected_answer\": {\"type\": \"string\"},\n",
" \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n",
" },\n",
")\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 6,
"metadata": { "metadata": {
"id": "Gc8azb4Rxr5J" "id": "Gc8azb4Rxr5J"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"eval_rows = client.datasetio.get_rows_paginated(\n", "eval_rows = client.datasets.iterrows(\n",
" dataset_id=simpleqa_dataset_id,\n", " dataset_id=simpleqa_dataset_id,\n",
" rows_in_page=5,\n", " limit=5,\n",
")\n" ")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 7,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
@ -876,7 +1060,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 5/5 [00:31<00:00, 6.38s/it]\n" "100%|██████████| 5/5 [00:13<00:00, 2.71s/it]\n"
] ]
}, },
{ {
@ -889,14 +1073,14 @@
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"Radcliffe College was a women's liberal arts college in Cambridge, Massachusetts. However, it merged with Harvard University in 1977 and is now known as the Radcliffe Institute for Advanced Study at Harvard University.\"</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"Radcliffe College was a women's liberal arts college in Cambridge, Massachusetts. However, it merged with Harvard University in 1977 and is now known as the Radcliffe Institute for Advanced Study at Harvard University.\"</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'I do not have information on the Leipzig 1877 tournament.'</span><span style=\"font-weight: bold\">}</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'I am unable to verify in whose honor the Leipzig 1877 tournament was organized.'</span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"I am unable to verify what Empress Elizabeth of Austria's favorite sculpture depicted at her villa Achilleion at Corfu, according to Karl Küchler.\"</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'generated_answer'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"I am unable to verify what Empress Elizabeth of Austria's favorite sculpture depicted at her villa Achilleion at Corfu, according to Karl Küchler.\"</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">]</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">]</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">scores</span>=<span style=\"font-weight: bold\">{</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">scores</span>=<span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'llm-as-judge::405b-simpleqa'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringResult</span><span style=\"font-weight: bold\">(</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'llm-as-judge::405b-simpleqa'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringResult</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">aggregated_results</span>=<span style=\"font-weight: bold\">{}</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">aggregated_results</span>=<span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'categorical_count'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'categorical_count'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'A'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span><span style=\"font-weight: bold\">}}}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">score_rows</span>=<span style=\"font-weight: bold\">[</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">score_rows</span>=<span style=\"font-weight: bold\">[</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'judge_feedback'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span><span style=\"font-weight: bold\">}</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'judge_feedback'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'judge_feedback'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span><span style=\"font-weight: bold\">}</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'judge_feedback'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'C'</span><span style=\"font-weight: bold\">}</span>,\n",
@ -917,14 +1101,14 @@
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m\"Radcliffe College was a women's liberal arts college in Cambridge, Massachusetts. However, it merged with Harvard University in 1977 and is now known as the Radcliffe Institute for Advanced Study at Harvard University.\"\u001b[0m\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m\"Radcliffe College was a women's liberal arts college in Cambridge, Massachusetts. However, it merged with Harvard University in 1977 and is now known as the Radcliffe Institute for Advanced Study at Harvard University.\"\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m,\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'I do not have information on the Leipzig 1877 tournament.'\u001b[0m\u001b[1m}\u001b[0m,\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m'I am unable to verify in whose honor the Leipzig 1877 tournament was organized.'\u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m\"I am unable to verify what Empress Elizabeth of Austria's favorite sculpture depicted at her villa Achilleion at Corfu, according to Karl Küchler.\"\u001b[0m\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[32m'generated_answer'\u001b[0m: \u001b[32m\"I am unable to verify what Empress Elizabeth of Austria's favorite sculpture depicted at her villa Achilleion at Corfu, according to Karl Küchler.\"\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[1m]\u001b[0m,\n", "\u001b[2;32m│ \u001b[0m\u001b[1m]\u001b[0m,\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mscores\u001b[0m=\u001b[1m{\u001b[0m\n", "\u001b[2;32m│ \u001b[0m\u001b[33mscores\u001b[0m=\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[32m'llm-as-judge::405b-simpleqa'\u001b[0m: \u001b[1;35mScoringResult\u001b[0m\u001b[1m(\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[32m'llm-as-judge::405b-simpleqa'\u001b[0m: \u001b[1;35mScoringResult\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33maggregated_results\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[33maggregated_results\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'categorical_count'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'categorical_count'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'A'\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m'C'\u001b[0m: \u001b[1;36m4\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mscore_rows\u001b[0m=\u001b[1m[\u001b[0m\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mscore_rows\u001b[0m=\u001b[1m[\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'score'\u001b[0m: \u001b[32m'C'\u001b[0m, \u001b[32m'judge_feedback'\u001b[0m: \u001b[32m'C'\u001b[0m\u001b[1m}\u001b[0m,\n", "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'score'\u001b[0m: \u001b[32m'C'\u001b[0m, \u001b[32m'judge_feedback'\u001b[0m: \u001b[32m'C'\u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'score'\u001b[0m: \u001b[32m'C'\u001b[0m, \u001b[32m'judge_feedback'\u001b[0m: \u001b[32m'C'\u001b[0m\u001b[1m}\u001b[0m,\n", "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\u001b[32m'score'\u001b[0m: \u001b[32m'C'\u001b[0m, \u001b[32m'judge_feedback'\u001b[0m: \u001b[32m'C'\u001b[0m\u001b[1m}\u001b[0m,\n",
@ -957,7 +1141,7 @@
"\n", "\n",
"response = client.eval.evaluate_rows_alpha(\n", "response = client.eval.evaluate_rows_alpha(\n",
" benchmark_id=\"meta-reference::simpleqa\",\n", " benchmark_id=\"meta-reference::simpleqa\",\n",
" input_rows=eval_rows.rows,\n", " input_rows=eval_rows.data,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n", " scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
" benchmark_config={\n", " benchmark_config={\n",
" \"type\": \"benchmark\",\n", " \"type\": \"benchmark\",\n",
@ -1106,7 +1290,7 @@
"\n", "\n",
"response = client.eval.evaluate_rows_alpha(\n", "response = client.eval.evaluate_rows_alpha(\n",
" benchmark_id=\"meta-reference::simpleqa\",\n", " benchmark_id=\"meta-reference::simpleqa\",\n",
" input_rows=eval_rows.rows,\n", " input_rows=eval_rows.data,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n", " scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
" benchmark_config={\n", " benchmark_config={\n",
" \"type\": \"benchmark\",\n", " \"type\": \"benchmark\",\n",

View file

@ -12,7 +12,7 @@
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import sys
import fire import fire
import ruamel.yaml as yaml import ruamel.yaml as yaml
@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402
from .pyopenapi.options import Options # noqa: E402 from .pyopenapi.options import Options # noqa: E402
from .pyopenapi.specification import Info, Server # noqa: E402 from .pyopenapi.specification import Info, Server # noqa: E402
from .pyopenapi.utility import Specification # noqa: E402 from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402
def str_presenter(dumper, data): def str_presenter(dumper, data):
@ -39,6 +39,14 @@ def main(output_dir: str):
if not output_dir.exists(): if not output_dir.exists():
raise ValueError(f"Directory {output_dir} does not exist") raise ValueError(f"Directory {output_dir} does not exist")
# Validate API protocols before generating spec
print("Validating API method return types...")
return_type_errors = validate_api_method_return_types()
if return_type_errors:
print("\nAPI Method Return Type Validation Errors:\n")
for error in return_type_errors:
print(error)
sys.exit(1)
now = str(datetime.now()) now = str(datetime.now())
print( print(
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now "Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now

View file

@ -435,7 +435,7 @@ class Generator:
) )
self.schema_builder = SchemaBuilder(schema_generator) self.schema_builder = SchemaBuilder(schema_generator)
self.responses = {} self.responses = {}
# Create standard error responses # Create standard error responses
self._create_standard_error_responses() self._create_standard_error_responses()
@ -446,7 +446,7 @@ class Generator:
""" """
# Get the Error schema # Get the Error schema
error_schema = self.schema_builder.classdef_to_ref(Error) error_schema = self.schema_builder.classdef_to_ref(Error)
# Create standard error responses # Create standard error responses
self.responses["BadRequest400"] = Response( self.responses["BadRequest400"] = Response(
description="The request was invalid or malformed", description="The request was invalid or malformed",
@ -457,11 +457,11 @@ class Generator:
"status": 400, "status": 400,
"title": "Bad Request", "title": "Bad Request",
"detail": "The request was invalid or malformed", "detail": "The request was invalid or malformed",
} },
) )
} },
) )
self.responses["TooManyRequests429"] = Response( self.responses["TooManyRequests429"] = Response(
description="The client has sent too many requests in a given amount of time", description="The client has sent too many requests in a given amount of time",
content={ content={
@ -471,11 +471,11 @@ class Generator:
"status": 429, "status": 429,
"title": "Too Many Requests", "title": "Too Many Requests",
"detail": "You have exceeded the rate limit. Please try again later.", "detail": "You have exceeded the rate limit. Please try again later.",
} },
) )
} },
) )
self.responses["InternalServerError500"] = Response( self.responses["InternalServerError500"] = Response(
description="The server encountered an unexpected error", description="The server encountered an unexpected error",
content={ content={
@ -485,11 +485,11 @@ class Generator:
"status": 500, "status": 500,
"title": "Internal Server Error", "title": "Internal Server Error",
"detail": "An unexpected error occurred. Our team has been notified.", "detail": "An unexpected error occurred. Our team has been notified.",
} },
) )
} },
) )
# Add a default error response for any unhandled error cases # Add a default error response for any unhandled error cases
self.responses["DefaultError"] = Response( self.responses["DefaultError"] = Response(
description="An unexpected error occurred", description="An unexpected error occurred",
@ -500,9 +500,9 @@ class Generator:
"status": 0, "status": 0,
"title": "Error", "title": "Error",
"detail": "An unexpected error occurred", "detail": "An unexpected error occurred",
} },
) )
} },
) )
def _build_type_tag(self, ref: str, schema: Schema) -> Tag: def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
@ -547,11 +547,14 @@ class Generator:
"SyntheticDataGeneration", "SyntheticDataGeneration",
"PostTraining", "PostTraining",
"BatchInference", "BatchInference",
"Files",
]: ]:
op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)" op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)"
print(op.defining_class.__name__) print(op.defining_class.__name__)
# TODO (xiyan): temporary fix for datasetio inner impl + datasets api
# if op.defining_class.__name__ in ["DatasetIO"]:
# op.defining_class.__name__ = "Datasets"
doc_string = parse_type(op.func_ref) doc_string = parse_type(op.func_ref)
doc_params = dict( doc_params = dict(
(param.name, param.description) for param in doc_string.params.values() (param.name, param.description) for param in doc_string.params.values()
@ -598,7 +601,9 @@ class Generator:
# data passed in request body as raw bytes cannot have request parameters # data passed in request body as raw bytes cannot have request parameters
if raw_bytes_request_body and op.request_params: if raw_bytes_request_body and op.request_params:
raise ValueError("Cannot have both raw bytes request body and request parameters") raise ValueError(
"Cannot have both raw bytes request body and request parameters"
)
# data passed in request body as raw bytes # data passed in request body as raw bytes
if raw_bytes_request_body: if raw_bytes_request_body:
@ -719,7 +724,7 @@ class Generator:
responses.update(response_builder.build_response(response_options)) responses.update(response_builder.build_response(response_options))
assert len(responses.keys()) > 0, f"No responses found for {op.name}" assert len(responses.keys()) > 0, f"No responses found for {op.name}"
# Add standard error response references # Add standard error response references
if self.options.include_standard_error_responses: if self.options.include_standard_error_responses:
if "400" not in responses: if "400" not in responses:
@ -730,7 +735,7 @@ class Generator:
responses["500"] = ResponseRef("InternalServerError500") responses["500"] = ResponseRef("InternalServerError500")
if "default" not in responses: if "default" not in responses:
responses["default"] = ResponseRef("DefaultError") responses["default"] = ResponseRef("DefaultError")
if op.event_type is not None: if op.event_type is not None:
builder = ContentBuilder(self.schema_builder) builder = ContentBuilder(self.schema_builder)
callbacks = { callbacks = {

View file

@ -6,16 +6,19 @@
import json import json
import typing import typing
import inspect
import os
from pathlib import Path from pathlib import Path
from typing import TextIO from typing import TextIO
from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from llama_stack.distribution.resolver import api_protocol_map
from .generator import Generator from .generator import Generator
from .options import Options from .options import Options
from .specification import Document from .specification import Document
THIS_DIR = Path(__file__).parent THIS_DIR = Path(__file__).parent
@ -114,3 +117,37 @@ class Specification:
) )
f.write(html) f.write(html)
def is_optional_type(type_: Any) -> bool:
"""Check if a type is Optional."""
origin = get_origin(type_)
args = get_args(type_)
return origin is Optional or (origin is Union and type(None) in args)
def validate_api_method_return_types() -> List[str]:
"""Validate that all API methods have proper return types."""
errors = []
protocols = api_protocol_map()
for protocol_name, protocol in protocols.items():
methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for method_name, method in methods:
if not hasattr(method, '__webmethod__'):
continue
# Only check GET methods
if method.__webmethod__.method != "GET":
continue
hints = get_type_hints(method)
if 'return' not in hints:
errors.append(f"Method {protocol_name}.{method_name} has no return type annotation")
else:
return_type = hints['return']
if is_optional_type(return_type):
errors.append(f"Method {protocol_name}.{method_name} returns Optional type")
return errors

View file

@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.) - Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally. - Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary. - Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`llama_stack/scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation. - Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
Here are some example PRs to help you get started: Here are some example PRs to help you get started:

View file

@ -185,8 +185,12 @@ llama stack build --config llama_stack/templates/ollama/build.yaml
::: :::
:::{tab-item} Building Container :::{tab-item} Building Container
> [!TIP]
> Podman is supported as an alternative to Docker. Set `CONTAINER_BINARY` to `podman` in your environment to use Podman. ```{admonition} Podman Alternative
:class: tip
Podman is supported as an alternative to Docker. Set `CONTAINER_BINARY` to `podman` in your environment to use Podman.
```
To build a container image, you may start off from a template and use the `--image-type container` flag to specify `container` as the build image type. To build a container image, you may start off from a template and use the `--image-type container` flag to specify `container` as the build image type.

View file

@ -58,7 +58,7 @@ Breaking down the demo app, this section will show the core pieces that are used
### Setup Remote Inferencing ### Setup Remote Inferencing
Start a Llama Stack server on localhost. Here is an example of how you can do this using the firework.ai distribution: Start a Llama Stack server on localhost. Here is an example of how you can do this using the firework.ai distribution:
``` ```
conda create -n stack-fireworks python=3.10 conda create -n stack-fireworks python=3.10
conda activate stack-fireworks conda activate stack-fireworks
pip install --no-cache llama-stack==0.1.4 pip install --no-cache llama-stack==0.1.4
llama stack build --template fireworks --image-type conda llama stack build --template fireworks --image-type conda

View file

@ -6,13 +6,13 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| API | Provider(s) | | API | Provider(s) |
|-----|-------------| |-----|-------------|
| agents | `inline::meta-reference` | | agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` | | datasetio | `inline::localfs` |
| eval | `inline::meta-reference` | | eval | `inline::meta-reference` |
| inference | `remote::nvidia` | | inference | `remote::nvidia` |
| safety | `inline::llama-guard` | | safety | `remote::nvidia` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` | | vector_io | `inline::faiss` |
@ -20,8 +20,10 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models ### Models

View file

@ -15,8 +15,6 @@ Llama Stack defines and standardizes the core building blocks needed to bring ge
- **Multiple developer interfaces** like CLI and SDKs for Python, Node, iOS, and Android - **Multiple developer interfaces** like CLI and SDKs for Python, Node, iOS, and Android
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack - **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack
We focus on making it easy to build production applications with the Llama model family - from the latest Llama 3.3 to specialized models like Llama Guard for safety.
```{image} ../_static/llama-stack.png ```{image} ../_static/llama-stack.png
:alt: Llama Stack :alt: Llama Stack
:width: 400px :width: 400px

View file

@ -48,7 +48,7 @@ Llama Stack addresses these challenges through a service-oriented, API-first app
**Robust Ecosystem** **Robust Ecosystem**
- Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies). - Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies).
- Ecosystem offers tailored infrastructure, software, and services for deploying Llama models. - Ecosystem offers tailored infrastructure, software, and services for deploying a variety of models.
### Our Philosophy ### Our Philosophy
@ -57,7 +57,6 @@ Llama Stack addresses these challenges through a service-oriented, API-first app
- **Composability**: Every component is independent but works together seamlessly - **Composability**: Every component is independent but works together seamlessly
- **Production Ready**: Built for real-world applications, not just demos - **Production Ready**: Built for real-world applications, not just demos
- **Turnkey Solutions**: Easy to deploy built in solutions for popular deployment scenarios - **Turnkey Solutions**: Easy to deploy built in solutions for popular deployment scenarios
- **Llama First**: Explicit focus on Meta's Llama models and partnering ecosystem
With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations. With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations.

View file

@ -92,6 +92,8 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
## Starting the Llama Stack Playground ## Starting the Llama Stack Playground
### Llama CLI
To start the Llama Stack Playground, run the following commands: To start the Llama Stack Playground, run the following commands:
1. Start up the Llama Stack API server 1. Start up the Llama Stack API server
@ -107,3 +109,28 @@ cd llama_stack/distribution/ui
pip install -r requirements.txt pip install -r requirements.txt
streamlit run app.py streamlit run app.py
``` ```
### Docker
Playground can also be started in a docker image:
```sh
export LLAMA_STACK_URL=http://localhost:11434
docker run \
-p 8501:8501 \
-e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \
quay.io/jland/llama-stack-playground
```
## Configurable Environment Variables
## Environment Variables
| Environment Variable | Description | Default Value |
|----------------------------|------------------------------------|---------------------------|
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |

View file

@ -3,21 +3,36 @@ orphan: true
--- ---
# Qdrant # Qdrant
[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It [Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval. That means you'll get fast and efficient vector retrieval.
> By default, Qdrant stores vectors in RAM, delivering incredibly fast access for datasets that fit comfortably in
> memory. But when your dataset exceeds RAM capacity, Qdrant offers Memmap as an alternative.
>
> \[[An Introduction to Vector Databases](https://qdrant.tech/articles/what-is-a-vector-database/)\]
## Features ## Features
- Easy to use - Lightweight and easy to use
- Fully integrated with Llama Stack - Fully integrated with Llama Stack
- Apache 2.0 license terms
- Store embeddings and their metadata
- Supports search by
[Keyword](https://qdrant.tech/articles/qdrant-introduces-full-text-filters-and-indexes/)
and [Hybrid](https://qdrant.tech/articles/hybrid-search/#building-a-hybrid-search-system-in-qdrant) search
- [Multilingual and Multimodal retrieval](https://qdrant.tech/documentation/multimodal-search/)
- [Medatata filtering](https://qdrant.tech/articles/vector-search-filtering/)
- [GPU support](https://qdrant.tech/documentation/guides/running-with-gpu/)
## Usage ## Usage
To use Qdrant in your Llama Stack project, follow these steps: To use Qdrant in your Llama Stack project, follow these steps:
1. Install the necessary dependencies. 1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss. 2. Configure your Llama Stack project to use Qdrant.
3. Start storing and querying vectors. 3. Start storing and querying vectors.
## Installation ## Installation

View file

@ -114,23 +114,17 @@ pprint(response)
simpleqa_dataset_id = "huggingface::simpleqa" simpleqa_dataset_id = "huggingface::simpleqa"
_ = client.datasets.register( _ = client.datasets.register(
purpose="eval/messages-answer",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id=simpleqa_dataset_id, dataset_id=simpleqa_dataset_id,
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
metadata={
"path": "llamastack/simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
},
) )
eval_rows = client.datasetio.get_rows_paginated( eval_rows = client.datasets.iterrows(
dataset_id=simpleqa_dataset_id, dataset_id=simpleqa_dataset_id,
rows_in_page=5, limit=5,
) )
``` ```
@ -143,7 +137,7 @@ client.benchmarks.register(
response = client.eval.evaluate_rows( response = client.eval.evaluate_rows(
benchmark_id="meta-reference::simpleqa", benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows, input_rows=eval_rows.data,
scoring_functions=["llm-as-judge::405b-simpleqa"], scoring_functions=["llm-as-judge::405b-simpleqa"],
benchmark_config={ benchmark_config={
"eval_candidate": { "eval_candidate": {
@ -191,7 +185,7 @@ agent_config = {
response = client.eval.evaluate_rows( response = client.eval.evaluate_rows(
benchmark_id="meta-reference::simpleqa", benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows, input_rows=eval_rows.data,
scoring_functions=["llm-as-judge::405b-simpleqa"], scoring_functions=["llm-as-judge::405b-simpleqa"],
benchmark_config={ benchmark_config={
"eval_candidate": { "eval_candidate": {

View file

@ -6,17 +6,32 @@ The `llama-stack-client` CLI allows you to query information about the distribut
### `llama-stack-client` ### `llama-stack-client`
```bash ```bash
llama-stack-client -h llama-stack-client
Usage: llama-stack-client [OPTIONS] COMMAND [ARGS]...
usage: llama-stack-client [-h] {models,memory_banks,shields} ... Welcome to the LlamaStackClient CLI
Welcome to the LlamaStackClient CLI Options:
--version Show the version and exit.
--endpoint TEXT Llama Stack distribution endpoint
--api-key TEXT Llama Stack distribution API key
--config TEXT Path to config file
--help Show this message and exit.
options: Commands:
-h, --help show this help message and exit configure Configure Llama Stack Client CLI.
datasets Manage datasets.
subcommands: eval Run evaluation tasks.
{models,memory_banks,shields} eval_tasks Manage evaluation tasks.
inference Inference (chat).
inspect Inspect server configuration.
models Manage GenAI models.
post_training Post-training.
providers Manage API providers.
scoring_functions Manage scoring functions.
shields Manage safety shield services.
toolgroups Manage available tool groups.
vector_dbs Manage vector databases.
``` ```
### `llama-stack-client configure` ### `llama-stack-client configure`
@ -127,11 +142,11 @@ llama-stack-client vector_dbs list
llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>] llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
``` ```
Options: Optional arguments:
- `--provider-id`: Optional. Provider ID for the vector db - `--provider-id`: Provider ID for the vector db
- `--provider-vector-db-id`: Optional. Provider's vector db ID - `--provider-vector-db-id`: Provider's vector db ID
- `--embedding-model`: Optional. Embedding model to use. Default: "all-MiniLM-L6-v2" - `--embedding-model`: Embedding model to use. Default: "all-MiniLM-L6-v2"
- `--embedding-dimension`: Optional. Dimension of embeddings. Default: 384 - `--embedding-dimension`: Dimension of embeddings. Default: 384
### `llama-stack-client vector_dbs unregister` ### `llama-stack-client vector_dbs unregister`
```bash ```bash
@ -157,11 +172,13 @@ llama-stack-client shields list
llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>] llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
``` ```
Options: Required arguments:
- `--shield-id`: Required. ID of the shield - `--shield-id`: ID of the shield
- `--provider-id`: Optional. Provider ID for the shield
- `--provider-shield-id`: Optional. Provider's shield ID Optional arguments:
- `--params`: Optional. JSON configuration parameters for the shield - `--provider-id`: Provider ID for the shield
- `--provider-shield-id`: Provider's shield ID
- `--params`: JSON configuration parameters for the shield
## Eval Task Management ## Eval Task Management
@ -175,13 +192,15 @@ llama-stack-client benchmarks list
llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>] llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
``` ```
Options: Required arguments:
- `--eval-task-id`: Required. ID of the eval task - `--eval-task-id`: ID of the eval task
- `--dataset-id`: Required. ID of the dataset to evaluate - `--dataset-id`: ID of the dataset to evaluate
- `--scoring-functions`: Required. One or more scoring functions to use for evaluation - `--scoring-functions`: One or more scoring functions to use for evaluation
- `--provider-id`: Optional. Provider ID for the eval task
- `--provider-eval-task-id`: Optional. Provider's eval task ID Optional arguments:
- `--metadata`: Optional. Metadata for the eval task in JSON format - `--provider-id`: Provider ID for the eval task
- `--provider-eval-task-id`: Provider's eval task ID
- `--metadata`: Metadata for the eval task in JSON format
## Eval execution ## Eval execution
### `llama-stack-client eval run-benchmark` ### `llama-stack-client eval run-benchmark`
@ -189,11 +208,13 @@ Options:
llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize] llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
``` ```
Options: Required arguments:
- `--eval-task-config`: Required. Path to the eval task config file in JSON format - `--eval-task-config`: Path to the eval task config file in JSON format
- `--output-dir`: Required. Path to the directory where evaluation results will be saved - `--output-dir`: Path to the directory where evaluation results will be saved
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
- `--visualize`: Optional flag. If set, visualizes evaluation results after completion Optional arguments:
- `--num-examples`: Number of examples to evaluate (useful for debugging)
- `--visualize`: If set, visualizes evaluation results after completion
Example benchmark_config.json: Example benchmark_config.json:
```json ```json
@ -214,11 +235,13 @@ Example benchmark_config.json:
llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize] llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
``` ```
Options: Required arguments:
- `--eval-task-config`: Required. Path to the eval task config file in JSON format - `--eval-task-config`: Path to the eval task config file in JSON format
- `--output-dir`: Required. Path to the directory where scoring results will be saved - `--output-dir`: Path to the directory where scoring results will be saved
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
- `--visualize`: Optional flag. If set, visualizes scoring results after completion Optional arguments:
- `--num-examples`: Number of examples to evaluate (useful for debugging)
- `--visualize`: If set, visualizes scoring results after completion
## Tool Group Management ## Tool Group Management
@ -230,11 +253,11 @@ llama-stack-client toolgroups list
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
| identifier | provider_id | args | mcp_endpoint | | identifier | provider_id | args | mcp_endpoint |
+===========================+==================+======+===============+ +===========================+==================+======+===============+
| builtin::code_interpreter | code-interpreter | None | None | | builtin::code_interpreter | code-interpreter | None | None |
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
| builtin::rag | rag-runtime | None | None | | builtin::rag | rag-runtime | None | None |
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
| builtin::websearch | tavily-search | None | None | | builtin::websearch | tavily-search | None | None |
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
``` ```
@ -250,11 +273,11 @@ Shows detailed information about a specific toolgroup. If the toolgroup is not f
llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>] llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>]
``` ```
Options: Optional arguments:
- `--provider-id`: Optional. Provider ID for the toolgroup - `--provider-id`: Provider ID for the toolgroup
- `--provider-toolgroup-id`: Optional. Provider's toolgroup ID - `--provider-toolgroup-id`: Provider's toolgroup ID
- `--mcp-config`: Optional. JSON configuration for the MCP endpoint - `--mcp-config`: JSON configuration for the MCP endpoint
- `--args`: Optional. JSON arguments for the toolgroup - `--args`: JSON arguments for the toolgroup
### `llama-stack-client toolgroups unregister` ### `llama-stack-client toolgroups unregister`
```bash ```bash

View file

@ -52,7 +52,7 @@ class Benchmarks(Protocol):
async def get_benchmark( async def get_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
) -> Optional[Benchmark]: ... ) -> Benchmark: ...
@webmethod(route="/eval/benchmarks", method="POST") @webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark( async def register_benchmark(

View file

@ -13,19 +13,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type
class PaginatedRowsResult(BaseModel): class IterrowsResponse(BaseModel):
""" """
A paginated list of rows from a dataset. A paginated list of rows from a dataset.
:param rows: The rows in the current page. :param data: The rows in the current page.
:param total_count: The total number of rows in the dataset. :param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
:param next_page_token: The token to get the next page of rows.
""" """
# the rows obey the DatasetSchema for the given dataset data: List[Dict[str, Any]]
rows: List[Dict[str, Any]] next_start_index: Optional[int] = None
total_count: int
next_page_token: Optional[str] = None
class DatasetStore(Protocol): class DatasetStore(Protocol):
@ -37,22 +34,21 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used # keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore dataset_store: DatasetStore
@webmethod(route="/datasetio/rows", method="GET") # TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
async def get_rows_paginated( @webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, start_index: Optional[int] = None,
page_token: Optional[str] = None, limit: Optional[int] = None,
filter_condition: Optional[str] = None, ) -> IterrowsResponse:
) -> PaginatedRowsResult: """Get a paginated list of rows from a dataset. Uses cursor-based pagination.
"""Get a paginated list of rows from a dataset.
:param dataset_id: The ID of the dataset to get the rows from. :param dataset_id: The ID of the dataset to get the rows from.
:param rows_in_page: The number of rows to get per page. :param start_index: Index into dataset for the first row to get. Get all rows if None.
:param page_token: The token to get the next page of rows. :param limit: The number of rows to get.
:param filter_condition: (Optional) A condition to filter the rows by.
""" """
... ...
@webmethod(route="/datasetio/rows", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -4,19 +4,102 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class DatasetPurpose(str, Enum):
"""
Purpose of the dataset. Each purpose has a required input data schema.
:cvar post-training/messages: The dataset contains messages used for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
:cvar eval/question-answer: The dataset contains a question column and an answer column.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
uri = "uri"
rows = "rows"
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI.
:param uri: The dataset can be obtained from a URI. E.g.
- "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
:param rows: The dataset is stored in rows. E.g.
- [
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
DataSource = register_schema(
Annotated[
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
],
name="DataSource",
)
class CommonDatasetFields(BaseModel): class CommonDatasetFields(BaseModel):
dataset_schema: Dict[str, ParamType] """
url: URL Common fields for a dataset.
"""
purpose: DatasetPurpose
source: DataSource
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this dataset", description="Any additional metadata for this dataset",
@ -50,19 +133,75 @@ class Datasets(Protocol):
@webmethod(route="/datasets", method="POST") @webmethod(route="/datasets", method="POST")
async def register_dataset( async def register_dataset(
self, self,
dataset_id: str, purpose: DatasetPurpose,
dataset_schema: Dict[str, ParamType], source: DataSource,
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
) -> None: ... dataset_id: Optional[str] = None,
) -> Dataset:
"""
Register a new dataset.
:param purpose: The purpose of the dataset. One of
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"
}
- {
"type": "uri",
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "uri",
"uri": "data:csv;base64,{base64_content}"
}
- {
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
- {
"type": "rows",
"rows": [
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET") @webmethod(route="/datasets/{dataset_id:path}", method="GET")
async def get_dataset( async def get_dataset(
self, self,
dataset_id: str, dataset_id: str,
) -> Optional[Dataset]: ... ) -> Dataset: ...
@webmethod(route="/datasets", method="GET") @webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ... async def list_datasets(self) -> ListDatasetsResponse: ...

View file

@ -117,7 +117,7 @@ class Eval(Protocol):
""" """
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
"""Get the status of a job. """Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -115,7 +115,7 @@ class Files(Protocol):
async def get_upload_session_info( async def get_upload_session_info(
self, self,
upload_id: str, upload_id: str,
) -> Optional[FileUploadResponse]: ) -> FileUploadResponse:
""" """
Returns information about an existsing upload session Returns information about an existsing upload session

View file

@ -66,7 +66,7 @@ class Models(Protocol):
async def get_model( async def get_model(
self, self,
model_id: str, model_id: str,
) -> Optional[Model]: ... ) -> Model: ...
@webmethod(route="/models", method="POST") @webmethod(route="/models", method="POST")
async def register_model( async def register_model(

View file

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel):
AlgorithmConfig = register_schema( AlgorithmConfig = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
name="AlgorithmConfig", name="AlgorithmConfig",
) )
@ -184,7 +184,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
), ),
checkpoint_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None, algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")
@ -202,10 +202,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET") @webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ... async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel", method="POST") @webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET") @webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ... async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...

View file

@ -136,7 +136,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(

View file

@ -49,7 +49,7 @@ class Shields(Protocol):
async def list_shields(self) -> ListShieldsResponse: ... async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier:path}", method="GET") @webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ... async def get_shield(self, identifier: str) -> Shield: ...
@webmethod(route="/shields", method="POST") @webmethod(route="/shields", method="POST")
async def register_shield( async def register_shield(

View file

@ -50,7 +50,7 @@ class VectorDBs(Protocol):
async def get_vector_db( async def get_vector_db(
self, self,
vector_db_id: str, vector_db_id: str,
) -> Optional[VectorDB]: ... ) -> VectorDB: ...
@webmethod(route="/vector-dbs", method="POST") @webmethod(route="/vector-dbs", method="POST")
async def register_vector_db( async def register_vector_db(

View file

@ -38,7 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty from llama_stack.distribution.utils.exec import formulate_run_args, run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -213,7 +213,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))]) run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_with_pty(run_args) run_command(run_args)
def _generate_run_config( def _generate_run_config(

View file

@ -82,7 +82,7 @@ class StackRun(Subcommand):
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty from llama_stack.distribution.utils.exec import formulate_run_args, run_command
config_file = Path(args.config) config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml") has_yaml_suffix = args.config.endswith(".yaml")
@ -136,4 +136,4 @@ class StackRun(Subcommand):
if args.tls_keyfile and args.tls_certfile: if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile]) run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_with_pty(run_args) run_command(run_args)

View file

@ -6,7 +6,6 @@
import importlib.resources import importlib.resources
import logging import logging
import sys
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
@ -15,7 +14,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -123,11 +122,7 @@ def build_image(
if special_deps: if special_deps:
args.append("#".join(special_deps)) args.append("#".join(special_deps))
is_terminal = sys.stdin.isatty() return_code = run_command(args)
if is_terminal:
return_code = run_with_pty(args)
else:
return_code = run_command(args)
if return_code != 0: if return_code != 0:
log.error( log.error(

View file

@ -43,7 +43,7 @@ RED='\033[0;31m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-} CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
TEMP_DIR=$(mktemp -d) TEMP_DIR=$(mktemp -d)
@ -253,8 +253,7 @@ $CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \ "${CLI_ARGS[@]}" \
-t "$image_tag" \ -t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \ -f "$TEMP_DIR/Containerfile" \
"." \ "."
--progress=plain
# clean up tmp/configs # clean up tmp/configs
set +x set +x

View file

@ -125,6 +125,13 @@ class LoggingConfig(BaseModel):
) )
class AuthenticationConfig(BaseModel):
endpoint: str = Field(
...,
description="Endpoint URL to validate authentication tokens",
)
class ServerConfig(BaseModel): class ServerConfig(BaseModel):
port: int = Field( port: int = Field(
default=8321, default=8321,
@ -140,6 +147,10 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="Path to TLS key file for HTTPS", description="Path to TLS key file for HTTPS",
) )
auth: Optional[AuthenticationConfig] = Field(
default=None,
description="Authentication configuration for the server",
)
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):

View file

@ -8,10 +8,13 @@
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from .datatypes import StackRunConfig from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields from .stack import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel): class ProviderImplConfig(BaseModel):
run_config: StackRunConfig run_config: StackRunConfig
@ -31,6 +34,10 @@ class ProviderImpl(Providers):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse: async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump())) safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))

View file

@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
) )
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
BenchmarkConfig, BenchmarkConfig,
Eval, Eval,
@ -160,7 +161,11 @@ class InferenceRouter(Inference):
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics( def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]: ) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics. """Constructs a list of MetricEvent objects containing token usage metrics.
@ -298,7 +303,12 @@ class InferenceRouter(Inference):
completion_text += chunk.event.delta.text completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete: if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], [
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format, tool_config.tool_prompt_format,
) )
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
@ -471,21 +481,36 @@ class DatasetIORouter(DatasetIO):
logger.debug("DatasetIORouter.shutdown") logger.debug("DatasetIORouter.shutdown")
pass pass
async def get_rows_paginated( async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, start_index: Optional[int] = None,
page_token: Optional[str] = None, limit: Optional[int] = None,
filter_condition: Optional[str] = None, ) -> IterrowsResponse:
) -> PaginatedRowsResult:
logger.debug( logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
) )
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=rows_in_page, start_index=start_index,
page_token=page_token, limit=limit,
filter_condition=filter_condition,
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter from pydantic import TypeAdapter
@ -12,7 +13,14 @@ from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ( from llama_stack.apis.scoring_functions import (
@ -211,8 +219,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse: async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model")) return ListModelsResponse(data=await self.get_all_with_type("model"))
async def get_model(self, model_id: str) -> Optional[Model]: async def get_model(self, model_id: str) -> Model:
return await self.get_object_by_identifier("model", model_id) model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model( async def register_model(
self, self,
@ -259,8 +270,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse: async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Optional[Shield]: async def get_shield(self, identifier: str) -> Shield:
return await self.get_object_by_identifier("shield", identifier) shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield( async def register_shield(
self, self,
@ -295,8 +309,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse: async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: async def get_vector_db(self, vector_db_id: str) -> VectorDB:
return await self.get_object_by_identifier("vector_db", vector_db_id) vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise ValueError(f"Vector DB '{vector_db_id}' not found")
return vector_db
async def register_vector_db( async def register_vector_db(
self, self,
@ -347,39 +364,50 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse: async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: async def get_dataset(self, dataset_id: str) -> Dataset:
return await self.get_object_by_identifier("dataset", dataset_id) dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise ValueError(f"Dataset '{dataset_id}' not found")
return dataset
async def register_dataset( async def register_dataset(
self, self,
dataset_id: str, purpose: DatasetPurpose,
dataset_schema: Dict[str, ParamType], source: DataSource,
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
) -> None: dataset_id: Optional[str] = None,
if provider_dataset_id is None: ) -> Dataset:
provider_dataset_id = dataset_id if not dataset_id:
if provider_id is None: dataset_id = f"dataset-{str(uuid.uuid4())}"
# If provider_id not specified, use the only provider if it supports this dataset
if len(self.impls_by_provider_id) == 1: provider_dataset_id = dataset_id
provider_id = list(self.impls_by_provider_id.keys())[0]
# infer provider from source
if source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else: else:
raise ValueError( provider_id = "localfs"
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" else:
) raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None: if metadata is None:
metadata = {} metadata = {}
dataset = Dataset( dataset = Dataset(
identifier=dataset_id, identifier=dataset_id,
provider_resource_id=provider_dataset_id, provider_resource_id=provider_dataset_id,
provider_id=provider_id, provider_id=provider_id,
dataset_schema=dataset_schema, purpose=purpose,
url=url, source=source,
metadata=metadata, metadata=metadata,
) )
await self.register_object(dataset) await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None: async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id) dataset = await self.get_dataset(dataset_id)
@ -392,8 +420,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id) scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function( async def register_scoring_function(
self, self,
@ -429,8 +460,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse: async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]: async def get_benchmark(self, benchmark_id: str) -> Benchmark:
return await self.get_object_by_identifier("benchmark", benchmark_id) benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark( async def register_benchmark(
self, self,
@ -474,7 +508,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id) tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool: async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name) return await self.get_object_by_identifier("tool", tool_name)

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from urllib.parse import parse_qs
import httpx
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthenticationMiddleware:
def __init__(self, app, auth_endpoint):
self.app = app
self.auth_endpoint = auth_endpoint
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
api_key = auth_header.split("Bearer ", 1)[1]
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
auth_data = {
"api_key": api_key,
"request": {
"path": path,
"headers": request_headers,
"params": params,
},
}
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(self.auth_endpoint, json=auth_data)
if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed")
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):
await send(
{
"type": "http.response.start",
"status": 401,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": error_msg})

View file

@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace, start_trace,
) )
from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -351,6 +352,11 @@ def main():
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth and config.server.auth.endpoint:
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e: except InvalidProviderError as e:

View file

@ -0,0 +1,11 @@
# More info on playground configuration can be found here:
# https://llama-stack.readthedocs.io/en/latest/playground
FROM python:3.9-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \
/usr/local/bin/pip3 install -r requirements.txt
EXPOSE 8501
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]

View file

@ -40,3 +40,13 @@ cd llama_stack/distribution/ui
pip install -r requirements.txt pip install -r requirements.txt
streamlit run app.py streamlit run app.py
``` ```
## Environment Variables
| Environment Variable | Description | Default Value |
|----------------------------|------------------------------------|---------------------------|
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def datasets(): def datasets():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def benchmarks(): def benchmarks():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def models(): def models():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def providers(): def providers():

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from page.distribution.benchmarks import benchmarks
from page.distribution.datasets import datasets
from page.distribution.models import models
from page.distribution.scoring_functions import scoring_functions
from page.distribution.shields import shields
from page.distribution.vector_dbs import vector_dbs
from streamlit_option_menu import option_menu from streamlit_option_menu import option_menu
from llama_stack.distribution.ui.page.distribution.datasets import datasets
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.distribution.ui.page.distribution.models import models
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.distribution.ui.page.distribution.shields import shields
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
def resources_page(): def resources_page():
options = [ options = [

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def scoring_functions(): def scoring_functions():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def shields(): def shields():

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def vector_dbs(): def vector_dbs():

View file

@ -8,8 +8,9 @@ import json
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from modules.utils import process_dataset from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import process_dataset
def application_evaluation_page(): def application_evaluation_page():

View file

@ -8,7 +8,8 @@ import json
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
def select_benchmark_1(): def select_benchmark_1():
@ -166,11 +167,10 @@ def run_evaluation_3():
eval_candidate = st.session_state["eval_candidate"] eval_candidate = st.session_state["eval_candidate"]
dataset_id = benchmarks[selected_benchmark].dataset_id dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasetio.get_rows_paginated( rows = llama_stack_api.client.datasets.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1,
) )
total_rows = len(rows.rows) total_rows = len(rows.data)
# Add number of examples control # Add number of examples control
num_rows = st.number_input( num_rows = st.number_input(
"Number of Examples to Evaluate", "Number of Examples to Evaluate",
@ -195,7 +195,7 @@ def run_evaluation_3():
if st.button("Run Evaluation"): if st.button("Run Evaluation"):
progress_text = "Running evaluation..." progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text) progress_bar = st.progress(0, text=progress_text)
rows = rows.rows rows = rows.data
if num_rows < total_rows: if num_rows < total_rows:
rows = rows[:num_rows] rows = rows[:num_rows]

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.api import llama_stack_api
# Sidebar configurations # Sidebar configurations
with st.sidebar: with st.sidebar:

View file

@ -7,9 +7,10 @@
import streamlit as st import streamlit as st
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.document import Document
from modules.api import llama_stack_api
from modules.utils import data_url_from_file from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file
def rag_chat_page(): def rag_chat_page():

View file

@ -4,13 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import errno
import logging import logging
import os import os
import select
import signal import signal
import subprocess import subprocess
import sys
from termcolor import cprint from termcolor import cprint
@ -88,13 +85,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
return run_args return run_args
def run_with_pty(command):
if sys.platform.startswith("win"):
return _run_with_pty_win(command)
else:
return _run_with_pty_unix(command)
def in_notebook(): def in_notebook():
try: try:
from IPython import get_ipython from IPython import get_ipython
@ -108,19 +98,19 @@ def in_notebook():
return True return True
# run a command in a pseudo-terminal, with interrupt handling, def run_command(command: list[str]) -> int:
# useful when you want to run interactive things """
def _run_with_pty_unix(command): Run a command with interrupt handling and output capture.
import pty Uses subprocess.run with direct stream piping for better performance.
import termios
master, slave = pty.openpty() Args:
command (list): The command to run.
old_settings = termios.tcgetattr(sys.stdin) Returns:
int: The return code of the command.
"""
original_sigint = signal.getsignal(signal.SIGINT) original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False ctrl_c_pressed = False
process = None
def sigint_handler(signum, frame): def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed nonlocal ctrl_c_pressed
@ -131,106 +121,19 @@ def _run_with_pty_unix(command):
# Set up the signal handler # Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
new_settings = termios.tcgetattr(sys.stdin) # Run the command with stdout/stderr piped directly to system streams
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo result = subprocess.run(
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
process = subprocess.Popen(
command, command,
stdin=slave, text=True,
stdout=slave, check=False,
stderr=slave,
universal_newlines=True,
preexec_fn=os.setsid,
) )
return result.returncode
# Close the slave file descriptor as it's now owned by the subprocess except subprocess.SubprocessError as e:
os.close(slave) log.error(f"Subprocess error: {e}")
return 1
def handle_io():
while not ctrl_c_pressed:
try:
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
sys.stdout.buffer.write(data)
sys.stdout.flush()
except KeyboardInterrupt:
# This will be raised when Ctrl+C is pressed
break
if process.poll() is not None:
break
handle_io()
except (EOFError, KeyboardInterrupt):
pass
except OSError as e:
if e.errno != errno.EIO:
raise
finally:
# Clean up
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
signal.signal(signal.SIGINT, original_sigint)
os.close(master)
if process and process.poll() is None:
process.terminate()
process.wait()
return process.returncode
# run a command in a pseudo-terminal in windows, with interrupt handling,
def _run_with_pty_win(command):
"""
Runs a command with interactive support using subprocess directly.
"""
try:
# For shell scripts on Windows, use appropriate shell
if isinstance(command, (list, tuple)):
if command[0].endswith(".sh"):
if os.path.exists("/usr/bin/bash"): # WSL
command = ["bash"] + command
else:
# Use cmd.exe with bash while preserving all arguments
command = ["cmd.exe", "/c", "bash"] + command
process = subprocess.Popen(
command,
shell=True,
universal_newlines=True,
)
process.wait()
except Exception as e: except Exception as e:
print(f"Error: {str(e)}") log.exception(f"Unexpected error: {e}")
return 1 return 1
finally: finally:
if process and process.poll() is None: # Restore the original signal handler
process.terminate() signal.signal(signal.SIGINT, original_sigint)
process.wait()
return process.returncode
def run_command(command):
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
print("Script Output\n", result.stdout)
return result.returncode
except subprocess.CalledProcessError as e:
print("Error running script:", e)
print("Error output:", e.stderr)
return e.returncode

View file

@ -614,118 +614,133 @@ class ChatAgent(ShieldRunnerMixin):
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}") input_messages = input_messages + [message]
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
tool_call=tool_call,
),
)
)
)
# If tool is a client tool, yield CompletionMessage and return # Process tool calls in the message
if tool_call.tool_name in client_tools: client_tool_calls = []
# NOTE: mark end_of_message to indicate to client that it may non_client_tool_calls = []
# call the tool and continue the conversation with the tool's response.
message.stop_reason = StopReason.end_of_message # Separate client and non-client tool calls
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
else:
non_client_tool_calls.append(tool_call)
# Process non-client tool calls first
for tool_call in non_client_tool_calls:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
tool_call=tool_call,
),
)
)
)
# Execute the tool call
async with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
)
if tool_result.content is None:
raise ValueError(
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
)
result_message = ToolResponseMessage(
call_id=tool_call.call_id,
content=tool_result.content,
)
span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step
tool_execution_step = ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
)
# Yield the step completion event
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=tool_execution_step,
)
)
)
# Add the result message to input_messages for the next iteration
input_messages.append(result_message)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
# If there are client tool calls, yield a message with only those tool calls
if client_tool_calls:
await self.storage.set_in_progress_tool_call_step( await self.storage.set_in_progress_tool_call_step(
session_id, session_id,
turn_id, turn_id,
ToolExecutionStep( ToolExecutionStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
tool_calls=[tool_call], tool_calls=client_tool_calls,
tool_responses=[], tool_responses=[],
started_at=datetime.now(timezone.utc).isoformat(), started_at=datetime.now(timezone.utc).isoformat(),
), ),
) )
yield message
# Create a copy of the message with only client tool calls
client_message = message.model_copy(deep=True)
client_message.tool_calls = client_tool_calls
# NOTE: mark end_of_message to indicate to client that it may
# call the tool and continue the conversation with the tool's response.
client_message.stop_reason = StopReason.end_of_message
# Yield the message with client tool calls
yield client_message
return return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
async with tracing.span(
"tool_execution",
{
"tool_name": tool_name,
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_call = message.tool_calls[0]
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
)
if tool_result.content is None:
raise ValueError(
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
)
result_messages = [
ToolResponseMessage(
call_id=tool_call.call_id,
content=tool_result.content,
)
]
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=tool_call.tool_name,
content=result_message.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
input_messages = input_messages + [message, result_message]
async def _initialize_tools( async def _initialize_tools(
self, self,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
@ -891,16 +906,14 @@ class ChatAgent(ShieldRunnerMixin):
if memory_tool and code_interpreter_tool: if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs # if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message. # and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items) await attachment_message(self.tempdir, url_items, input_messages[-1])
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank # Since memory is present, add all the data to the memory bank
await self.add_to_session_vector_db(session_id, documents) await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool: elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir # if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the # and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path # assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items) await attachment_message(self.tempdir, url_items, input_messages[-1])
input_messages.append(msg)
elif memory_tool: elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank # if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_vector_db(session_id, documents) await self.add_to_session_vector_db(session_id, documents)
@ -967,8 +980,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]:
return data return data
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage: async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
content = [] contents = []
for url in urls: for url in urls:
uri = url.uri uri = url.uri
@ -988,16 +1001,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else: else:
raise ValueError(f"Unsupported URL {url}") raise ValueError(f"Unsupported URL {url}")
content.append( contents.append(
TextContentItem( TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
) )
) )
return ToolResponseMessage( if isinstance(message.content, list):
call_id="", message.content.extend(contents)
content=content, else:
) if isinstance(message.content, str):
message.content = [TextContentItem(text=message.content)] + contents
else:
message.content = [message.content] + contents
def _interpret_content_as_attachment( def _interpret_content_as_attachment(

View file

@ -3,20 +3,14 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import pandas import pandas
from llama_stack.apis.common.content_types import URL from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:" DATASETS_PREFIX = "localfs_datasets:"
class BaseDataset(ABC): class PandasDataframeDataset:
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError()
@abstractmethod
def load(self):
raise NotImplementedError()
@dataclass
class DatasetInfo:
dataset_def: Dataset
dataset_impl: BaseDataset
class PandasDataframeDataset(BaseDataset):
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None: def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dataset_def = dataset_def self.dataset_def = dataset_def
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
else: else:
return self.df.iloc[idx].to_dict() return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None: def load(self) -> None:
if self.df is not None: if self.df is not None:
return return
df = get_dataframe_from_url(self.dataset_def.url) if self.dataset_def.source.type == "uri":
if df is None: self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
self.df = self._validate_dataset_schema(df) if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -99,95 +66,55 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
for dataset in stored_datasets: for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset) dataset = Dataset.model_validate_json(dataset)
dataset_impl = PandasDataframeDataset(dataset) self.dataset_infos[dataset.identifier] = dataset
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def register_dataset( async def register_dataset(
self, self,
dataset: Dataset, dataset_def: Dataset,
) -> None: ) -> None:
# Store in kvstore # Store in kvstore
key = f"{DATASETS_PREFIX}{dataset.identifier}" key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
value=dataset.json(), value=dataset_def.model_dump_json(),
)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
) )
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None: async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}" key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key) await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id] del self.dataset_infos[dataset_id]
async def get_rows_paginated( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, start_index: Optional[int] = None,
page_token: Optional[str] = None, limit: Optional[int] = None,
filter_condition: Optional[str] = None, ) -> IterrowsResponse:
) -> PaginatedRowsResult: dataset_def = self.dataset_infos[dataset_id]
dataset_info = self.dataset_infos.get(dataset_id) dataset_impl = PandasDataframeDataset(dataset_def)
dataset_info.dataset_impl.load() dataset_impl.load()
if page_token and not page_token.isnumeric(): start_index = start_index or 0
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0: if limit is None or limit == -1:
next_page_token = 0 end = len(dataset_impl)
else: else:
next_page_token = int(page_token) end = min(start_index + limit, len(dataset_impl))
start = next_page_token rows = dataset_impl[start_index:end]
if rows_in_page == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
rows = dataset_info.dataset_impl[start:end] return IterrowsResponse(
data=rows,
return PaginatedRowsResult( next_start_index=end if end < len(dataset_impl) else None,
rows=rows,
total_count=len(rows),
next_page_token=str(end),
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id) dataset_def = self.dataset_infos[dataset_id]
if dataset_info is None: dataset_impl = PandasDataframeDataset(dataset_def)
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_impl.load() dataset_impl.load()
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)

View file

@ -14,16 +14,11 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL, MEMORY_QUERY_TOOL,
) )
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import ColumnName
ColumnName,
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job from .....apis.common.job_types import Job
@ -88,15 +83,17 @@ class MetaReferenceEvalImpl(
task_def = self.benchmarks[benchmark_id] task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) # TODO (xiyan): validate dataset schema
all_rows = await self.datasetio_api.get_rows_paginated( # dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
) )
res = await self.evaluate_rows( res = await self.evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=all_rows.rows, input_rows=all_rows.data,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
benchmark_config=benchmark_config, benchmark_config=benchmark_config,
) )

View file

@ -10,6 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import copy
import json import json
import logging import logging
import multiprocessing import multiprocessing
@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
def parse_message(json_str: str) -> ProcessingMessage: def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str) data = json.loads(json_str)
return ProcessingMessageWrapper(**data).payload return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
def worker_process_entrypoint( def worker_process_entrypoint(

View file

@ -9,6 +9,9 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.type_system import ( from llama_stack.apis.common.type_system import (
ChatCompletionInputType, ChatCompletionInputType,
DialogType, DialogType,
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
validate_dataset_schema, validate_dataset_schema,
) )
EXPECTED_DATASET_SCHEMA = { EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
"instruct": [ "instruct": [
{ {
ColumnName.chat_completion_input.value: ChatCompletionInputType(), ColumnName.chat_completion_input.value: ChatCompletionInputType(),
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
dataset_type: str, dataset_type: str,
) -> None: ) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def:
raise ValueError(f"Dataset {dataset_id} does not exist.")
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

View file

@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
checkpoint_files: List[str], checkpoint_files: List[str],
output_dir: str, output_dir: str,
model_type: str, model_type: str,
) -> None: ):
# Fail fast if ``checkpoint_files`` is invalid # Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file # TODO: support loading more than one file
if len(checkpoint_files) != 1: if len(checkpoint_files) != 1:
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
""" """
Load Meta checkpoint from file. Currently only loading from a single file is supported. Load Meta checkpoint from file. Currently only loading from a single file is supported.
""" """
state_dict: Dict[str:Any] = {} state_dict: Dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path) model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION: if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import ( from torchtune.models.llama3_2_vision._convert_weights import (
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
epoch: int, epoch: int,
adapter_only: bool = False, adapter_only: bool = False,
checkpoint_format: str = "meta", checkpoint_format: str | None = None,
) -> str: ) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
if checkpoint_format == "meta": if checkpoint_format == "meta" or checkpoint_format is None:
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only) self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface": elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now

View file

@ -10,7 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Callable, Dict from typing import Callable, Dict
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
model_definition: Any model_definition: BuildLoraModelCallable
tokenizer_type: Any tokenizer_type: BuildTokenizerCallable
checkpoint_type: str checkpoint_type: str
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
} }
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def _validate_model_id(model_id: str) -> Model: def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id) model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS: if model is None or model.core_model_id.value not in MODEL_CONFIGS:

View file

@ -55,7 +55,7 @@ class SFTDataset(Dataset):
if "messages" in transformed_sample: if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"]) validate_messages(transformed_sample["messages"])
tokenized_dict = self._model_transform(transformed_sample) tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys()) keys_str = ", ".join(tokenized_dict.keys())

View file

@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint, Checkpoint,
LoraFinetuningConfig, LoraFinetuningConfig,
OptimizerConfig, OptimizerConfig,
QATFinetuningConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
# Currently logging only logs limited training metrics to local disk # Currently logging only logs limited training metrics to local disk
# will figure out more loggings and how it works with telemetry in future PRs # will figure out more loggings and how it works with telemetry in future PRs
_checkpointer: TorchtuneCheckpointer
def __init__( def __init__(
self, self,
config: TorchtunePostTrainingConfig, config: TorchtunePostTrainingConfig,
@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
model: str, model: str,
checkpoint_dir: Optional[str], checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig], algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
datasets_api: Datasets, datasets_api: Datasets,
) -> None: ) -> None:
@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
return str(checkpoint_dir) return str(checkpoint_dir)
if checkpoint_dir and checkpoint_dir != "null": if checkpoint_dir and checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir self.checkpoint_dir = checkpoint_dir
else: else:
model = resolve_model(self.model_id) model_obj = resolve_model(self.model_id)
if model is None: if model_obj is None:
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list") raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model_obj)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format self._checkpoint_format = config.checkpoint_format
@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
self.max_validation_steps = training_config.max_validation_steps self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0 self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
(training_config.efficiency_config.enable_activation_checkpointing) self._enable_activation_checkpointing = False
if training_config.efficiency_config self._enable_activation_offloading = False
else False if training_config.efficiency_config:
) if training_config.efficiency_config.enable_activation_checkpointing:
self._enable_activation_offloading = ( self._enable_activation_checkpointing = (
(training_config.efficiency_config.enable_activation_offloading) training_config.efficiency_config.enable_activation_checkpointing
if training_config.efficiency_config )
else False if training_config.efficiency_config.enable_activation_offloading:
) self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets_api self.datasets_api = datasets_api
@ -328,13 +331,13 @@ class LoraFinetuningSingleDevice:
batch_size: int, batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]: ) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str): async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated( return await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, limit=-1,
) )
all_rows = await fetch_rows(dataset_id) all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows rows = all_rows.data
await validate_input_dataset_schema( await validate_input_dataset_schema(
datasets_api=self.datasets_api, datasets_api=self.datasets_api,
@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
""" """
# Initialize tokens count and running loss (for grad accumulation) # Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter() t0 = time.perf_counter()
running_loss = 0 running_loss: float = 0.0
num_tokens = 0 num_tokens = 0
# training artifacts # training artifacts
checkpoints = [] checkpoints = []
memory_stats = {} memory_stats: Dict[str, Any] = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint # self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs): for curr_epoch in range(self.epochs_run, self.total_epochs):
@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
# Loss is normalized by default so we multiply by the number of tokens # Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients # This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = await self._loss_step(batch) * current_num_tokens current_loss = await self._loss_step(batch) * current_num_tokens
running_loss += current_loss running_loss += current_loss.detach().item()
current_loss.backward() current_loss.backward()
# Step with optimizer # Step with optimizer
@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated # Update the number of steps when the weights are updated
self.global_step += 1 self.global_step += 1
loss_to_log = running_loss.item() / num_tokens loss_to_log = running_loss / num_tokens
pbar.update(1) pbar.update(1)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}") pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
) )
# Reset running stats for the next step # Reset running stats for the next step
running_loss = 0 running_loss = 0.0
num_tokens = 0 num_tokens = 0
t0 = time.perf_counter() t0 = time.perf_counter()

View file

@ -227,13 +227,6 @@ class LlamaGuardShield:
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value): if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
messages = messages[1:] messages = messages[1:]
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
for i, m in enumerate(messages):
print(f"{i}: {m.role}: {m.content}")
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
)
return messages return messages
async def run(self, messages: List[Message]) -> RunShieldResponse: async def run(self, messages: List[Message]) -> RunShieldResponse:

View file

@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
@ -82,12 +84,12 @@ class BasicScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, limit=-1,
) )
res = await self.score( res = await self.score(
input_rows=all_rows.rows, input_rows=all_rows.data,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
) )
if save_results_dataset: if save_results_dataset:

View file

@ -167,11 +167,11 @@ class BraintrustScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, limit=-1,
) )
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions) res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
if save_results_dataset: if save_results_dataset:
# TODO: persist and register dataset on to server for reading # TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset() # self.datasets_api.register_dataset()

View file

@ -72,12 +72,12 @@ class LlmAsJudgeScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, limit=-1,
) )
res = await self.score( res = await self.score(
input_rows=all_rows.rows, input_rows=all_rows.data,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
) )
if save_results_dataset: if save_results_dataset:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import logging import logging
import os import os
import tempfile import tempfile
@ -37,7 +38,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -65,7 +66,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
# Use environment variable to control bwrap usage # Use environment variable to control bwrap usage
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes") force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap) req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap)
res = self.code_executor.execute(req) res = await asyncio.to_thread(self.code_executor.execute, req)
pieces = [res["process_status"]] pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]: for out_type in ["stdout", "stderr"]:
res_out = res[out_type] res_out = res[out_type]

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class QdrantVectorIOConfig(BaseModel):
path: str
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
}

View file

@ -55,4 +55,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
), ),
), ),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
),
),
] ]

View file

@ -92,6 +92,14 @@ def available_providers() -> List[ProviderSpec]:
), ),
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
), ),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::qdrant",
pip_packages=["qdrant-client"],
module="llama_stack.providers.inline.vector_io.qdrant",
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
),
remote_provider_spec( remote_provider_spec(
Api.vector_io, Api.vector_io,
AdapterSpec( AdapterSpec(

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig from .config import HuggingfaceDatasetIOConfig
@ -18,22 +18,14 @@ from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:" DATASETS_PREFIX = "datasets:"
def load_hf_dataset(dataset_def: Dataset): def parse_hf_params(dataset_def: Dataset):
if dataset_def.metadata.get("path", None): uri = dataset_def.source.uri
dataset = hf_datasets.load_dataset(**dataset_def.metadata) parsed_uri = urlparse(uri)
else: params = parse_qs(parsed_uri.query)
df = get_dataframe_from_url(dataset_def.url) params = {k: v[0] for k, v in params.items()}
path = parsed_uri.path.lstrip("/")
if df is None: return path, params
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = hf_datasets.Dataset.from_pandas(df)
# drop columns not specified by schema
if dataset_def.dataset_schema:
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
return dataset
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -64,7 +56,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
key = f"{DATASETS_PREFIX}{dataset_def.identifier}" key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
value=dataset_def.json(), value=dataset_def.model_dump_json(),
) )
self.dataset_infos[dataset_def.identifier] = dataset_def self.dataset_infos[dataset_def.identifier] = dataset_def
@ -73,41 +65,34 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
await self.kvstore.delete(key=key) await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id] del self.dataset_infos[dataset_id]
async def get_rows_paginated( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, start_index: Optional[int] = None,
page_token: Optional[str] = None, limit: Optional[int] = None,
filter_condition: Optional[str] = None, ) -> IterrowsResponse:
) -> PaginatedRowsResult:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def) path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
if page_token and not page_token.isnumeric(): start_index = start_index or 0
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0: if limit is None or limit == -1:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
end = len(loaded_dataset) end = len(loaded_dataset)
else: else:
end = min(start + rows_in_page, len(loaded_dataset)) end = min(start_index + limit, len(loaded_dataset))
rows = [loaded_dataset[i] for i in range(start, end)] rows = [loaded_dataset[i] for i in range(start_index, end)]
return PaginatedRowsResult( return IterrowsResponse(
rows=rows, data=rows,
total_count=len(rows), next_start_index=end if end < len(loaded_dataset) else None,
next_page_token=str(end),
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def) path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
# Convert rows to HF Dataset format # Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows) new_dataset = hf_datasets.Dataset.from_list(rows)

View file

@ -6,6 +6,7 @@
import logging import logging
import warnings import warnings
from functools import lru_cache
from typing import AsyncIterator, List, Optional, Union from typing import AsyncIterator, List, Optional, Union
from openai import APIConnectionError, AsyncOpenAI, BadRequestError from openai import APIConnectionError, AsyncOpenAI, BadRequestError
@ -82,12 +83,42 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# ) # )
self._config = config self._config = config
# make sure the client lives longer than any async calls
self._client = AsyncOpenAI( @lru_cache # noqa: B019
base_url=f"{self._config.url}/v1", def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), """
timeout=self._config.timeout, For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
) some models are hosted on different URLs. This function returns the appropriate client
for the given provider_model_id.
This relies on lru_cache and self._default_client to avoid creating a new client for each request
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
:param provider_model_id: The provider model ID
:return: An OpenAI client
"""
@lru_cache # noqa: B019
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
"""
Maintain a single OpenAI client per base_url.
"""
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
special_model_urls = {
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1"
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def completion( async def completion(
self, self,
@ -105,9 +136,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
await check_health(self._config) # this raises errors await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
request = convert_completion_request( request = convert_completion_request(
request=CompletionRequest( request=CompletionRequest(
model=self.get_provider_model_id(model_id), model=provider_model_id,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
@ -118,7 +150,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
response = await self._client.completions.create(**request) response = await self._get_client(provider_model_id).completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -206,6 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
await check_health(self._config) # this raises errors await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
request = await convert_chat_completion_request( request = await convert_chat_completion_request(
request=ChatCompletionRequest( request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id), model=self.get_provider_model_id(model_id),
@ -221,7 +254,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
response = await self._client.chat.completions.create(**request) response = await self._get_client(provider_model_id).chat.completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -12,6 +12,7 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
@ -160,12 +161,14 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client() client = self._get_client()
response = await client.inference.chat_completion(**json_params) response = await client.inference.chat_completion(**json_params)
response = response.to_dict() return ChatCompletionResponse(
completion_message=CompletionMessage(
# temporary hack to remove the metrics from the response content=response.completion_message.content.text,
response["metrics"] = [] stop_reason=response.completion_message.stop_reason,
tool_calls=response.completion_message.tool_calls,
return convert_to_pydantic(ChatCompletionResponse, response) ),
logprobs=response.logprobs,
)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator: async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client() client = self._get_client()

View file

@ -25,6 +25,10 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake", default="fake",
description="The API token", description="The API token",
) )
tls_verify: bool = Field(
default=True,
description="Whether to verify TLS certificates",
)
@classmethod @classmethod
def sample_run_config( def sample_run_config(
@ -36,4 +40,5 @@ class VLLMInferenceAdapterConfig(BaseModel):
"url": url, "url": url,
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}", "max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"api_token": "${env.VLLM_API_TOKEN:fake}", "api_token": "${env.VLLM_API_TOKEN:fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:true}",
} }

View file

@ -7,6 +7,7 @@ import json
import logging import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
import httpx
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
@ -229,7 +230,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None: async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}") log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token) self.client = AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
config_id (str): The ID of the guardrails configuration to use from the configuration store
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check",
}

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, List, Optional
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
"""
Initialize the NVIDIASafetyAdapter with a given safety configuration.
Args:
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
"""
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
Args:
shield_id (str): The unique identifier for the shield to be used.
messages (List[Message]): A list of Message objects representing the conversation history.
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
Returns:
RunShieldResponse: The response containing safety violation details if any.
Raises:
ValueError: If the shield with the provided shield_id is not found.
"""
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
class NeMoGuardrails:
"""
A class that encapsulates NVIDIA's guardrails safety logic.
Sends messages to the guardrails service and interprets the response to determine
if a safety violation has occurred.
"""
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
"""
Initialize a NeMoGuardrails instance with the provided parameters.
Args:
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
model (str): The identifier or name of the model to be used for safety checks.
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
Raises:
ValueError: If temperature is less than or equal to 0.
AssertionError: If config_id is not provided in the configuration.
"""
self.config_id = config.config_id
self.model = model
assert self.config_id is not None, "Must provide config id"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
Args:
messages (List[Message]): A list of Message objects to be checked for safety violations.
Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
Raises:
requests.HTTPError: If the POST request fails.
"""
headers = {
"Accept": "application/json",
}
request_data = {
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)

View file

@ -23,7 +23,6 @@ class QdrantVectorIOConfig(BaseModel):
prefix: Optional[str] = None prefix: Optional[str] = None
timeout: Optional[int] = None timeout: Optional[int] = None
host: Optional[str] = None host: Optional[str] = None
path: Optional[str] = None
@classmethod @classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:

View file

@ -6,7 +6,7 @@
import logging import logging
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from numpy.typing import NDArray from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
@ -16,12 +16,13 @@ from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
from .config import QdrantVectorIOConfig from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id" CHUNK_ID_KEY = "_chunk_id"
@ -99,17 +100,19 @@ class QdrantIndex(EmbeddingIndex):
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantVectorIOConfig, inference_api: Api.inference) -> None: def __init__(
self, config: Union[RemoteQdrantVectorIOConfig, InlineQdrantVectorIOConfig], inference_api: Api.inference
) -> None:
self.config = config self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.client: AsyncQdrantClient = None
self.cache = {} self.cache = {}
self.inference_api = inference_api self.inference_api = inference_api
async def initialize(self) -> None: async def initialize(self) -> None:
pass self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() await self.client.close()
async def register_vector_db( async def register_vector_db(
self, self,
@ -123,6 +126,11 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.cache[vector_db.identifier] = index self.cache[vector_db.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache: if vector_db_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_db_id]

View file

@ -10,18 +10,17 @@ from urllib.parse import unquote
import pandas import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_url(url: URL): def get_dataframe_from_uri(uri: str):
df = None df = None
if url.uri.endswith(".csv"): if uri.endswith(".csv"):
df = pandas.read_csv(url.uri) df = pandas.read_csv(uri)
elif url.uri.endswith(".xlsx"): elif uri.endswith(".xlsx"):
df = pandas.read_excel(url.uri) df = pandas.read_excel(uri)
elif url.uri.startswith("data:"): elif uri.startswith("data:"):
parts = parse_data_url(url.uri) parts = parse_data_url(uri)
data = parts["data"] data = parts["data"]
if parts["is_base64"]: if parts["is_base64"]:
data = base64.b64decode(data) data = base64.b64decode(data)
@ -39,6 +38,6 @@ def get_dataframe_from_url(url: URL):
else: else:
df = pandas.read_excel(data_bytes) df = pandas.read_excel(data_bytes)
else: else:
raise ValueError(f"Unsupported file type: {url}") raise ValueError(f"Unsupported file type: {uri}")
return df return df

View file

@ -192,7 +192,11 @@ class LiteLLMOpenAIMixin(
if request.tools: if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice: if request.tool_config.tool_choice:
input_dict["tool_choice"] = request.tool_config.tool_choice.value input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field key_field = self.provider_data_api_key_field

View file

@ -527,26 +527,30 @@ async def convert_message_to_openai_dict_new(
async def _convert_message_content( async def _convert_message_content(
content: InterleavedContent, content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl(): async def impl(
content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input # Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str): if isinstance(content_, str):
return content return content_
elif isinstance(content, TextContentItem): elif isinstance(content_, TextContentItem):
return OpenAIChatCompletionContentPartTextParam( return OpenAIChatCompletionContentPartTextParam(
type="text", type="text",
text=content.text, text=content_.text,
) )
elif isinstance(content, ImageContentItem): elif isinstance(content_, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam( return OpenAIChatCompletionContentPartImageParam(
type="image_url", type="image_url",
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)), image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
) )
elif isinstance(content, list): elif isinstance(content_, list):
return [await _convert_message_content(item) for item in content] return [await impl(item) for item in content_]
else: else:
raise ValueError(f"Unsupported content type: {type(content)}") raise ValueError(f"Unsupported content type: {type(content_)}")
ret = await impl() ret = await impl(content)
# OpenAI*Message expects a str or list
if isinstance(ret, str) or isinstance(ret, list): if isinstance(ret, str) or isinstance(ret, list):
return ret return ret
else: else:
@ -566,13 +570,14 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall( OpenAIChatCompletionMessageToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=tool.tool_name, name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
arguments=json.dumps(tool.arguments), arguments=json.dumps(tool.arguments),
), ),
type="function", type="function",
) )
for tool in message.tool_calls for tool in message.tool_calls
], ]
or None,
) )
elif isinstance(message, ToolResponseMessage): elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage( out = OpenAIChatCompletionToolMessage(
@ -858,7 +863,8 @@ async def convert_openai_chat_completion_stream(
event_type = ChatCompletionResponseEventType.progress event_type = ChatCompletionResponseEventType.progress
stop_reason = None stop_reason = None
toolcall_buffer = {} tool_call_idx_to_buffer = {}
async for chunk in stream: async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk choice = chunk.choices[0] # assuming only one choice per chunk
@ -868,7 +874,6 @@ async def convert_openai_chat_completion_stream(
# if there's a tool call, emit an event for each tool in the list # if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately # if tool call and content, emit both separately
if choice.delta.tool_calls: if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent # the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first # does not support both, so we emit the content first
@ -889,44 +894,53 @@ async def convert_openai_chat_completion_stream(
) )
if not enable_incremental_tool_calls: if not enable_incremental_tool_calls:
yield ChatCompletionResponseStreamChunk( for tool_call in choice.delta.tool_calls:
event=ChatCompletionResponseEvent( yield ChatCompletionResponseStreamChunk(
event_type=next(event_type), event=ChatCompletionResponseEvent(
delta=ToolCallDelta( event_type=event_type,
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0], delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded, tool_call=_convert_openai_tool_calls([tool_call])[0],
), parse_status=ToolCallParseStatus.succeeded,
logprobs=_convert_openai_logprobs(logprobs), ),
logprobs=_convert_openai_logprobs(logprobs),
)
) )
)
else: else:
tool_call = choice.delta.tool_calls[0] for tool_call in choice.delta.tool_calls:
if "name" not in toolcall_buffer: idx = tool_call.index if hasattr(tool_call, "index") else 0
toolcall_buffer["call_id"] = tool_call.id
toolcall_buffer["name"] = None
toolcall_buffer["content"] = ""
if "arguments" not in toolcall_buffer:
toolcall_buffer["arguments"] = ""
if tool_call.function.name: if idx not in tool_call_idx_to_buffer:
toolcall_buffer["name"] = tool_call.function.name tool_call_idx_to_buffer[idx] = {
delta = f"{toolcall_buffer['name']}(" "call_id": tool_call.id,
if tool_call.function.arguments: "name": None,
toolcall_buffer["arguments"] += tool_call.function.arguments "arguments": "",
delta = toolcall_buffer["arguments"] "content": "",
}
toolcall_buffer["content"] += delta buffer = tool_call_idx_to_buffer[idx]
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( if tool_call.function:
event_type=event_type, if tool_call.function.name:
delta=ToolCallDelta( buffer["name"] = tool_call.function.name
tool_call=delta, delta = f"{buffer['name']}("
parse_status=ToolCallParseStatus.in_progress, buffer["content"] += delta
),
logprobs=_convert_openai_logprobs(logprobs), if tool_call.function.arguments:
) delta = tool_call.function.arguments
) buffer["arguments"] += delta
else: buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
elif choice.delta.content:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=event_type, event_type=event_type,
@ -935,47 +949,51 @@ async def convert_openai_chat_completion_stream(
) )
) )
if toolcall_buffer: for idx, buffer in tool_call_idx_to_buffer.items():
delta = ")" logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
toolcall_buffer["content"] += delta if buffer["name"]:
yield ChatCompletionResponseStreamChunk( delta = ")"
event=ChatCompletionResponseEvent( buffer["content"] += delta
event_type=event_type,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
try:
arguments = json.loads(toolcall_buffer["arguments"])
tool_call = ToolCall(
call_id=toolcall_buffer["call_id"],
tool_name=toolcall_buffer["name"],
arguments=arguments,
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=event_type,
delta=ToolCallDelta( delta=ToolCallDelta(
tool_call=tool_call, tool_call=delta,
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.in_progress,
), ),
stop_reason=stop_reason, logprobs=None,
) )
) )
except json.JSONDecodeError:
yield ChatCompletionResponseStreamChunk( try:
event=ChatCompletionResponseEvent( arguments = json.loads(buffer["arguments"])
event_type=ChatCompletionResponseEventType.complete, tool_call = ToolCall(
delta=ToolCallDelta( call_id=buffer["call_id"],
tool_call=toolcall_buffer["content"], tool_name=buffer["name"],
parse_status=ToolCallParseStatus.failed, arguments=arguments,
), )
stop_reason=stop_reason, yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
except json.JSONDecodeError as e:
print(f"Failed to parse arguments: {e}")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=buffer["content"],
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
) )
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(

View file

@ -1,15 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
set -euo pipefail
set -x
stack_dir=$(dirname $(dirname $THIS_DIR))
PYTHONPATH=$stack_dir pytest -p no:warnings --asyncio-mode auto --tb=short

View file

@ -1,13 +1,13 @@
version: '2' version: '2'
distribution_spec: distribution_spec:
description: Use NVIDIA NIM for running LLM inference description: Use NVIDIA NIM for running LLM inference and safety
providers: providers:
inference: inference:
- remote::nvidia - remote::nvidia
vector_io: vector_io:
- inline::faiss - inline::faiss
safety: safety:
- inline::llama-guard - remote::nvidia
agents: agents:
- inline::meta-reference - inline::meta-reference
telemetry: telemetry:
@ -15,16 +15,9 @@ distribution_spec:
eval: eval:
- inline::meta-reference - inline::meta-reference
datasetio: datasetio:
- remote::huggingface
- inline::localfs - inline::localfs
scoring: scoring:
- inline::basic - inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime: tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol
image_type: conda image_type: conda

View file

@ -6,9 +6,10 @@
from pathlib import Path from pathlib import Path
from llama_stack.distribution.datatypes import Provider, ToolGroupInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
@ -16,19 +17,13 @@ def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": ["remote::nvidia"], "inference": ["remote::nvidia"],
"vector_io": ["inline::faiss"], "vector_io": ["inline::faiss"],
"safety": ["inline::llama-guard"], "safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"], "agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"], "eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"], "datasetio": ["inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "scoring": ["inline::basic"],
"tool_runtime": [ "tool_runtime": ["inline::rag-runtime"],
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
} }
inference_provider = Provider( inference_provider = Provider(
@ -36,30 +31,35 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia", provider_type="remote::nvidia",
config=NVIDIAConfig.sample_run_config(), config=NVIDIAConfig.sample_run_config(),
) )
safety_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="nvidia",
)
available_models = { available_models = {
"nvidia": MODEL_ENTRIES, "nvidia": MODEL_ENTRIES,
} }
default_tool_groups = [ default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput( ToolGroupInput(
toolgroup_id="builtin::rag", toolgroup_id="builtin::rag",
provider_id="rag-runtime", provider_id="rag-runtime",
), ),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
] ]
default_models = get_model_registry(available_models) default_models = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name="nvidia",
distro_type="remote_hosted", distro_type="remote_hosted",
description="Use NVIDIA NIM for running LLM inference", description="Use NVIDIA NIM for running LLM inference and safety",
container_image=None, container_image=None,
template_path=Path(__file__).parent / "doc_template.md", template_path=Path(__file__).parent / "doc_template.md",
providers=providers, providers=providers,
@ -72,15 +72,34 @@ def get_distribution_template() -> DistributionTemplate:
default_models=default_models, default_models=default_models,
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
safety_provider,
]
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_tool_groups=default_tool_groups,
),
}, },
run_config_env_vars={ run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"NVIDIA_API_KEY": ( "NVIDIA_API_KEY": (
"", "",
"NVIDIA API Key", "NVIDIA API Key",
), ),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",
),
"SAFETY_MODEL": (
"meta/llama-3.1-8b-instruct",
"Name of the model to use for safety",
),
}, },
) )

View file

@ -0,0 +1,101 @@
version: '2'
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:}
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -26,10 +26,11 @@ providers:
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety: safety:
- provider_id: llama-guard - provider_id: nvidia
provider_type: inline::llama-guard provider_type: remote::nvidia
config: config:
excluded_categories: [] guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -54,13 +55,6 @@ providers:
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
datasetio: datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/huggingface_datasetio.db
- provider_id: localfs - provider_id: localfs
provider_type: inline::localfs provider_type: inline::localfs
config: config:
@ -72,33 +66,10 @@ providers:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {} config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime: tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {} config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
@ -227,11 +198,7 @@ datasets: []
scoring_fns: [] scoring_fns: []
benchmarks: [] benchmarks: []
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag - toolgroup_id: builtin::rag
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server: server:
port: 8321 port: 8321

View file

@ -6,7 +6,7 @@
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from llama_stack.apis.common.content_types import URL from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BenchmarkInput, BenchmarkInput,
@ -171,76 +171,42 @@ def get_distribution_template() -> DistributionTemplate:
DatasetInput( DatasetInput(
dataset_id="simpleqa", dataset_id="simpleqa",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/simpleqa", uri="huggingface://datasets/llamastack/simpleqa?split=train",
"split": "train", ),
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
DatasetInput( DatasetInput(
dataset_id="mmlu_cot", dataset_id="mmlu_cot",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/mmlu_cot", uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all",
"name": "all", ),
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
DatasetInput( DatasetInput(
dataset_id="gpqa_cot", dataset_id="gpqa_cot",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/gpqa_0shot_cot", uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main",
"name": "gpqa_main", ),
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
DatasetInput( DatasetInput(
dataset_id="math_500", dataset_id="math_500",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/math_500", uri="huggingface://datasets/llamastack/math_500?split=test",
"split": "test", ),
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
DatasetInput( DatasetInput(
dataset_id="bfcl", dataset_id="bfcl",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/bfcl_v3"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/bfcl_v3", uri="huggingface://datasets/llamastack/bfcl_v3?split=train",
"split": "train", ),
},
dataset_schema={
"function": {"type": "string"},
"language": {"type": "string"},
"ground_truth": {"type": "string"},
"id": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
] ]

View file

@ -158,80 +158,39 @@ shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: [] vector_dbs: []
datasets: datasets:
- dataset_schema: - purpose: eval/messages-answer
input_query: source:
type: string type: uri
expected_answer: uri: huggingface://datasets/llamastack/simpleqa?split=train
type: string metadata: {}
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/simpleqa
metadata:
path: llamastack/simpleqa
split: train
dataset_id: simpleqa dataset_id: simpleqa
provider_id: huggingface provider_id: huggingface
- dataset_schema: - purpose: eval/messages-answer
input_query: source:
type: string type: uri
expected_answer: uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all
type: string metadata: {}
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
metadata:
path: llamastack/mmlu_cot
name: all
split: test
dataset_id: mmlu_cot dataset_id: mmlu_cot
provider_id: huggingface provider_id: huggingface
- dataset_schema: - purpose: eval/messages-answer
input_query: source:
type: string type: uri
expected_answer: uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
type: string metadata: {}
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
metadata:
path: llamastack/gpqa_0shot_cot
name: gpqa_main
split: train
dataset_id: gpqa_cot dataset_id: gpqa_cot
provider_id: huggingface provider_id: huggingface
- dataset_schema: - purpose: eval/messages-answer
input_query: source:
type: string type: uri
expected_answer: uri: huggingface://datasets/llamastack/math_500?split=test
type: string metadata: {}
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/math_500
metadata:
path: llamastack/math_500
split: test
dataset_id: math_500 dataset_id: math_500
provider_id: huggingface provider_id: huggingface
- dataset_schema: - purpose: eval/messages-answer
function: source:
type: string type: uri
language: uri: huggingface://datasets/llamastack/bfcl_v3?split=train
type: string metadata: {}
ground_truth:
type: string
id:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/bfcl_v3
metadata:
path: llamastack/bfcl_v3
split: train
dataset_id: bfcl dataset_id: bfcl
provider_id: huggingface provider_id: huggingface
scoring_fns: [] scoring_fns: []

Some files were not shown because too many files have changed in this diff Show more