forked from phoenix-oss/llama-stack-mirror
Merge branch 'main' into pr1573
This commit is contained in:
commit
0e2a13da9c
39 changed files with 5311 additions and 407 deletions
80
.github/workflows/integration-tests.yml
vendored
Normal file
80
.github/workflows/integration-tests.yml
vendored
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
name: Integration tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ollama:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install Ollama
|
||||||
|
run: |
|
||||||
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
|
- name: Pull Ollama image
|
||||||
|
run: |
|
||||||
|
ollama pull llama3.2:3b-instruct-fp16
|
||||||
|
|
||||||
|
- name: Start Ollama in background
|
||||||
|
run: |
|
||||||
|
nohup ollama run llama3.2:3b-instruct-fp16 > ollama.log 2>&1 &
|
||||||
|
|
||||||
|
- name: Set Up Environment and Install Dependencies
|
||||||
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install ollama faiss-cpu
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Wait for Ollama to start
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Ollama..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
||||||
|
echo "Ollama is running!"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Ollama failed to start"
|
||||||
|
ollama ps
|
||||||
|
ollama.log
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: Start Llama Stack server in background
|
||||||
|
env:
|
||||||
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
# TODO: use "llama stack run"
|
||||||
|
nohup uv run python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml > server.log 2>&1 &
|
||||||
|
|
||||||
|
- name: Wait for Llama Stack server to be ready
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Llama Stack server..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
||||||
|
echo " Llama Stack server is up!"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo " Llama Stack server failed to start"
|
||||||
|
cat server.log
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: Run Inference Integration Tests
|
||||||
|
env:
|
||||||
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
run: |
|
||||||
|
uv run pytest -v tests/integration/inference --stack-config=ollama --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
2
.github/workflows/providers-build.yml
vendored
2
.github/workflows/providers-build.yml
vendored
|
@ -51,7 +51,7 @@ jobs:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v4
|
uses: astral-sh/setup-uv@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
|
45
.github/workflows/stale_bot.yml
vendored
Normal file
45
.github/workflows/stale_bot.yml
vendored
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
name: Close stale issues and PRs
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * *' # every day at midnight
|
||||||
|
|
||||||
|
env:
|
||||||
|
LC_ALL: en_US.UTF-8
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Stale Action
|
||||||
|
uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
stale-issue-label: 'stale'
|
||||||
|
stale-issue-message: >
|
||||||
|
This issue has been automatically marked as stale because it has not had activity within 60 days.
|
||||||
|
It will be automatically closed if no further activity occurs within 30 days.
|
||||||
|
close-issue-message: >
|
||||||
|
This issue has been automatically closed due to inactivity.
|
||||||
|
Please feel free to reopen if you feel it is still relevant!
|
||||||
|
days-before-issue-stale: 60
|
||||||
|
days-before-issue-close: 30
|
||||||
|
stale-pr-label: 'stale'
|
||||||
|
stale-pr-message: >
|
||||||
|
This pull request has been automatically marked as stale because it has not had activity within 60 days.
|
||||||
|
It will be automatically closed if no further activity occurs within 30 days.
|
||||||
|
close-pr-message: >
|
||||||
|
This pull request has been automatically closed due to inactivity.
|
||||||
|
Please feel free to reopen if you intend to continue working on it!
|
||||||
|
days-before-pr-stale: 60
|
||||||
|
days-before-pr-close: 30
|
||||||
|
operations-per-run: 300
|
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -33,7 +33,7 @@ jobs:
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
||||||
uv run --python ${{ matrix.python }} --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
|
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
|
||||||
|
|
||||||
- name: Upload test results
|
- name: Upload test results
|
||||||
if: always()
|
if: always()
|
||||||
|
|
|
@ -8,6 +8,7 @@ repos:
|
||||||
rev: v5.0.0 # Latest stable version
|
rev: v5.0.0 # Latest stable version
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
|
args: ['--assume-in-merge']
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
exclude: '\.py$' # Exclude Python files as Ruff already handles them
|
exclude: '\.py$' # Exclude Python files as Ruff already handles them
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
|
@ -82,6 +83,17 @@ repos:
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: openapi-codegen
|
||||||
|
name: API Spec Codegen
|
||||||
|
additional_dependencies:
|
||||||
|
- uv==0.6.2
|
||||||
|
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1'
|
||||||
|
language: python
|
||||||
|
pass_filenames: false
|
||||||
|
require_serial: true
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
||||||
|
|
|
@ -108,6 +108,22 @@ uv run pre-commit run --all-files
|
||||||
> [!CAUTION]
|
> [!CAUTION]
|
||||||
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||||
|
|
||||||
|
## Running unit tests
|
||||||
|
|
||||||
|
You can run the unit tests by running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate
|
||||||
|
./scripts/unit-tests.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
If you'd like to run for a non-default version of Python (currently 3.10), pass `PYTHON_VERSION` variable as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
source .venv/bin/activate
|
||||||
|
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
||||||
|
```
|
||||||
|
|
||||||
## Adding a new dependency to the project
|
## Adding a new dependency to the project
|
||||||
|
|
||||||
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:
|
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:
|
||||||
|
|
|
@ -51,6 +51,10 @@ Here is a list of the various API providers and available distributions that can
|
||||||
| PG Vector | Single Node | | | ✅ | | |
|
| PG Vector | Single Node | | | ✅ | | |
|
||||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
||||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
| vLLM | Hosted and Single Node | | ✅ | | | |
|
||||||
|
| OpenAI | Hosted | | ✅ | | | |
|
||||||
|
| Anthropic | Hosted | | ✅ | | | |
|
||||||
|
| Gemini | Hosted | | ✅ | | | |
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
|
||||||
|
|
117
docs/_static/llama-stack-spec.html
vendored
117
docs/_static/llama-stack-spec.html
vendored
|
@ -2092,6 +2092,48 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/providers/{provider_id}": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ProviderInfo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Providers"
|
||||||
|
],
|
||||||
|
"description": "",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "provider_id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/tool-runtime/invoke": {
|
"/v1/tool-runtime/invoke": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -2660,7 +2702,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/inspect/providers": {
|
"/v1/providers": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -7965,6 +8007,53 @@
|
||||||
],
|
],
|
||||||
"title": "InsertChunksRequest"
|
"title": "InsertChunksRequest"
|
||||||
},
|
},
|
||||||
|
"ProviderInfo": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"api": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"provider_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"provider_type": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"api",
|
||||||
|
"provider_id",
|
||||||
|
"provider_type",
|
||||||
|
"config"
|
||||||
|
],
|
||||||
|
"title": "ProviderInfo"
|
||||||
|
},
|
||||||
"InvokeToolRequest": {
|
"InvokeToolRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -8226,27 +8315,6 @@
|
||||||
],
|
],
|
||||||
"title": "ListModelsResponse"
|
"title": "ListModelsResponse"
|
||||||
},
|
},
|
||||||
"ProviderInfo": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"api": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"provider_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"provider_type": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"api",
|
|
||||||
"provider_id",
|
|
||||||
"provider_type"
|
|
||||||
],
|
|
||||||
"title": "ProviderInfo"
|
|
||||||
},
|
|
||||||
"ListProvidersResponse": {
|
"ListProvidersResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -10246,6 +10314,10 @@
|
||||||
{
|
{
|
||||||
"name": "PostTraining (Coming Soon)"
|
"name": "PostTraining (Coming Soon)"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "Providers",
|
||||||
|
"x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations."
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Safety"
|
"name": "Safety"
|
||||||
},
|
},
|
||||||
|
@ -10292,6 +10364,7 @@
|
||||||
"Inspect",
|
"Inspect",
|
||||||
"Models",
|
"Models",
|
||||||
"PostTraining (Coming Soon)",
|
"PostTraining (Coming Soon)",
|
||||||
|
"Providers",
|
||||||
"Safety",
|
"Safety",
|
||||||
"Scoring",
|
"Scoring",
|
||||||
"ScoringFunctions",
|
"ScoringFunctions",
|
||||||
|
|
75
docs/_static/llama-stack-spec.yaml
vendored
75
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1400,6 +1400,34 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/InsertChunksRequest'
|
$ref: '#/components/schemas/InsertChunksRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/providers/{provider_id}:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ProviderInfo'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Providers
|
||||||
|
description: ''
|
||||||
|
parameters:
|
||||||
|
- name: provider_id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
/v1/tool-runtime/invoke:
|
/v1/tool-runtime/invoke:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
@ -1792,7 +1820,7 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/RegisterModelRequest'
|
$ref: '#/components/schemas/RegisterModelRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/inspect/providers:
|
/v1/providers:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -5450,6 +5478,32 @@ components:
|
||||||
- vector_db_id
|
- vector_db_id
|
||||||
- chunks
|
- chunks
|
||||||
title: InsertChunksRequest
|
title: InsertChunksRequest
|
||||||
|
ProviderInfo:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
api:
|
||||||
|
type: string
|
||||||
|
provider_id:
|
||||||
|
type: string
|
||||||
|
provider_type:
|
||||||
|
type: string
|
||||||
|
config:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- api
|
||||||
|
- provider_id
|
||||||
|
- provider_type
|
||||||
|
- config
|
||||||
|
title: ProviderInfo
|
||||||
InvokeToolRequest:
|
InvokeToolRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5613,21 +5667,6 @@ components:
|
||||||
required:
|
required:
|
||||||
- data
|
- data
|
||||||
title: ListModelsResponse
|
title: ListModelsResponse
|
||||||
ProviderInfo:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
api:
|
|
||||||
type: string
|
|
||||||
provider_id:
|
|
||||||
type: string
|
|
||||||
provider_type:
|
|
||||||
type: string
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- api
|
|
||||||
- provider_id
|
|
||||||
- provider_type
|
|
||||||
title: ProviderInfo
|
|
||||||
ListProvidersResponse:
|
ListProvidersResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -6921,6 +6960,9 @@ tags:
|
||||||
- name: Inspect
|
- name: Inspect
|
||||||
- name: Models
|
- name: Models
|
||||||
- name: PostTraining (Coming Soon)
|
- name: PostTraining (Coming Soon)
|
||||||
|
- name: Providers
|
||||||
|
x-displayName: >-
|
||||||
|
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||||
- name: Safety
|
- name: Safety
|
||||||
- name: Scoring
|
- name: Scoring
|
||||||
- name: ScoringFunctions
|
- name: ScoringFunctions
|
||||||
|
@ -6945,6 +6987,7 @@ x-tagGroups:
|
||||||
- Inspect
|
- Inspect
|
||||||
- Models
|
- Models
|
||||||
- PostTraining (Coming Soon)
|
- PostTraining (Coming Soon)
|
||||||
|
- Providers
|
||||||
- Safety
|
- Safety
|
||||||
- Scoring
|
- Scoring
|
||||||
- ScoringFunctions
|
- ScoringFunctions
|
||||||
|
|
|
@ -61,6 +61,10 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
||||||
| Groq | Hosted |
|
| Groq | Hosted |
|
||||||
| SambaNova | Hosted |
|
| SambaNova | Hosted |
|
||||||
| PyTorch ExecuTorch | On-device iOS, Android |
|
| PyTorch ExecuTorch | On-device iOS, Android |
|
||||||
|
| OpenAI | Hosted |
|
||||||
|
| Anthropic | Hosted |
|
||||||
|
| Gemini | Hosted |
|
||||||
|
|
||||||
|
|
||||||
**Vector IO API**
|
**Vector IO API**
|
||||||
| **Provider** | **Environments** |
|
| **Provider** | **Environments** |
|
||||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Api(Enum):
|
class Api(Enum):
|
||||||
|
providers = "providers"
|
||||||
inference = "inference"
|
inference = "inference"
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agents = "agents"
|
agents = "agents"
|
||||||
|
|
|
@ -11,13 +11,6 @@ from pydantic import BaseModel
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ProviderInfo(BaseModel):
|
|
||||||
api: str
|
|
||||||
provider_id: str
|
|
||||||
provider_type: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RouteInfo(BaseModel):
|
class RouteInfo(BaseModel):
|
||||||
route: str
|
route: str
|
||||||
|
@ -32,14 +25,21 @@ class HealthInfo(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VersionInfo(BaseModel):
|
class ProviderInfo(BaseModel):
|
||||||
version: str
|
api: str
|
||||||
|
provider_id: str
|
||||||
|
provider_type: str
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
class ListProvidersResponse(BaseModel):
|
||||||
data: List[ProviderInfo]
|
data: List[ProviderInfo]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VersionInfo(BaseModel):
|
||||||
|
version: str
|
||||||
|
|
||||||
|
|
||||||
class ListRoutesResponse(BaseModel):
|
class ListRoutesResponse(BaseModel):
|
||||||
data: List[RouteInfo]
|
data: List[RouteInfo]
|
||||||
|
|
||||||
|
|
7
llama_stack/apis/providers/__init__.py
Normal file
7
llama_stack/apis/providers/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .providers import * # noqa: F401 F403
|
36
llama_stack/apis/providers/providers.py
Normal file
36
llama_stack/apis/providers/providers.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ProviderInfo(BaseModel):
|
||||||
|
api: str
|
||||||
|
provider_id: str
|
||||||
|
provider_type: str
|
||||||
|
config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ListProvidersResponse(BaseModel):
|
||||||
|
data: List[ProviderInfo]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Providers(Protocol):
|
||||||
|
"""
|
||||||
|
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/providers", method="GET")
|
||||||
|
async def list_providers(self) -> ListProvidersResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||||
|
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
|
@ -10,7 +10,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||||
d = json.load(f)
|
d = json.load(f)
|
||||||
manifest = Manifest(**d)
|
manifest = Manifest(**d)
|
||||||
|
|
||||||
if datetime.now() > manifest.expires_on:
|
if datetime.now(timezone.utc) > manifest.expires_on:
|
||||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
|
@ -41,8 +41,8 @@ class ModelPromptFormat(Subcommand):
|
||||||
"-m",
|
"-m",
|
||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="llama3_1",
|
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
|
||||||
help="Model Family (llama3_1, llama3_X, etc.)",
|
"(Run `llama model list` to see a list of valid model names)",
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"-l",
|
"-l",
|
||||||
|
@ -60,7 +60,6 @@ class ModelPromptFormat(Subcommand):
|
||||||
]
|
]
|
||||||
|
|
||||||
model_list = [m.value for m in supported_model_ids]
|
model_list = [m.value for m in supported_model_ids]
|
||||||
model_str = "\n".join(model_list)
|
|
||||||
|
|
||||||
if args.list:
|
if args.list:
|
||||||
headers = ["Model(s)"]
|
headers = ["Model(s)"]
|
||||||
|
@ -81,10 +80,16 @@ class ModelPromptFormat(Subcommand):
|
||||||
try:
|
try:
|
||||||
model_id = CoreModelId(args.model_name)
|
model_id = CoreModelId(args.model_name)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
|
self.parser.error(
|
||||||
|
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
|
||||||
|
f"Run `llama model list` to see the valid model names."
|
||||||
|
)
|
||||||
|
|
||||||
if model_id not in supported_model_ids:
|
if model_id not in supported_model_ids:
|
||||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
self.parser.error(
|
||||||
|
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
|
||||||
|
f"Run `llama model list` to see the valid model names."
|
||||||
|
)
|
||||||
|
|
||||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||||
|
|
|
@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
if config.apis:
|
if config.apis:
|
||||||
apis_to_serve = config.apis
|
apis_to_serve = config.apis
|
||||||
else:
|
else:
|
||||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
|
||||||
|
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
|
@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
|
|
||||||
def providable_apis() -> List[Api]:
|
def providable_apis() -> List[Api]:
|
||||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||||
|
|
||||||
|
|
||||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
|
|
59
llama_stack/distribution/providers.py
Normal file
59
llama_stack/distribution/providers.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||||
|
|
||||||
|
from .datatypes import StackRunConfig
|
||||||
|
from .stack import redact_sensitive_fields
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderImplConfig(BaseModel):
|
||||||
|
run_config: StackRunConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config, deps):
|
||||||
|
impl = ProviderImpl(config, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderImpl(Providers):
|
||||||
|
def __init__(self, config, deps):
|
||||||
|
self.config = config
|
||||||
|
self.deps = deps
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
|
run_config = self.config.run_config
|
||||||
|
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||||
|
ret = []
|
||||||
|
for api, providers in safe_config.providers.items():
|
||||||
|
ret.extend(
|
||||||
|
[
|
||||||
|
ProviderInfo(
|
||||||
|
api=api,
|
||||||
|
provider_id=p.provider_id,
|
||||||
|
provider_type=p.provider_type,
|
||||||
|
config=p.config,
|
||||||
|
)
|
||||||
|
for p in providers
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return ListProvidersResponse(data=ret)
|
||||||
|
|
||||||
|
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||||
|
all_providers = await self.list_providers()
|
||||||
|
for p in all_providers.data:
|
||||||
|
if p.provider_id == provider_id:
|
||||||
|
return p
|
||||||
|
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
|
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
|
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
|
||||||
|
|
||||||
def api_protocol_map() -> Dict[Api, Any]:
|
def api_protocol_map() -> Dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
|
Api.providers: ProvidersAPI,
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
Api.inspect: Inspect,
|
Api.inspect: Inspect,
|
||||||
|
@ -247,6 +249,25 @@ def sort_providers_by_deps(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sorted_providers.append(
|
||||||
|
(
|
||||||
|
"providers",
|
||||||
|
ProviderWithSpec(
|
||||||
|
provider_id="__builtin__",
|
||||||
|
provider_type="__builtin__",
|
||||||
|
config={"run_config": run_config.model_dump()},
|
||||||
|
spec=InlineProviderSpec(
|
||||||
|
api=Api.providers,
|
||||||
|
provider_type="__builtin__",
|
||||||
|
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||||
|
module="llama_stack.distribution.providers",
|
||||||
|
api_dependencies=apis,
|
||||||
|
deps__=[x.value for x in apis],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||||
|
|
|
@ -368,6 +368,7 @@ def main():
|
||||||
apis_to_serve.add(inf.routing_table_api.value)
|
apis_to_serve.add(inf.routing_table_api.value)
|
||||||
|
|
||||||
apis_to_serve.add("inspect")
|
apis_to_serve.add("inspect")
|
||||||
|
apis_to_serve.add("providers")
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
|
from llama_stack.apis.providers import Providers
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
|
@ -44,6 +45,7 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
|
Providers,
|
||||||
VectorDBs,
|
VectorDBs,
|
||||||
Inference,
|
Inference,
|
||||||
BatchInference,
|
BatchInference,
|
||||||
|
|
|
@ -34,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
||||||
)
|
)
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
template_str.lstrip("\n"),
|
template_str.lstrip("\n"),
|
||||||
{"today": datetime.now().strftime("%d %B %Y")},
|
{
|
||||||
|
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def data_examples(self) -> List[Any]:
|
def data_examples(self) -> List[Any]:
|
||||||
|
|
|
@ -11,7 +11,7 @@ import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
request.session_id, request.turn_id
|
request.session_id, request.turn_id
|
||||||
)
|
)
|
||||||
now = datetime.now().astimezone().isoformat()
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
tool_execution_step = ToolExecutionStep(
|
tool_execution_step = ToolExecutionStep(
|
||||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
|
@ -264,7 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
start_time = last_turn.started_at
|
start_time = last_turn.started_at
|
||||||
else:
|
else:
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
start_time = datetime.now().astimezone().isoformat()
|
start_time = datetime.now(timezone.utc).isoformat()
|
||||||
input_messages = request.messages
|
input_messages = request.messages
|
||||||
|
|
||||||
output_message = None
|
output_message = None
|
||||||
|
@ -295,7 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages=input_messages,
|
input_messages=input_messages,
|
||||||
output_message=output_message,
|
output_message=output_message,
|
||||||
started_at=start_time,
|
started_at=start_time,
|
||||||
completed_at=datetime.now().astimezone().isoformat(),
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
@ -386,7 +386,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
shield_call_start_time = datetime.now().astimezone().isoformat()
|
shield_call_start_time = datetime.now(timezone.utc).isoformat()
|
||||||
try:
|
try:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -410,7 +410,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=e.violation,
|
violation=e.violation,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now().astimezone().isoformat(),
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -433,7 +433,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=None,
|
violation=None,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now().astimezone().isoformat(),
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -472,7 +472,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
client_tools[tool.name] = tool
|
client_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
inference_start_time = datetime.now().astimezone().isoformat()
|
inference_start_time = datetime.now(timezone.utc).isoformat()
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
|
@ -582,7 +582,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
model_response=copy.deepcopy(message),
|
model_response=copy.deepcopy(message),
|
||||||
started_at=inference_start_time,
|
started_at=inference_start_time,
|
||||||
completed_at=datetime.now().astimezone().isoformat(),
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -653,7 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
tool_calls=[tool_call],
|
tool_calls=[tool_call],
|
||||||
tool_responses=[],
|
tool_responses=[],
|
||||||
started_at=datetime.now().astimezone().isoformat(),
|
started_at=datetime.now(timezone.utc).isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
yield message
|
yield message
|
||||||
|
@ -670,7 +670,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
tool_result = await self.execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
session_id,
|
session_id,
|
||||||
|
@ -708,7 +708,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
completed_at=datetime.now().astimezone().isoformat(),
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -36,7 +36,7 @@ class AgentPersistence:
|
||||||
session_info = AgentSessionInfo(
|
session_info = AgentSessionInfo(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name=name,
|
session_name=name,
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
||||||
job_status_response = PostTrainingJobStatusResponse(
|
job_status_response = PostTrainingJobStatusResponse(
|
||||||
job_uuid=job_uuid,
|
job_uuid=job_uuid,
|
||||||
status=JobStatus.scheduled,
|
status=JobStatus.scheduled,
|
||||||
scheduled_at=datetime.now(),
|
scheduled_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
self.jobs[job_uuid] = job_status_response
|
self.jobs[job_uuid] = job_status_response
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ class TorchtunePostTrainingImpl:
|
||||||
)
|
)
|
||||||
|
|
||||||
job_status_response.status = JobStatus.in_progress
|
job_status_response.status = JobStatus.in_progress
|
||||||
job_status_response.started_at = datetime.now()
|
job_status_response.started_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
await recipe.setup()
|
await recipe.setup()
|
||||||
resources_allocated, checkpoints = await recipe.train()
|
resources_allocated, checkpoints = await recipe.train()
|
||||||
|
@ -93,7 +93,7 @@ class TorchtunePostTrainingImpl:
|
||||||
job_status_response.resources_allocated = resources_allocated
|
job_status_response.resources_allocated = resources_allocated
|
||||||
job_status_response.checkpoints = checkpoints
|
job_status_response.checkpoints = checkpoints
|
||||||
job_status_response.status = JobStatus.completed
|
job_status_response.status = JobStatus.completed
|
||||||
job_status_response.completed_at = datetime.now()
|
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
job_status_response.status = JobStatus.failed
|
job_status_response.status = JobStatus.failed
|
||||||
|
|
|
@ -8,7 +8,7 @@ import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
@ -532,7 +532,7 @@ class LoraFinetuningSingleDevice:
|
||||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(timezone.utc),
|
||||||
epoch=curr_epoch,
|
epoch=curr_epoch,
|
||||||
post_training_job_id=self.job_uuid,
|
post_training_job_id=self.job_uuid,
|
||||||
path=checkpoint_path,
|
path=checkpoint_path,
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from opentelemetry.sdk.trace import ReadableSpan
|
from opentelemetry.sdk.trace import ReadableSpan
|
||||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||||
|
@ -34,7 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
if span.attributes and span.attributes.get("__autotraced__"):
|
if span.attributes and span.attributes.get("__autotraced__"):
|
||||||
return
|
return
|
||||||
|
|
||||||
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||||
|
@ -46,7 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
if span.attributes and span.attributes.get("__autotraced__"):
|
if span.attributes and span.attributes.get("__autotraced__"):
|
||||||
return
|
return
|
||||||
|
|
||||||
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||||
|
|
||||||
span_context = (
|
span_context = (
|
||||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||||
|
@ -74,7 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||||
|
|
||||||
for event in span.events:
|
for event in span.events:
|
||||||
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||||
|
|
||||||
severity = event.attributes.get("severity", "info")
|
severity = event.attributes.get("severity", "info")
|
||||||
message = event.attributes.get("message", event.name)
|
message = event.attributes.get("message", event.name)
|
||||||
|
|
|
@ -8,7 +8,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from opentelemetry.sdk.trace import SpanProcessor
|
from opentelemetry.sdk.trace import SpanProcessor
|
||||||
from opentelemetry.trace import Span
|
from opentelemetry.trace import Span
|
||||||
|
@ -124,8 +124,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
trace_id,
|
trace_id,
|
||||||
service_name,
|
service_name,
|
||||||
(span_id if not parent_span_id else None),
|
(span_id if not parent_span_id else None),
|
||||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -143,8 +143,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
trace_id,
|
trace_id,
|
||||||
parent_span_id,
|
parent_span_id,
|
||||||
span.name,
|
span.name,
|
||||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||||
json.dumps(dict(span.attributes)),
|
json.dumps(dict(span.attributes)),
|
||||||
span.status.status_code.name,
|
span.status.status_code.name,
|
||||||
span.kind.name,
|
span.kind.name,
|
||||||
|
@ -161,7 +161,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
(
|
(
|
||||||
span_id,
|
span_id,
|
||||||
event.name,
|
event.name,
|
||||||
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
|
datetime.fromtimestamp(event.timestamp / 1e9, timezone.utc).isoformat(),
|
||||||
json.dumps(dict(event.attributes)),
|
json.dumps(dict(event.attributes)),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -168,7 +168,7 @@ def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||||
image_paths = []
|
image_paths = []
|
||||||
for i, img in enumerate(images):
|
for i, img in enumerate(images):
|
||||||
# create new directory for each day to better organize data:
|
# create new directory for each day to better organize data:
|
||||||
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
dump_dname = datetime.today().strftime("%Y-%m-%d") # noqa: DTZ002 - we don't care about timezones here since we are displaying the date
|
||||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||||
# save image into a file
|
# save image into a file
|
||||||
|
|
|
@ -11,7 +11,7 @@ import logging
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class TraceContext:
|
||||||
span_id=generate_short_uuid(),
|
span_id=generate_short_uuid(),
|
||||||
trace_id=self.trace_id,
|
trace_id=self.trace_id,
|
||||||
name=name,
|
name=name,
|
||||||
start_time=datetime.now(),
|
start_time=datetime.now(timezone.utc),
|
||||||
parent_span_id=current_span.span_id if current_span else None,
|
parent_span_id=current_span.span_id if current_span else None,
|
||||||
attributes=attributes,
|
attributes=attributes,
|
||||||
)
|
)
|
||||||
|
@ -203,7 +203,7 @@ class TelemetryHandler(logging.Handler):
|
||||||
UnstructuredLogEvent(
|
UnstructuredLogEvent(
|
||||||
trace_id=span.trace_id,
|
trace_id=span.trace_id,
|
||||||
span_id=span.span_id,
|
span_id=span.span_id,
|
||||||
timestamp=datetime.now(),
|
timestamp=datetime.now(timezone.utc),
|
||||||
message=self.format(record),
|
message=self.format(record),
|
||||||
severity=severity(record.levelname),
|
severity=severity(record.levelname),
|
||||||
)
|
)
|
||||||
|
|
|
@ -124,14 +124,15 @@ exclude = [
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
"B", # flake8-bugbear
|
"B", # flake8-bugbear
|
||||||
"B9", # flake8-bugbear subset
|
"B9", # flake8-bugbear subset
|
||||||
"C", # comprehensions
|
"C", # comprehensions
|
||||||
"E", # pycodestyle
|
"E", # pycodestyle
|
||||||
"F", # Pyflakes
|
"F", # Pyflakes
|
||||||
"N", # Naming
|
"N", # Naming
|
||||||
"W", # Warnings
|
"W", # Warnings
|
||||||
"I", # isort
|
"I", # isort
|
||||||
|
"DTZ", # datetime rules
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
# The following ignores are desired by the project maintainers.
|
# The following ignores are desired by the project maintainers.
|
||||||
|
@ -145,6 +146,10 @@ ignore = [
|
||||||
"C901", # Complexity of the function is too high
|
"C901", # Complexity of the function is too high
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Ignore the following errors for the following files
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/**/*.py" = ["DTZ"] # Ignore datetime rules for tests
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
mypy_path = ["llama_stack"]
|
mypy_path = ["llama_stack"]
|
||||||
packages = ["llama_stack"]
|
packages = ["llama_stack"]
|
||||||
|
@ -170,6 +175,7 @@ exclude = [
|
||||||
"^llama_stack/apis/inspect/inspect\\.py$",
|
"^llama_stack/apis/inspect/inspect\\.py$",
|
||||||
"^llama_stack/apis/models/models\\.py$",
|
"^llama_stack/apis/models/models\\.py$",
|
||||||
"^llama_stack/apis/post_training/post_training\\.py$",
|
"^llama_stack/apis/post_training/post_training\\.py$",
|
||||||
|
"^llama_stack/apis/providers/providers\\.py$",
|
||||||
"^llama_stack/apis/resource\\.py$",
|
"^llama_stack/apis/resource\\.py$",
|
||||||
"^llama_stack/apis/safety/safety\\.py$",
|
"^llama_stack/apis/safety/safety\\.py$",
|
||||||
"^llama_stack/apis/scoring/scoring\\.py$",
|
"^llama_stack/apis/scoring/scoring\\.py$",
|
||||||
|
|
19
scripts/unit-tests.sh
Executable file
19
scripts/unit-tests.sh
Executable file
|
@ -0,0 +1,19 @@
|
||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
PYTHON_VERSION=${PYTHON_VERSION:-3.10}
|
||||||
|
|
||||||
|
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
|
||||||
|
|
||||||
|
uv python find $PYTHON_VERSION
|
||||||
|
FOUND_PYTHON=$?
|
||||||
|
if [ $FOUND_PYTHON -ne 0 ]; then
|
||||||
|
uv python install $PYTHON_VERSION
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv run --python $PYTHON_VERSION --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest -s -v tests/unit/ $@
|
|
@ -52,6 +52,8 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request):
|
||||||
|
|
||||||
If --record-responses is passed, it will call the real APIs and record the responses.
|
If --record-responses is passed, it will call the real APIs and record the responses.
|
||||||
"""
|
"""
|
||||||
|
# TODO: will rework this to be more stable
|
||||||
|
return llama_stack_client
|
||||||
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"llama_stack_client_with_mocked_inference is not supported for this client, returning original client without mocking"
|
"llama_stack_client_with_mocked_inference is not supported for this client, returning original client without mocking"
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -36,7 +36,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": {
|
"image": {
|
||||||
"url": {
|
"url": {
|
||||||
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/api/inference/dog.png"
|
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -65,7 +65,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": {
|
"image": {
|
||||||
"url": {
|
"url": {
|
||||||
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/api/inference/dog.png"
|
"uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
5
tests/integration/providers/__init__.py
Normal file
5
tests/integration/providers/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
17
tests/integration/providers/test_providers.py
Normal file
17
tests/integration/providers/test_providers.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviders:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
def test_list(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
provider_list = llama_stack_client.providers.list()
|
||||||
|
assert provider_list is not None
|
Loading…
Add table
Add a link
Reference in a new issue