mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:22:35 +00:00
Merge branch 'main' into implement-search-for-PGVector
This commit is contained in:
commit
4c03cddf6f
176 changed files with 8344 additions and 734 deletions
2
.github/workflows/integration-auth-tests.yml
vendored
2
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -18,7 +18,7 @@ on:
|
||||||
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ on:
|
||||||
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
|
||||||
24
.github/workflows/pre-commit.yml
vendored
24
.github/workflows/pre-commit.yml
vendored
|
|
@ -8,7 +8,7 @@ on:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -36,20 +36,16 @@ jobs:
|
||||||
**/requirements*.txt
|
**/requirements*.txt
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
# npm ci may fail -
|
- name: Set up Node.js
|
||||||
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
|
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
|
||||||
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
|
with:
|
||||||
|
node-version: '20'
|
||||||
|
cache: 'npm'
|
||||||
|
cache-dependency-path: 'llama_stack/ui/'
|
||||||
|
|
||||||
# - name: Set up Node.js
|
- name: Install npm dependencies
|
||||||
# uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
run: npm ci
|
||||||
# with:
|
working-directory: llama_stack/ui
|
||||||
# node-version: '20'
|
|
||||||
# cache: 'npm'
|
|
||||||
# cache-dependency-path: 'llama_stack/ui/'
|
|
||||||
|
|
||||||
# - name: Install npm dependencies
|
|
||||||
# run: npm ci
|
|
||||||
# working-directory: llama_stack/ui
|
|
||||||
|
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
|
|
|
||||||
10
.github/workflows/providers-build.yml
vendored
10
.github/workflows/providers-build.yml
vendored
|
|
@ -26,7 +26,7 @@ on:
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -106,6 +106,10 @@ jobs:
|
||||||
- name: Inspect the container image entrypoint
|
- name: Inspect the container image entrypoint
|
||||||
run: |
|
run: |
|
||||||
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
if [ -z "$IMAGE_ID" ]; then
|
||||||
|
echo "No image found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
echo "Entrypoint: $entrypoint"
|
echo "Entrypoint: $entrypoint"
|
||||||
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
||||||
|
|
@ -140,6 +144,10 @@ jobs:
|
||||||
- name: Inspect UBI9 image
|
- name: Inspect UBI9 image
|
||||||
run: |
|
run: |
|
||||||
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
if [ -z "$IMAGE_ID" ]; then
|
||||||
|
echo "No image found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
echo "Entrypoint: $entrypoint"
|
echo "Entrypoint: $entrypoint"
|
||||||
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
||||||
|
|
|
||||||
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
|
@ -24,7 +24,7 @@ jobs:
|
||||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0
|
uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
activate-environment: true
|
activate-environment: true
|
||||||
|
|
|
||||||
2
.github/workflows/semantic-pr.yml
vendored
2
.github/workflows/semantic-pr.yml
vendored
|
|
@ -22,6 +22,6 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check PR Title's semantic conformance
|
- name: Check PR Title's semantic conformance
|
||||||
uses: amannn/action-semantic-pull-request@7f33ba792281b034f64e96f4c0b5496782dd3b37 # v6.1.0
|
uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
|
||||||
2
.github/workflows/ui-unit-tests.yml
vendored
2
.github/workflows/ui-unit-tests.yml
vendored
|
|
@ -13,7 +13,7 @@ on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
|
||||||
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
|
@ -18,7 +18,7 @@ on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
|
||||||
2
.github/workflows/update-readthedocs.yml
vendored
2
.github/workflows/update-readthedocs.yml
vendored
|
|
@ -27,7 +27,7 @@ on:
|
||||||
- '.github/workflows/update-readthedocs.yml'
|
- '.github/workflows/update-readthedocs.yml'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
|
||||||
|
|
@ -146,31 +146,13 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^.github/workflows/.*$
|
files: ^.github/workflows/.*$
|
||||||
# ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail -
|
- id: ui-linter
|
||||||
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
|
name: Format & Lint UI
|
||||||
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
|
entry: bash ./scripts/run-ui-linter.sh
|
||||||
# and until we have infra for installing prettier and next via npm -
|
language: system
|
||||||
# Lint UI code with ESLint.....................................................Failed
|
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||||
# - hook id: ui-eslint
|
pass_filenames: false
|
||||||
# - exit code: 127
|
require_serial: true
|
||||||
# > ui@0.1.0 lint
|
|
||||||
# > next lint --fix --quiet
|
|
||||||
# sh: line 1: next: command not found
|
|
||||||
#
|
|
||||||
# - id: ui-prettier
|
|
||||||
# name: Format UI code with Prettier
|
|
||||||
# entry: bash -c 'cd llama_stack/ui && npm ci && npm run format'
|
|
||||||
# language: system
|
|
||||||
# files: ^llama_stack/ui/.*\.(ts|tsx)$
|
|
||||||
# pass_filenames: false
|
|
||||||
# require_serial: true
|
|
||||||
# - id: ui-eslint
|
|
||||||
# name: Lint UI code with ESLint
|
|
||||||
# entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
|
|
||||||
# language: system
|
|
||||||
# files: ^llama_stack/ui/.*\.(ts|tsx)$
|
|
||||||
# pass_filenames: false
|
|
||||||
# require_serial: true
|
|
||||||
|
|
||||||
- id: check-log-usage
|
- id: check-log-usage
|
||||||
name: Ensure 'llama_stack.log' usage for logging
|
name: Ensure 'llama_stack.log' usage for logging
|
||||||
|
|
|
||||||
138
docs/_static/llama-stack-spec.html
vendored
138
docs/_static/llama-stack-spec.html
vendored
|
|
@ -4605,6 +4605,49 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/inference/rerank": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "RerankResponse with indices sorted by relevance score (descending).",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/RerankResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inference"
|
||||||
|
],
|
||||||
|
"description": "Rerank a list of documents based on their relevance to a query.",
|
||||||
|
"parameters": [],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/RerankRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
|
@ -16024,12 +16067,16 @@
|
||||||
"value": {
|
"value": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"description": "The numeric value of the metric at this timestamp"
|
"description": "The numeric value of the metric at this timestamp"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"timestamp",
|
"timestamp",
|
||||||
"value"
|
"value",
|
||||||
|
"unit"
|
||||||
],
|
],
|
||||||
"title": "MetricDataPoint",
|
"title": "MetricDataPoint",
|
||||||
"description": "A single data point in a metric time series."
|
"description": "A single data point in a metric time series."
|
||||||
|
|
@ -16587,6 +16634,95 @@
|
||||||
],
|
],
|
||||||
"title": "RegisterVectorDbRequest"
|
"title": "RegisterVectorDbRequest"
|
||||||
},
|
},
|
||||||
|
"RerankRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the reranking model to use."
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
|
||||||
|
},
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
|
||||||
|
},
|
||||||
|
"max_num_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "(Optional) Maximum number of results to return. Default: returns all."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"model",
|
||||||
|
"query",
|
||||||
|
"items"
|
||||||
|
],
|
||||||
|
"title": "RerankRequest"
|
||||||
|
},
|
||||||
|
"RerankData": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The original index of the document in the input list"
|
||||||
|
},
|
||||||
|
"relevance_score": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"index",
|
||||||
|
"relevance_score"
|
||||||
|
],
|
||||||
|
"title": "RerankData",
|
||||||
|
"description": "A single rerank result from a reranking response."
|
||||||
|
},
|
||||||
|
"RerankResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/RerankData"
|
||||||
|
},
|
||||||
|
"description": "List of rerank result objects, sorted by relevance score (descending)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"data"
|
||||||
|
],
|
||||||
|
"title": "RerankResponse",
|
||||||
|
"description": "Response from a reranking request."
|
||||||
|
},
|
||||||
"ResumeAgentTurnRequest": {
|
"ResumeAgentTurnRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
|
||||||
104
docs/_static/llama-stack-spec.yaml
vendored
104
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -3264,6 +3264,37 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/QueryTracesRequest'
|
$ref: '#/components/schemas/QueryTracesRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/inference/rerank:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
RerankResponse with indices sorted by relevance score (descending).
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RerankResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Inference
|
||||||
|
description: >-
|
||||||
|
Rerank a list of documents based on their relevance to a query.
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RerankRequest'
|
||||||
|
required: true
|
||||||
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -11923,10 +11954,13 @@ components:
|
||||||
type: number
|
type: number
|
||||||
description: >-
|
description: >-
|
||||||
The numeric value of the metric at this timestamp
|
The numeric value of the metric at this timestamp
|
||||||
|
unit:
|
||||||
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- timestamp
|
- timestamp
|
||||||
- value
|
- value
|
||||||
|
- unit
|
||||||
title: MetricDataPoint
|
title: MetricDataPoint
|
||||||
description: >-
|
description: >-
|
||||||
A single data point in a metric time series.
|
A single data point in a metric time series.
|
||||||
|
|
@ -12337,6 +12371,76 @@ components:
|
||||||
- vector_db_id
|
- vector_db_id
|
||||||
- embedding_model
|
- embedding_model
|
||||||
title: RegisterVectorDbRequest
|
title: RegisterVectorDbRequest
|
||||||
|
RerankRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The identifier of the reranking model to use.
|
||||||
|
query:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
description: >-
|
||||||
|
The search query to rank items against. Can be a string, text content
|
||||||
|
part, or image content part. The input must not exceed the model's max
|
||||||
|
input token length.
|
||||||
|
items:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
description: >-
|
||||||
|
List of items to rerank. Each item can be a string, text content part,
|
||||||
|
or image content part. Each input must not exceed the model's max input
|
||||||
|
token length.
|
||||||
|
max_num_results:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
- query
|
||||||
|
- items
|
||||||
|
title: RerankRequest
|
||||||
|
RerankData:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
index:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
The original index of the document in the input list
|
||||||
|
relevance_score:
|
||||||
|
type: number
|
||||||
|
description: >-
|
||||||
|
The relevance score from the model output. Values are inverted when applicable
|
||||||
|
so that higher scores indicate greater relevance.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- index
|
||||||
|
- relevance_score
|
||||||
|
title: RerankData
|
||||||
|
description: >-
|
||||||
|
A single rerank result from a reranking response.
|
||||||
|
RerankResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/RerankData'
|
||||||
|
description: >-
|
||||||
|
List of rerank result objects, sorted by relevance score (descending)
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- data
|
||||||
|
title: RerankResponse
|
||||||
|
description: Response from a reranking request.
|
||||||
ResumeAgentTurnRequest:
|
ResumeAgentTurnRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ The list of open-benchmarks we currently support:
|
||||||
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
|
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
|
||||||
|
|
||||||
|
|
||||||
You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
|
You can follow this [contributing guide](../references/evals_reference/index.md#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
|
||||||
|
|
||||||
#### Run evaluation on open-benchmarks via CLI
|
#### Run evaluation on open-benchmarks via CLI
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,3 +35,6 @@ device: cpu
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
[Find more detailed information here!](huggingface.md)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,3 +22,4 @@ checkpoint_format: meta
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
[Find more detailed information here!](torchtune.md)
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
|
||||||
- **API Resources**: Inspect Llama Stack API resources
|
- **API Resources**: Inspect Llama Stack API resources
|
||||||
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`).
|
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`).
|
||||||
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
|
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
|
||||||
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
|
- Please visit [Core Concepts](../../concepts/index.md) for more details about the resources.
|
||||||
|
|
||||||
### Starting the Llama Stack Playground
|
### Starting the Llama Stack Playground
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics.
|
Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics.
|
||||||
|
|
||||||
```{note}
|
```{note}
|
||||||
For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API.
|
**Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](../providers/openai.md#chat-completions) directly, before progressing to Agents or Responses API.
|
||||||
```
|
```
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
@ -173,7 +173,7 @@ Both APIs demonstrate distinct strengths that make them valuable on their own fo
|
||||||
|
|
||||||
## For More Information
|
## For More Information
|
||||||
|
|
||||||
- **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html)
|
- **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](agent.md)
|
||||||
- **OpenAI Responses API**: For information on using the OpenAI-compatible responses API, see the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/responses)
|
- **OpenAI Responses API**: For information on using the OpenAI-compatible responses API, see the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/responses)
|
||||||
- **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions)
|
- **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](../providers/openai.md#chat-completions)
|
||||||
- **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent_execution_loop.html)
|
- **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](agent_execution_loop.md)
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,4 @@ While there is a lot of flexibility to mix-and-match providers, often users will
|
||||||
|
|
||||||
**Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros.
|
**Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros.
|
||||||
|
|
||||||
**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/ios_sdk.html) and [Android](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/android_sdk.html)
|
**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](../distributions/ondevice_distro/ios_sdk.md) and [Android](../distributions/ondevice_distro/android_sdk.md)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,13 @@ Here are some example PRs to help you get started:
|
||||||
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
|
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
|
||||||
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
|
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
|
||||||
|
|
||||||
|
## Guidelines for creating Internal or External Providers
|
||||||
|
|
||||||
|
|**Type** |Internal (In-tree) |External (out-of-tree)
|
||||||
|
|---------|-------------------|---------------------|
|
||||||
|
|**Description** |A provider that is directly in the Llama Stack code|A provider that is outside of the Llama stack core codebase but is still accessible and usable by Llama Stack.
|
||||||
|
|**Benefits** |Ability to interact with the provider with minimal additional configurations or installations| Contributors do not have to add directly to the code to create providers accessible on Llama Stack. Keep provider-specific code separate from the core Llama Stack code.
|
||||||
|
|
||||||
## Inference Provider Patterns
|
## Inference Provider Patterns
|
||||||
|
|
||||||
When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers.
|
When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers.
|
||||||
|
|
|
||||||
|
|
@ -225,8 +225,32 @@ server:
|
||||||
port: 8321 # Port to listen on (default: 8321)
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
|
cors: true # Optional: Enable CORS (dev mode) or full config object
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CORS Configuration
|
||||||
|
|
||||||
|
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
|
||||||
|
|
||||||
|
**Local development** (allows localhost origins only):
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors: true
|
||||||
|
```
|
||||||
|
|
||||||
|
**Explicit configuration** (custom origins and settings):
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors:
|
||||||
|
allow_origins: ["https://myapp.com", "https://app.example.com"]
|
||||||
|
allow_methods: ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
allow_headers: ["Content-Type", "Authorization"]
|
||||||
|
allow_credentials: true
|
||||||
|
max_age: 3600
|
||||||
|
```
|
||||||
|
|
||||||
|
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
|
||||||
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
||||||
|
|
@ -618,6 +642,54 @@ Content-Type: application/json
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CORS Configuration
|
||||||
|
|
||||||
|
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
|
||||||
|
|
||||||
|
#### Quick Setup
|
||||||
|
|
||||||
|
For development, use the simple boolean flag:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors: true # Auto-enables localhost with any port
|
||||||
|
```
|
||||||
|
|
||||||
|
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
|
||||||
|
|
||||||
|
#### Custom Configuration
|
||||||
|
|
||||||
|
For specific origins and full control:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors:
|
||||||
|
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
|
||||||
|
allow_credentials: true
|
||||||
|
allow_methods: ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
allow_headers: ["Content-Type", "Authorization"]
|
||||||
|
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
|
||||||
|
expose_headers: ["X-Total-Count"]
|
||||||
|
max_age: 86400
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configuration Options
|
||||||
|
|
||||||
|
| Field | Description | Default |
|
||||||
|
| -------------------- | ---------------------------------------------- | ------- |
|
||||||
|
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
|
||||||
|
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
|
||||||
|
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
|
||||||
|
| `allow_headers` | Allowed headers. | `["*"]` |
|
||||||
|
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
|
||||||
|
| `expose_headers` | Headers exposed to browser. | `[]` |
|
||||||
|
| `max_age` | Preflight cache time (seconds). | `600` |
|
||||||
|
|
||||||
|
**Security Notes**:
|
||||||
|
- `allow_credentials: true` requires explicit origins (no wildcards)
|
||||||
|
- `cors: true` enables localhost access only (secure for development)
|
||||||
|
- For public APIs, always specify exact allowed origins
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
||||||
|
|
@ -28,9 +27,8 @@ Then, you can access the APIs like `models` and `inference` on the client and ca
|
||||||
response = client.models.list()
|
response = client.models.list()
|
||||||
```
|
```
|
||||||
|
|
||||||
If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html), you can also use the run.yaml configuration file directly:
|
If you've created a [custom distribution](building_distro.md), you can also use the run.yaml configuration file directly:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client = LlamaStackAsLibraryClient(config_path)
|
client = LlamaStackAsLibraryClient(config_path)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -22,17 +22,17 @@ else
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "${GITHUB_CLIENT_ID:-}" ]; then
|
if [ -z "${GITHUB_CLIENT_ID:-}" ]; then
|
||||||
echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then
|
if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then
|
||||||
echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then
|
if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then
|
||||||
echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ llama stack run starter --port 5050
|
||||||
|
|
||||||
Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility.
|
Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility.
|
||||||
|
|
||||||
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations)
|
Other inference providers: [Table](../../index.md#supported-llama-stack-implementations)
|
||||||
|
|
||||||
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#settings)
|
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#settings)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
orphan: true
|
orphan: true
|
||||||
---
|
---
|
||||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||||
# Meta Reference Distribution
|
# Meta Reference GPU Distribution
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
@ -41,7 +41,7 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
## Prerequisite: Downloading Models
|
||||||
|
|
||||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ llama model list --downloaded
|
$ llama model list --downloaded
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,15 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Protocol for batch processing API operations.
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
|
|
||||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
|
||||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
cost-effective inference at scale.
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
The API is designed to allow use of openai client libraries for seamless integration.
|
||||||
|
|
||||||
|
This API provides the following extensions:
|
||||||
|
- idempotent batch creation
|
||||||
|
|
||||||
Note: This API is currently under active development and may undergo changes.
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **batches** API.
|
This section contains documentation for all available providers for the **batches** API.
|
||||||
|
|
|
||||||
|
|
@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files*
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
inline_localfs
|
inline_localfs
|
||||||
|
remote_s3
|
||||||
```
|
```
|
||||||
|
|
|
||||||
33
docs/source/providers/files/remote_s3.md
Normal file
33
docs/source/providers/files/remote_s3.md
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
# remote::s3
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
AWS S3-based file storage provider for scalable cloud file management with metadata persistence.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
|
||||||
|
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
|
||||||
|
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
|
||||||
|
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
|
||||||
|
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
|
||||||
|
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
|
||||||
|
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
bucket_name: ${env.S3_BUCKET_NAME}
|
||||||
|
region: ${env.AWS_REGION:=us-east-1}
|
||||||
|
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=}
|
||||||
|
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=}
|
||||||
|
endpoint_url: ${env.S3_ENDPOINT_URL:=}
|
||||||
|
auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
@ -9,7 +9,8 @@ This section contains documentation for all available providers for the **post_t
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
inline_huggingface
|
inline_huggingface-gpu
|
||||||
inline_torchtune
|
inline_torchtune-cpu
|
||||||
|
inline_torchtune-gpu
|
||||||
remote_nvidia
|
remote_nvidia
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
# inline::huggingface-cpu
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `device` | `<class 'str'>` | No | cuda | |
|
||||||
|
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
|
||||||
|
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
|
||||||
|
| `chat_template` | `<class 'str'>` | No | <|user|>
|
||||||
|
{input}
|
||||||
|
<|assistant|>
|
||||||
|
{output} | |
|
||||||
|
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
|
||||||
|
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
|
||||||
|
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
|
||||||
|
| `save_total_limit` | `<class 'int'>` | No | 3 | |
|
||||||
|
| `logging_steps` | `<class 'int'>` | No | 10 | |
|
||||||
|
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
|
||||||
|
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
|
||||||
|
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
|
||||||
|
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
|
||||||
|
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
|
||||||
|
| `use_reference_model` | `<class 'bool'>` | No | True | |
|
||||||
|
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
|
||||||
|
| `dpo_output_dir` | `<class 'str'>` | No | | |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
checkpoint_format: huggingface
|
||||||
|
distributed_backend: null
|
||||||
|
device: cpu
|
||||||
|
dpo_output_dir: ~/.llama/dummy/dpo_output
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
# inline::huggingface-gpu
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `device` | `<class 'str'>` | No | cuda | |
|
||||||
|
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
|
||||||
|
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
|
||||||
|
| `chat_template` | `<class 'str'>` | No | <|user|>
|
||||||
|
{input}
|
||||||
|
<|assistant|>
|
||||||
|
{output} | |
|
||||||
|
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
|
||||||
|
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
|
||||||
|
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
|
||||||
|
| `save_total_limit` | `<class 'int'>` | No | 3 | |
|
||||||
|
| `logging_steps` | `<class 'int'>` | No | 10 | |
|
||||||
|
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
|
||||||
|
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
|
||||||
|
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
|
||||||
|
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
|
||||||
|
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
|
||||||
|
| `use_reference_model` | `<class 'bool'>` | No | True | |
|
||||||
|
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
|
||||||
|
| `dpo_output_dir` | `<class 'str'>` | No | | |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
checkpoint_format: huggingface
|
||||||
|
distributed_backend: null
|
||||||
|
device: cpu
|
||||||
|
dpo_output_dir: ~/.llama/dummy/dpo_output
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
20
docs/source/providers/post_training/inline_torchtune-cpu.md
Normal file
20
docs/source/providers/post_training/inline_torchtune-cpu.md
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
# inline::torchtune-cpu
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `torch_seed` | `int \| None` | No | | |
|
||||||
|
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
checkpoint_format: meta
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
20
docs/source/providers/post_training/inline_torchtune-gpu.md
Normal file
20
docs/source/providers/post_training/inline_torchtune-gpu.md
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
# inline::torchtune-gpu
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `torch_seed` | `int \| None` | No | | |
|
||||||
|
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
checkpoint_format: meta
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
@ -202,7 +202,7 @@ pprint(response)
|
||||||
|
|
||||||
Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
|
Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
|
||||||
|
|
||||||
In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings.
|
In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](../../building_applications/playground/index.md) for an interactive interface to upload datasets and run scorings.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"
|
judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"
|
||||||
|
|
|
||||||
|
|
@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Batches(Protocol):
|
class Batches(Protocol):
|
||||||
"""Protocol for batch processing API operations.
|
"""
|
||||||
|
|
||||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
cost-effective inference at scale.
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
The API is designed to allow use of openai client libraries for seamless integration.
|
||||||
|
|
||||||
|
This API provides the following extensions:
|
||||||
|
- idempotent batch creation
|
||||||
|
|
||||||
Note: This API is currently under active development and may undergo changes.
|
Note: This API is currently under active development and may undergo changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -45,6 +49,7 @@ class Batches(Protocol):
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
metadata: dict[str, str] | None = None,
|
metadata: dict[str, str] | None = None,
|
||||||
|
idempotency_key: str | None = None,
|
||||||
) -> BatchObject:
|
) -> BatchObject:
|
||||||
"""Create a new batch for processing multiple API requests.
|
"""Create a new batch for processing multiple API requests.
|
||||||
|
|
||||||
|
|
@ -52,6 +57,7 @@ class Batches(Protocol):
|
||||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||||
:param completion_window: The time window within which the batch should be processed.
|
:param completion_window: The time window within which the batch should be processed.
|
||||||
:param metadata: Optional metadata for the batch.
|
:param metadata: Optional metadata for the batch.
|
||||||
|
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
|
||||||
:returns: The created batch object.
|
:returns: The created batch object.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
|
||||||
embeddings: list[list[float]]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankData(BaseModel):
|
||||||
|
"""A single rerank result from a reranking response.
|
||||||
|
|
||||||
|
:param index: The original index of the document in the input list
|
||||||
|
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
relevance_score: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankResponse(BaseModel):
|
||||||
|
"""Response from a reranking request.
|
||||||
|
|
||||||
|
:param data: List of rerank result objects, sorted by relevance score (descending)
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: list[RerankData]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
"""Text content part for OpenAI-compatible chat completion messages.
|
"""Text content part for OpenAI-compatible chat completion messages.
|
||||||
|
|
@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol):
|
||||||
:returns: A BatchCompletionResponse with the full completions.
|
:returns: A BatchCompletionResponse with the full completions.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion", method="POST")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
|
@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol):
|
||||||
:returns: A BatchChatCompletionResponse with the full completions.
|
:returns: A BatchChatCompletionResponse with the full completions.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings", method="POST")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Rerank a list of documents based on their relevance to a query.
|
||||||
|
|
||||||
|
:param model: The identifier of the reranking model to use.
|
||||||
|
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
||||||
|
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
||||||
|
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
:returns: RerankResponse with indices sorted by relevance score (descending).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Reranking is not implemented")
|
||||||
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/completions", method="POST")
|
@webmethod(route="/openai/v1/completions", method="POST")
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel):
|
||||||
|
|
||||||
timestamp: int
|
timestamp: int
|
||||||
value: float
|
value: float
|
||||||
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -518,7 +519,7 @@ class Telemetry(Protocol):
|
||||||
metric_name: str,
|
metric_name: str,
|
||||||
start_time: int,
|
start_time: int,
|
||||||
end_time: int | None = None,
|
end_time: int | None = None,
|
||||||
granularity: str | None = "1d",
|
granularity: str | None = None,
|
||||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||||
) -> QueryMetricsResponse:
|
) -> QueryMetricsResponse:
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
logger = get_logger(name=__name__, category="cli")
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ def get_provider_dependencies(
|
||||||
normal_deps = []
|
normal_deps = []
|
||||||
special_deps = []
|
special_deps = []
|
||||||
for package in deps:
|
for package in deps:
|
||||||
if "--no-deps" in package or "--index-url" in package:
|
if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
|
||||||
special_deps.append(package)
|
special_deps.append(package)
|
||||||
else:
|
else:
|
||||||
normal_deps.append(package)
|
normal_deps.append(package)
|
||||||
|
|
|
||||||
|
|
@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
|
||||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||||
|
|
||||||
|
|
||||||
|
class CORSConfig(BaseModel):
|
||||||
|
allow_origins: list[str] = Field(default_factory=list)
|
||||||
|
allow_origin_regex: str | None = Field(default=None)
|
||||||
|
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||||
|
allow_headers: list[str] = Field(default_factory=list)
|
||||||
|
allow_credentials: bool = Field(default=False)
|
||||||
|
expose_headers: list[str] = Field(default_factory=list)
|
||||||
|
max_age: int = Field(default=600, ge=0)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_credentials_config(self) -> Self:
|
||||||
|
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||||
|
raise ValueError("Cannot use wildcard origins with credentials enabled")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
|
||||||
|
if cors_config is False or cors_config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cors_config is True:
|
||||||
|
# dev mode: allow localhost on any port
|
||||||
|
return CORSConfig(
|
||||||
|
allow_origins=[],
|
||||||
|
allow_origin_regex=r"https?://localhost:\d+",
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cors_config, CORSConfig):
|
||||||
|
return cors_config
|
||||||
|
|
||||||
|
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
|
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Per client quota request configuration",
|
description="Per client quota request configuration",
|
||||||
)
|
)
|
||||||
|
cors: bool | CORSConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||||
|
"- true: Enable localhost CORS for development\n"
|
||||||
|
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_distro_name, custom_provider_registry, provider_data
|
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
|
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
if in_notebook():
|
|
||||||
import nest_asyncio
|
|
||||||
|
|
||||||
nest_asyncio.apply()
|
|
||||||
if not self.skip_logger_removal:
|
|
||||||
self._remove_root_logger_handlers()
|
|
||||||
|
|
||||||
# use a new event loop to avoid interfering with the main event loop
|
# use a new event loop to avoid interfering with the main event loop
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
return loop.run_until_complete(self.async_client.initialize())
|
loop.run_until_complete(self.async_client.initialize())
|
||||||
finally:
|
finally:
|
||||||
asyncio.set_event_loop(None)
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
def _remove_root_logger_handlers(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
Deprecated method for backward compatibility.
|
||||||
"""
|
"""
|
||||||
root_logger = logging.getLogger()
|
pass
|
||||||
|
|
||||||
for handler in root_logger.handlers[:]:
|
|
||||||
root_logger.removeHandler(handler)
|
|
||||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
loop = self.loop
|
loop = self.loop
|
||||||
|
|
@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
config_path_or_distro_name: str,
|
config_path_or_distro_name: str,
|
||||||
custom_provider_registry: ProviderRegistry | None = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: dict[str, Any] | None = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
|
skip_logger_removal: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
|
|
@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||||
|
|
||||||
|
if in_notebook():
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
if not skip_logger_removal:
|
||||||
|
self._remove_root_logger_handlers()
|
||||||
|
|
||||||
if config_path_or_distro_name.endswith(".yaml"):
|
if config_path_or_distro_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_distro_name)
|
config_path = Path(config_path_or_distro_name)
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
|
|
@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||||
|
|
||||||
|
def _remove_root_logger_handlers(self):
|
||||||
|
"""
|
||||||
|
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||||
|
"""
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
|
for handler in root_logger.handlers[:]:
|
||||||
|
root_logger.removeHandler(handler)
|
||||||
|
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
|
"""
|
||||||
|
Initialize the async client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if initialization was successful
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
class DatasetIORouter(DatasetIO):
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
class ScoringRouter(Scoring):
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
class SafetyRouter(Safety):
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
|
||||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="core::auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="core::auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="quota")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
class QuotaMiddleware:
|
class QuotaMiddleware:
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from aiohttp import hdrs
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
|
||||||
AuthenticationRequiredError,
|
AuthenticationRequiredError,
|
||||||
LoggingConfig,
|
LoggingConfig,
|
||||||
StackRunConfig,
|
StackRunConfig,
|
||||||
|
process_cors_config,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||||
|
|
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
|
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
if args.env:
|
if args.env:
|
||||||
for env_pair in args.env:
|
for env_pair in args.env:
|
||||||
try:
|
try:
|
||||||
|
|
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
|
||||||
window_seconds=window_seconds,
|
window_seconds=window_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.server.cors:
|
||||||
|
logger.info("Enabling CORS")
|
||||||
|
cors_config = process_cors_config(config.server.cors)
|
||||||
|
if cors_config:
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -225,7 +225,10 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = re.sub(pattern, get_env_var, config)
|
result = re.sub(pattern, get_env_var, config)
|
||||||
return _convert_string_to_proper_type(result)
|
# Only apply type conversion if substitution actually happened
|
||||||
|
if result != config:
|
||||||
|
return _convert_string_to_proper_type(result)
|
||||||
|
return result
|
||||||
except EnvVarError as e:
|
except EnvVarError as e:
|
||||||
raise EnvVarError(e.var_name, e.path) from None
|
raise EnvVarError(e.var_name, e.path) from None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
logger = get_logger(__name__, category="core")
|
logger = get_logger(__name__, category="core::registry")
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="config_resolution")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ distribution_spec:
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
post_training:
|
post_training:
|
||||||
- provider_type: inline::huggingface
|
- provider_type: inline::torchtune-cpu
|
||||||
eval:
|
eval:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
datasetio:
|
datasetio:
|
||||||
|
|
|
||||||
|
|
@ -156,13 +156,10 @@ providers:
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
||||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: huggingface
|
- provider_id: torchtune-cpu
|
||||||
provider_type: inline::huggingface
|
provider_type: inline::torchtune-cpu
|
||||||
config:
|
config:
|
||||||
checkpoint_format: huggingface
|
checkpoint_format: meta
|
||||||
distributed_backend: null
|
|
||||||
device: cpu
|
|
||||||
dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output
|
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
---
|
---
|
||||||
orphan: true
|
orphan: true
|
||||||
---
|
---
|
||||||
# Meta Reference Distribution
|
# Meta Reference GPU Distribution
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
@ -29,7 +29,7 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
## Prerequisite: Downloading Models
|
||||||
|
|
||||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ llama model list --downloaded
|
$ llama model list --downloaded
|
||||||
|
|
|
||||||
7
llama_stack/distributions/starter-gpu/__init__.py
Normal file
7
llama_stack/distributions/starter-gpu/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .starter_gpu import get_distribution_template # noqa: F401
|
||||||
59
llama_stack/distributions/starter-gpu/build.yaml
Normal file
59
llama_stack/distributions/starter-gpu/build.yaml
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
version: 2
|
||||||
|
distribution_spec:
|
||||||
|
description: Quick start template for running Llama Stack with several popular providers.
|
||||||
|
This distribution is intended for GPU-enabled environments.
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_type: remote::cerebras
|
||||||
|
- provider_type: remote::ollama
|
||||||
|
- provider_type: remote::vllm
|
||||||
|
- provider_type: remote::tgi
|
||||||
|
- provider_type: remote::fireworks
|
||||||
|
- provider_type: remote::together
|
||||||
|
- provider_type: remote::bedrock
|
||||||
|
- provider_type: remote::nvidia
|
||||||
|
- provider_type: remote::openai
|
||||||
|
- provider_type: remote::anthropic
|
||||||
|
- provider_type: remote::gemini
|
||||||
|
- provider_type: remote::vertexai
|
||||||
|
- provider_type: remote::groq
|
||||||
|
- provider_type: remote::sambanova
|
||||||
|
- provider_type: inline::sentence-transformers
|
||||||
|
vector_io:
|
||||||
|
- provider_type: inline::faiss
|
||||||
|
- provider_type: inline::sqlite-vec
|
||||||
|
- provider_type: inline::milvus
|
||||||
|
- provider_type: remote::chromadb
|
||||||
|
- provider_type: remote::pgvector
|
||||||
|
files:
|
||||||
|
- provider_type: inline::localfs
|
||||||
|
safety:
|
||||||
|
- provider_type: inline::llama-guard
|
||||||
|
- provider_type: inline::code-scanner
|
||||||
|
agents:
|
||||||
|
- provider_type: inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- provider_type: inline::meta-reference
|
||||||
|
post_training:
|
||||||
|
- provider_type: inline::huggingface-gpu
|
||||||
|
eval:
|
||||||
|
- provider_type: inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- provider_type: remote::huggingface
|
||||||
|
- provider_type: inline::localfs
|
||||||
|
scoring:
|
||||||
|
- provider_type: inline::basic
|
||||||
|
- provider_type: inline::llm-as-judge
|
||||||
|
- provider_type: inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- provider_type: remote::brave-search
|
||||||
|
- provider_type: remote::tavily-search
|
||||||
|
- provider_type: inline::rag-runtime
|
||||||
|
- provider_type: remote::model-context-protocol
|
||||||
|
batches:
|
||||||
|
- provider_type: inline::reference
|
||||||
|
image_type: venv
|
||||||
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
|
- asyncpg
|
||||||
|
- sqlalchemy[asyncio]
|
||||||
241
llama_stack/distributions/starter-gpu/run.yaml
Normal file
241
llama_stack/distributions/starter-gpu/run.yaml
Normal file
|
|
@ -0,0 +1,241 @@
|
||||||
|
version: 2
|
||||||
|
image_name: starter-gpu
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- batches
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- files
|
||||||
|
- inference
|
||||||
|
- post_training
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
|
provider_type: remote::cerebras
|
||||||
|
config:
|
||||||
|
base_url: https://api.cerebras.ai
|
||||||
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
|
provider_type: remote::ollama
|
||||||
|
config:
|
||||||
|
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||||
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
|
provider_type: remote::vllm
|
||||||
|
config:
|
||||||
|
url: ${env.VLLM_URL:=}
|
||||||
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: ${env.TGI_URL:=}
|
||||||
|
- provider_id: fireworks
|
||||||
|
provider_type: remote::fireworks
|
||||||
|
config:
|
||||||
|
url: https://api.fireworks.ai/inference/v1
|
||||||
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
|
- provider_id: together
|
||||||
|
provider_type: remote::together
|
||||||
|
config:
|
||||||
|
url: https://api.together.xyz/v1
|
||||||
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
|
- provider_id: bedrock
|
||||||
|
provider_type: remote::bedrock
|
||||||
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||||
|
- provider_id: openai
|
||||||
|
provider_type: remote::openai
|
||||||
|
config:
|
||||||
|
api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||||
|
- provider_id: anthropic
|
||||||
|
provider_type: remote::anthropic
|
||||||
|
config:
|
||||||
|
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||||
|
- provider_id: gemini
|
||||||
|
provider_type: remote::gemini
|
||||||
|
config:
|
||||||
|
api_key: ${env.GEMINI_API_KEY:=}
|
||||||
|
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
|
||||||
|
provider_type: remote::vertexai
|
||||||
|
config:
|
||||||
|
project: ${env.VERTEX_AI_PROJECT:=}
|
||||||
|
location: ${env.VERTEX_AI_LOCATION:=us-central1}
|
||||||
|
- provider_id: groq
|
||||||
|
provider_type: remote::groq
|
||||||
|
config:
|
||||||
|
url: https://api.groq.com
|
||||||
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
|
- provider_id: sambanova
|
||||||
|
provider_type: remote::sambanova
|
||||||
|
config:
|
||||||
|
url: https://api.sambanova.ai/v1
|
||||||
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
vector_io:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
||||||
|
- provider_id: sqlite-vec
|
||||||
|
provider_type: inline::sqlite-vec
|
||||||
|
config:
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
||||||
|
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||||
|
provider_type: inline::milvus
|
||||||
|
config:
|
||||||
|
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
||||||
|
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
||||||
|
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||||
|
provider_type: remote::pgvector
|
||||||
|
config:
|
||||||
|
host: ${env.PGVECTOR_HOST:=localhost}
|
||||||
|
port: ${env.PGVECTOR_PORT:=5432}
|
||||||
|
db: ${env.PGVECTOR_DB:=}
|
||||||
|
user: ${env.PGVECTOR_USER:=}
|
||||||
|
password: ${env.PGVECTOR_PASSWORD:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
- provider_id: code-scanner
|
||||||
|
provider_type: inline::code-scanner
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db
|
||||||
|
responses_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
|
||||||
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
|
post_training:
|
||||||
|
- provider_id: huggingface-gpu
|
||||||
|
provider_type: inline::huggingface-gpu
|
||||||
|
config:
|
||||||
|
checkpoint_format: huggingface
|
||||||
|
distributed_backend: null
|
||||||
|
device: cpu
|
||||||
|
dpo_output_dir: ~/.llama/distributions/starter-gpu/dpo_output
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
batches:
|
||||||
|
- provider_id: reference
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db
|
||||||
|
inference_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
|
||||||
|
models: []
|
||||||
|
shields:
|
||||||
|
- shield_id: llama-guard
|
||||||
|
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||||
|
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||||
|
- shield_id: code-scanner
|
||||||
|
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||||
|
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
22
llama_stack/distributions/starter-gpu/starter_gpu.py
Normal file
22
llama_stack/distributions/starter-gpu/starter_gpu.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.distributions.template import BuildProvider, DistributionTemplate
|
||||||
|
|
||||||
|
from ..starter.starter import get_distribution_template as get_starter_distribution_template
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
template = get_starter_distribution_template()
|
||||||
|
name = "starter-gpu"
|
||||||
|
template.name = name
|
||||||
|
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
|
||||||
|
|
||||||
|
template.providers["post_training"] = [
|
||||||
|
BuildProvider(provider_type="inline::huggingface-gpu"),
|
||||||
|
]
|
||||||
|
return template
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
version: 2
|
version: 2
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
description: Quick start template for running Llama Stack with several popular providers
|
description: Quick start template for running Llama Stack with several popular providers.
|
||||||
|
This distribution is intended for CPU-only environments.
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_type: remote::cerebras
|
- provider_type: remote::cerebras
|
||||||
|
|
@ -34,7 +35,7 @@ distribution_spec:
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
post_training:
|
post_training:
|
||||||
- provider_type: inline::huggingface
|
- provider_type: inline::torchtune-cpu
|
||||||
eval:
|
eval:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
datasetio:
|
datasetio:
|
||||||
|
|
|
||||||
|
|
@ -156,13 +156,10 @@ providers:
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
||||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: huggingface
|
- provider_id: torchtune-cpu
|
||||||
provider_type: inline::huggingface
|
provider_type: inline::torchtune-cpu
|
||||||
config:
|
config:
|
||||||
checkpoint_format: huggingface
|
checkpoint_format: meta
|
||||||
distributed_backend: null
|
|
||||||
device: cpu
|
|
||||||
dpo_output_dir: ~/.llama/distributions/starter/dpo_output
|
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
],
|
],
|
||||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
"post_training": [BuildProvider(provider_type="inline::torchtune-cpu")],
|
||||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"datasetio": [
|
"datasetio": [
|
||||||
BuildProvider(provider_type="remote::huggingface"),
|
BuildProvider(provider_type="remote::huggingface"),
|
||||||
|
|
@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
description="Quick start template for running Llama Stack with several popular providers",
|
description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=None,
|
template_path=None,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
MP_SCALE = 8
|
MP_SCALE = 8
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="models")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def reduce_from_tensor_model_parallel_region(input_):
|
def reduce_from_tensor_model_parallel_region(input_):
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
|
||||||
from ..model import Transformer, TransformerBlock
|
from ..model import Transformer, TransformerBlock
|
||||||
from ..moe import MoE
|
from ..moe import MoE
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="models")
|
log = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper_no_reduce(
|
def swiglu_wrapper_no_reduce(
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import collections
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="llama")
|
log = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
RAG_TOOL_GROUP = "builtin::rag"
|
RAG_TOOL_GROUP = "builtin::rag"
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
||||||
from .persistence import AgentInfo
|
from .persistence import AgentInfo
|
||||||
from .responses.openai_responses import OpenAIResponsesImpl
|
from .responses.openai_responses import OpenAIResponsesImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="agents")
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(Session):
|
class AgentSessionInfo(Session):
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ from .utils import (
|
||||||
convert_response_text_to_chat_response_format,
|
convert_response_text_to_chat_response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="openai::responses")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ from llama_stack.log import get_logger
|
||||||
from .types import ChatCompletionContext, ChatCompletionResult
|
from .types import ChatCompletionContext, ChatCompletionResult
|
||||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class StreamingResponseOrchestrator:
|
class StreamingResponseOrchestrator:
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .types import ChatCompletionContext, ToolExecutionResult
|
from .types import ChatCompletionContext, ToolExecutionResult
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class ToolExecutor:
|
class ToolExecutor:
|
||||||
|
|
|
||||||
|
|
@ -101,14 +101,22 @@ async def convert_response_input_to_chat_messages(
|
||||||
"""
|
"""
|
||||||
messages: list[OpenAIMessageParam] = []
|
messages: list[OpenAIMessageParam] = []
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
|
# extract all OpenAIResponseInputFunctionToolCallOutput items
|
||||||
|
# so their corresponding OpenAIToolMessageParam instances can
|
||||||
|
# be added immediately following the corresponding
|
||||||
|
# OpenAIAssistantMessageParam
|
||||||
|
tool_call_results = {}
|
||||||
for input_item in input:
|
for input_item in input:
|
||||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
messages.append(
|
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
|
||||||
OpenAIToolMessageParam(
|
content=input_item.output,
|
||||||
content=input_item.output,
|
tool_call_id=input_item.call_id,
|
||||||
tool_call_id=input_item.call_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for input_item in input:
|
||||||
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
|
# skip as these have been extracted and inserted in order
|
||||||
|
pass
|
||||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||||
tool_call = OpenAIChatCompletionToolCall(
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
|
@ -119,6 +127,9 @@ async def convert_response_input_to_chat_messages(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
if input_item.call_id in tool_call_results:
|
||||||
|
messages.append(tool_call_results[input_item.call_id])
|
||||||
|
del tool_call_results[input_item.call_id]
|
||||||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
||||||
tool_call = OpenAIChatCompletionToolCall(
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
|
@ -146,6 +157,10 @@ async def convert_response_input_to_chat_messages(
|
||||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||||
)
|
)
|
||||||
messages.append(message_type(content=content))
|
messages.append(message_type(content=content))
|
||||||
|
if len(tool_call_results):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
messages.append(OpenAIUserMessageParam(content=input))
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
return messages
|
return messages
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="agents")
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches):
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
metadata: dict[str, str] | None = None,
|
metadata: dict[str, str] | None = None,
|
||||||
|
idempotency_key: str | None = None,
|
||||||
) -> BatchObject:
|
) -> BatchObject:
|
||||||
"""
|
"""
|
||||||
Create a new batch for processing multiple API requests.
|
Create a new batch for processing multiple API requests.
|
||||||
|
|
||||||
Error handling by levels -
|
This implementation provides optional idempotency: when an idempotency key
|
||||||
0. Input param handling, results in 40x errors before processing, e.g.
|
(idempotency_key) is provided, a deterministic ID is generated based on the input
|
||||||
- Wrong completion_window
|
parameters. If a batch with the same parameters already exists, it will be
|
||||||
- Invalid metadata types
|
returned instead of creating a duplicate. Without an idempotency key,
|
||||||
- Unknown endpoint
|
each request creates a new batch with a unique ID.
|
||||||
-> no batch created
|
|
||||||
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
Args:
|
||||||
- input_file_id missing
|
input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||||
- invalid json in file
|
endpoint: The endpoint to be used for all requests in the batch.
|
||||||
- missing custom_id, method, url, body
|
completion_window: The time window within which the batch should be processed.
|
||||||
- invalid model
|
metadata: Optional metadata for the batch.
|
||||||
- streaming
|
idempotency_key: Optional idempotency key for enabling idempotent behavior.
|
||||||
-> batch created, validation sends to failed status
|
|
||||||
2. Processing errors, result in error_file_id entries, e.g.
|
Returns:
|
||||||
- Any error returned from inference endpoint
|
The created or existing batch object.
|
||||||
-> batch created, goes to completed status
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Error handling by levels -
|
||||||
|
# 0. Input param handling, results in 40x errors before processing, e.g.
|
||||||
|
# - Wrong completion_window
|
||||||
|
# - Invalid metadata types
|
||||||
|
# - Unknown endpoint
|
||||||
|
# -> no batch created
|
||||||
|
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||||
|
# - input_file_id missing
|
||||||
|
# - invalid json in file
|
||||||
|
# - missing custom_id, method, url, body
|
||||||
|
# - invalid model
|
||||||
|
# - streaming
|
||||||
|
# -> batch created, validation sends to failed status
|
||||||
|
# 2. Processing errors, result in error_file_id entries, e.g.
|
||||||
|
# - Any error returned from inference endpoint
|
||||||
|
# -> batch created, goes to completed status
|
||||||
|
|
||||||
# TODO: set expiration time for garbage collection
|
# TODO: set expiration time for garbage collection
|
||||||
|
|
||||||
if endpoint not in ["/v1/chat/completions"]:
|
if endpoint not in ["/v1/chat/completions"]:
|
||||||
|
|
@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||||
|
|
||||||
|
# For idempotent requests, use the idempotency key for the batch ID
|
||||||
|
# This ensures the same key always maps to the same batch ID,
|
||||||
|
# allowing us to detect parameter conflicts
|
||||||
|
if idempotency_key is not None:
|
||||||
|
hash_input = idempotency_key.encode("utf-8")
|
||||||
|
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
|
||||||
|
batch_id = f"batch_{hash_digest}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_batch = await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
if (
|
||||||
|
existing_batch.input_file_id != input_file_id
|
||||||
|
or existing_batch.endpoint != endpoint
|
||||||
|
or existing_batch.completion_window != completion_window
|
||||||
|
or existing_batch.metadata != metadata
|
||||||
|
):
|
||||||
|
raise ConflictError(
|
||||||
|
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
|
||||||
|
"Either use a new idempotency key or ensure all parameters match the original request."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Returning existing batch with ID: {batch_id}")
|
||||||
|
return existing_batch
|
||||||
|
except ResourceNotFoundError:
|
||||||
|
# Batch doesn't exist, continue with creation
|
||||||
|
pass
|
||||||
|
|
||||||
current_time = int(time.time())
|
current_time = int(time.time())
|
||||||
|
|
||||||
batch = BatchObject(
|
batch = BatchObject(
|
||||||
|
|
@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||||
|
logger.info(f"Created new batch with ID: {batch_id}")
|
||||||
|
|
||||||
if self.process_batches:
|
if self.process_batches:
|
||||||
task = asyncio.create_task(self._process_batch(batch_id))
|
task = asyncio.create_task(self._process_batch(batch_id))
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
InterleavedContent,
|
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
|
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
|
||||||
tool_config: ToolConfig | None = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise ValueError("Sentence transformers don't support chat completion")
|
raise ValueError("Sentence transformers don't support chat completion")
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import datetime
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
metric_name: str,
|
metric_name: str,
|
||||||
start_time: int,
|
start_time: int,
|
||||||
end_time: int | None = None,
|
end_time: int | None = None,
|
||||||
granularity: str | None = "1d",
|
granularity: str | None = None,
|
||||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||||
) -> QueryMetricsResponse:
|
) -> QueryMetricsResponse:
|
||||||
raise NotImplementedError("Querying metrics is not implemented")
|
"""Query metrics from the telemetry store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric_name: The name of the metric to query (e.g., "prompt_tokens")
|
||||||
|
start_time: Start time as Unix timestamp
|
||||||
|
end_time: End time as Unix timestamp (defaults to now if None)
|
||||||
|
granularity: Time granularity for aggregation
|
||||||
|
query_type: Type of query (RANGE or INSTANT)
|
||||||
|
label_matchers: Label filters to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryMetricsResponse with metric time series data
|
||||||
|
"""
|
||||||
|
# Convert timestamps to datetime objects
|
||||||
|
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
|
||||||
|
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
|
||||||
|
|
||||||
|
# Use SQLite trace store if available
|
||||||
|
if hasattr(self, "trace_store") and self.trace_store:
|
||||||
|
return await self.trace_store.query_metrics(
|
||||||
|
metric_name=metric_name,
|
||||||
|
start_time=start_dt,
|
||||||
|
end_time=end_dt,
|
||||||
|
granularity=granularity,
|
||||||
|
query_type=query_type,
|
||||||
|
label_matchers=label_matchers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
|
||||||
|
)
|
||||||
|
|
||||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
|
AdapterSpec,
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
|
remote_provider_spec,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||||
|
|
||||||
|
|
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.files,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="s3",
|
||||||
|
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||||
|
module="llama_stack.providers.remote.files.s3",
|
||||||
|
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||||
|
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -40,8 +40,9 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_type="inline::sentence-transformers",
|
provider_type="inline::sentence-transformers",
|
||||||
|
# CrossEncoder depends on torchao.quantization
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||||
|
|
|
||||||
|
|
@ -5,27 +5,50 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||||
|
|
||||||
|
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
||||||
|
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||||
|
torchtune_def = dict(
|
||||||
|
api=Api.post_training,
|
||||||
|
pip_packages=["numpy"],
|
||||||
|
module="llama_stack.providers.inline.post_training.torchtune",
|
||||||
|
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.datasetio,
|
||||||
|
Api.datasets,
|
||||||
|
],
|
||||||
|
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> list[ProviderSpec]:
|
def available_providers() -> list[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.post_training,
|
**{ # type: ignore
|
||||||
provider_type="inline::torchtune",
|
**torchtune_def,
|
||||||
pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"],
|
"provider_type": "inline::torchtune-cpu",
|
||||||
module="llama_stack.providers.inline.post_training.torchtune",
|
"pip_packages": (
|
||||||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
cast(list[str], torchtune_def["pip_packages"])
|
||||||
api_dependencies=[
|
+ ["torch torchtune>=0.5.0 torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu"]
|
||||||
Api.datasetio,
|
),
|
||||||
Api.datasets,
|
},
|
||||||
],
|
),
|
||||||
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
|
InlineProviderSpec(
|
||||||
|
**{ # type: ignore
|
||||||
|
**torchtune_def,
|
||||||
|
"provider_type": "inline::torchtune-gpu",
|
||||||
|
"pip_packages": (
|
||||||
|
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune>=0.5.0 torchao>=0.12.0"]
|
||||||
|
),
|
||||||
|
},
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.post_training,
|
api=Api.post_training,
|
||||||
provider_type="inline::huggingface",
|
provider_type="inline::huggingface-gpu",
|
||||||
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
|
pip_packages=["trl", "transformers", "peft", "datasets", "torch"],
|
||||||
module="llama_stack.providers.inline.post_training.huggingface",
|
module="llama_stack.providers.inline.post_training.huggingface",
|
||||||
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
|
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
|
|
|
||||||
237
llama_stack/providers/remote/files/s3/README.md
Normal file
237
llama_stack/providers/remote/files/s3/README.md
Normal file
|
|
@ -0,0 +1,237 @@
|
||||||
|
# S3 Files Provider
|
||||||
|
|
||||||
|
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
|
||||||
|
- **Metadata Management**: Uses SQL database for efficient file metadata queries
|
||||||
|
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
|
||||||
|
- **Flexible Authentication**: Support for IAM roles and access keys
|
||||||
|
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Basic Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
api: files
|
||||||
|
provider_type: remote::s3
|
||||||
|
config:
|
||||||
|
bucket_name: my-llama-stack-files
|
||||||
|
region: us-east-1
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./s3_files_metadata.db
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
api: files
|
||||||
|
provider_type: remote::s3
|
||||||
|
config:
|
||||||
|
bucket_name: my-llama-stack-files
|
||||||
|
region: us-east-1
|
||||||
|
aws_access_key_id: YOUR_ACCESS_KEY
|
||||||
|
aws_secret_access_key: YOUR_SECRET_KEY
|
||||||
|
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./s3_files_metadata.db
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The configuration supports environment variable substitution:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: "${env.S3_BUCKET_NAME}"
|
||||||
|
region: "${env.AWS_REGION:=us-east-1}"
|
||||||
|
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
|
||||||
|
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
|
||||||
|
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
### IAM Roles (Recommended)
|
||||||
|
|
||||||
|
For production deployments, use IAM roles:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: my-bucket
|
||||||
|
region: us-east-1
|
||||||
|
# No credentials needed - will use IAM role
|
||||||
|
```
|
||||||
|
|
||||||
|
### Access Keys
|
||||||
|
|
||||||
|
For development or specific use cases:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: my-bucket
|
||||||
|
region: us-east-1
|
||||||
|
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
|
||||||
|
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
|
||||||
|
```
|
||||||
|
|
||||||
|
## S3 Bucket Setup
|
||||||
|
|
||||||
|
### Required Permissions
|
||||||
|
|
||||||
|
The S3 provider requires the following permissions:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:GetObject",
|
||||||
|
"s3:PutObject",
|
||||||
|
"s3:DeleteObject",
|
||||||
|
"s3:ListBucket"
|
||||||
|
],
|
||||||
|
"Resource": [
|
||||||
|
"arn:aws:s3:::your-bucket-name",
|
||||||
|
"arn:aws:s3:::your-bucket-name/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Automatic Bucket Creation
|
||||||
|
|
||||||
|
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: my-bucket
|
||||||
|
auto_create_bucket: true # Will create bucket if it doesn't exist
|
||||||
|
region: us-east-1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:GetObject",
|
||||||
|
"s3:PutObject",
|
||||||
|
"s3:DeleteObject",
|
||||||
|
"s3:ListBucket",
|
||||||
|
"s3:CreateBucket"
|
||||||
|
],
|
||||||
|
"Resource": [
|
||||||
|
"arn:aws:s3:::your-bucket-name",
|
||||||
|
"arn:aws:s3:::your-bucket-name/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bucket Policy (Optional)
|
||||||
|
|
||||||
|
For additional security, you can add a bucket policy:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Sid": "LlamaStackAccess",
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Principal": {
|
||||||
|
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||||
|
},
|
||||||
|
"Action": [
|
||||||
|
"s3:GetObject",
|
||||||
|
"s3:PutObject",
|
||||||
|
"s3:DeleteObject"
|
||||||
|
],
|
||||||
|
"Resource": "arn:aws:s3:::your-bucket-name/*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Sid": "LlamaStackBucketAccess",
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Principal": {
|
||||||
|
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||||
|
},
|
||||||
|
"Action": [
|
||||||
|
"s3:ListBucket"
|
||||||
|
],
|
||||||
|
"Resource": "arn:aws:s3:::your-bucket-name"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Metadata Persistence
|
||||||
|
|
||||||
|
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
|
||||||
|
|
||||||
|
- File ID
|
||||||
|
- Original filename
|
||||||
|
- Purpose (assistants, batch, etc.)
|
||||||
|
- File size in bytes
|
||||||
|
- Created and expiration timestamps
|
||||||
|
|
||||||
|
### TTL and Cleanup
|
||||||
|
|
||||||
|
Files currently have a fixed long expiration time (100 years).
|
||||||
|
|
||||||
|
## Development and Testing
|
||||||
|
|
||||||
|
### Using MinIO
|
||||||
|
|
||||||
|
For self-hosted S3-compatible storage:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: test-bucket
|
||||||
|
region: us-east-1
|
||||||
|
endpoint_url: http://localhost:9000
|
||||||
|
aws_access_key_id: minioadmin
|
||||||
|
aws_secret_access_key: minioadmin
|
||||||
|
```
|
||||||
|
|
||||||
|
## Monitoring and Logging
|
||||||
|
|
||||||
|
The provider logs important operations and errors. For production deployments, consider:
|
||||||
|
|
||||||
|
- CloudWatch monitoring for S3 operations
|
||||||
|
- Custom metrics for file upload/download rates
|
||||||
|
- Error rate monitoring
|
||||||
|
- Performance metrics tracking
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The provider handles various error scenarios:
|
||||||
|
|
||||||
|
- S3 connectivity issues
|
||||||
|
- Bucket access permissions
|
||||||
|
- File not found errors
|
||||||
|
- Metadata consistency checks
|
||||||
|
|
||||||
|
## Known Limitations
|
||||||
|
|
||||||
|
- Fixed long TTL (100 years) instead of configurable expiration
|
||||||
|
- No server-side encryption enabled by default
|
||||||
|
- No support for AWS session tokens
|
||||||
|
- No S3 key prefix organization support
|
||||||
|
- No multipart upload support (all files uploaded as single objects)
|
||||||
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import Api
|
||||||
|
|
||||||
|
from .config import S3FilesImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
|
||||||
|
from .files import S3FilesImpl
|
||||||
|
|
||||||
|
# TODO: authorization policies and user separation
|
||||||
|
impl = S3FilesImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
42
llama_stack/providers/remote/files/s3/config.py
Normal file
42
llama_stack/providers/remote/files/s3/config.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class S3FilesImplConfig(BaseModel):
|
||||||
|
"""Configuration for S3-based files provider."""
|
||||||
|
|
||||||
|
bucket_name: str = Field(description="S3 bucket name to store files")
|
||||||
|
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
||||||
|
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
||||||
|
aws_secret_access_key: str | None = Field(
|
||||||
|
default=None, description="AWS secret access key (optional if using IAM roles)"
|
||||||
|
)
|
||||||
|
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
||||||
|
auto_create_bucket: bool = Field(
|
||||||
|
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||||
|
)
|
||||||
|
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
|
||||||
|
"region": "${env.AWS_REGION:=us-east-1}",
|
||||||
|
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
|
||||||
|
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
|
||||||
|
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
|
||||||
|
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
|
||||||
|
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="s3_files_metadata.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
272
llama_stack/providers/remote/files/s3/files.py
Normal file
272
llama_stack/providers/remote/files/s3/files.py
Normal file
|
|
@ -0,0 +1,272 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||||
|
from fastapi import File, Form, Response, UploadFile
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.apis.files import (
|
||||||
|
Files,
|
||||||
|
ListOpenAIFileResponse,
|
||||||
|
OpenAIFileDeleteResponse,
|
||||||
|
OpenAIFileObject,
|
||||||
|
OpenAIFilePurpose,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||||
|
|
||||||
|
from .config import S3FilesImplConfig
|
||||||
|
|
||||||
|
# TODO: provider data for S3 credentials
|
||||||
|
|
||||||
|
|
||||||
|
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
||||||
|
try:
|
||||||
|
s3_config = {
|
||||||
|
"region_name": config.region,
|
||||||
|
}
|
||||||
|
|
||||||
|
# endpoint URL if specified (for MinIO, LocalStack, etc.)
|
||||||
|
if config.endpoint_url:
|
||||||
|
s3_config["endpoint_url"] = config.endpoint_url
|
||||||
|
|
||||||
|
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||||
|
s3_config.update(
|
||||||
|
{
|
||||||
|
"aws_access_key_id": config.aws_access_key_id,
|
||||||
|
"aws_secret_access_key": config.aws_secret_access_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return boto3.client("s3", **s3_config)
|
||||||
|
|
||||||
|
except (BotoCoreError, NoCredentialsError) as e:
|
||||||
|
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
|
||||||
|
try:
|
||||||
|
client.head_bucket(Bucket=config.bucket_name)
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response["Error"]["Code"]
|
||||||
|
if error_code == "404":
|
||||||
|
if not config.auto_create_bucket:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"S3 bucket '{config.bucket_name}' does not exist. "
|
||||||
|
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
|
||||||
|
) from e
|
||||||
|
try:
|
||||||
|
# For us-east-1, we can't specify LocationConstraint
|
||||||
|
if config.region == "us-east-1":
|
||||||
|
client.create_bucket(Bucket=config.bucket_name)
|
||||||
|
else:
|
||||||
|
client.create_bucket(
|
||||||
|
Bucket=config.bucket_name,
|
||||||
|
CreateBucketConfiguration={"LocationConstraint": config.region},
|
||||||
|
)
|
||||||
|
except ClientError as create_error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
|
||||||
|
) from create_error
|
||||||
|
elif error_code == "403":
|
||||||
|
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
class S3FilesImpl(Files):
|
||||||
|
"""S3-based implementation of the Files API."""
|
||||||
|
|
||||||
|
# TODO: implement expiration, for now a silly offset
|
||||||
|
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
|
||||||
|
|
||||||
|
def __init__(self, config: S3FilesImplConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._client: boto3.client | None = None
|
||||||
|
self._sql_store: SqlStore | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self._client = _create_s3_client(self._config)
|
||||||
|
await _create_bucket_if_not_exists(self._client, self._config)
|
||||||
|
|
||||||
|
self._sql_store = sqlstore_impl(self._config.metadata_store)
|
||||||
|
await self._sql_store.create_table(
|
||||||
|
"openai_files",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"filename": ColumnType.STRING,
|
||||||
|
"purpose": ColumnType.STRING,
|
||||||
|
"bytes": ColumnType.INTEGER,
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"expires_at": ColumnType.INTEGER,
|
||||||
|
# TODO: add s3_etag field for integrity checking
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> boto3.client:
|
||||||
|
assert self._client is not None, "Provider not initialized"
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sql_store(self) -> SqlStore:
|
||||||
|
assert self._sql_store is not None, "Provider not initialized"
|
||||||
|
return self._sql_store
|
||||||
|
|
||||||
|
async def openai_upload_file(
|
||||||
|
self,
|
||||||
|
file: Annotated[UploadFile, File()],
|
||||||
|
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
file_id = f"file-{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||||
|
|
||||||
|
created_at = int(time.time())
|
||||||
|
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
|
||||||
|
await self.sql_store.insert(
|
||||||
|
"openai_files",
|
||||||
|
{
|
||||||
|
"id": file_id,
|
||||||
|
"filename": filename,
|
||||||
|
"purpose": purpose.value,
|
||||||
|
"bytes": file_size,
|
||||||
|
"created_at": created_at,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.put_object(
|
||||||
|
Bucket=self._config.bucket_name,
|
||||||
|
Key=file_id,
|
||||||
|
Body=content,
|
||||||
|
# TODO: enable server-side encryption
|
||||||
|
)
|
||||||
|
except ClientError as e:
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
|
||||||
|
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||||
|
|
||||||
|
return OpenAIFileObject(
|
||||||
|
id=file_id,
|
||||||
|
filename=filename,
|
||||||
|
purpose=purpose,
|
||||||
|
bytes=file_size,
|
||||||
|
created_at=created_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_files(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 10000,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
purpose: OpenAIFilePurpose | None = None,
|
||||||
|
) -> ListOpenAIFileResponse:
|
||||||
|
# this purely defensive. it should not happen because the router also default to Order.desc.
|
||||||
|
if not order:
|
||||||
|
order = Order.desc
|
||||||
|
|
||||||
|
where_conditions = {}
|
||||||
|
if purpose:
|
||||||
|
where_conditions["purpose"] = purpose.value
|
||||||
|
|
||||||
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
|
table="openai_files",
|
||||||
|
where=where_conditions if where_conditions else None,
|
||||||
|
order_by=[("created_at", order.value)],
|
||||||
|
cursor=("id", after) if after else None,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
files = [
|
||||||
|
OpenAIFileObject(
|
||||||
|
id=row["id"],
|
||||||
|
filename=row["filename"],
|
||||||
|
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||||
|
bytes=row["bytes"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
)
|
||||||
|
for row in paginated_result.data
|
||||||
|
]
|
||||||
|
|
||||||
|
return ListOpenAIFileResponse(
|
||||||
|
data=files,
|
||||||
|
has_more=paginated_result.has_more,
|
||||||
|
# empty string or None? spec says str, ref impl returns str | None, we go with spec
|
||||||
|
first_id=files[0].id if files else "",
|
||||||
|
last_id=files[-1].id if files else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
|
|
||||||
|
return OpenAIFileObject(
|
||||||
|
id=row["id"],
|
||||||
|
filename=row["filename"],
|
||||||
|
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||||
|
bytes=row["bytes"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.delete_object(
|
||||||
|
Bucket=self._config.bucket_name,
|
||||||
|
Key=row["id"],
|
||||||
|
)
|
||||||
|
except ClientError as e:
|
||||||
|
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||||
|
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||||
|
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
|
||||||
|
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.get_object(
|
||||||
|
Bucket=self._config.bucket_name,
|
||||||
|
Key=row["id"],
|
||||||
|
)
|
||||||
|
# TODO: can we stream this instead of loading it into memory
|
||||||
|
content = response["Body"].read()
|
||||||
|
except ClientError as e:
|
||||||
|
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||||
|
raise RuntimeError(f"Failed to download file from S3: {e}") from e
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||||
|
)
|
||||||
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||||
|
|
||||||
|
|
||||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,11 @@ client.initialize()
|
||||||
|
|
||||||
### Create Completion
|
### Create Completion
|
||||||
|
|
||||||
|
> Note on Completion API
|
||||||
|
>
|
||||||
|
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = client.inference.completion(
|
response = client.inference.completion(
|
||||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
|
@ -76,6 +81,73 @@ response = client.inference.chat_completion(
|
||||||
print(f"Response: {response.completion_message.content}")
|
print(f"Response: {response.completion_message.content}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Tool Calling Example ###
|
||||||
|
```python
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
|
|
||||||
|
tool_definition = ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get current weather information for a location",
|
||||||
|
parameters={
|
||||||
|
"location": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The city and state, e.g. San Francisco, CA",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"unit": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="Temperature unit (celsius or fahrenheit)",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_response = client.inference.chat_completion(
|
||||||
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||||
|
tools=[tool_definition],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Tool Response: {tool_response.completion_message.content}")
|
||||||
|
if tool_response.completion_message.tool_calls:
|
||||||
|
for tool_call in tool_response.completion_message.tool_calls:
|
||||||
|
print(f"Tool Called: {tool_call.tool_name}")
|
||||||
|
print(f"Arguments: {tool_call.arguments}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Structured Output Example
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||||
|
|
||||||
|
person_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "integer"},
|
||||||
|
"occupation": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["name", "age", "occupation"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response_format = JsonSchemaResponseFormat(
|
||||||
|
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
structured_response = client.inference.chat_completion(
|
||||||
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
response_format=response_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Structured Response: {structured_response.completion_message.content}")
|
||||||
|
```
|
||||||
|
|
||||||
### Create Embeddings
|
### Create Embeddings
|
||||||
> Note on OpenAI embeddings compatibility
|
> Note on OpenAI embeddings compatibility
|
||||||
>
|
>
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
|
from openai import NOT_GIVEN, APIConnectionError
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
|
@ -57,7 +57,7 @@ from .openai_utils import (
|
||||||
)
|
)
|
||||||
from .utils import _is_nvidia_hosted
|
from .utils import _is_nvidia_hosted
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||||
|
|
||||||
|
|
||||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
|
|
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
}
|
}
|
||||||
extra_body["input_type"] = task_type_options[task_type]
|
extra_body["input_type"] = task_type_options[task_type]
|
||||||
|
|
||||||
try:
|
response = await self.client.embeddings.create(
|
||||||
response = await self.client.embeddings.create(
|
model=provider_model_id,
|
||||||
model=provider_model_id,
|
input=input,
|
||||||
input=input,
|
extra_body=extra_body,
|
||||||
extra_body=extra_body,
|
)
|
||||||
)
|
|
||||||
except BadRequestError as e:
|
|
||||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||||
# ->
|
# ->
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from . import NVIDIAConfig
|
from . import NVIDIAConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::ollama")
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(
|
class OllamaInferenceAdapter(
|
||||||
|
|
@ -619,28 +619,6 @@ class OllamaInferenceAdapter(
|
||||||
response.id = id
|
response.id = id
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch completion is not supported for Ollama")
|
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from .config import OpenAIConfig
|
from .config import OpenAIConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::openai")
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue