Compare commits

...
Sign in to create a new pull request.

111 commits

Author SHA1 Message Date
b174effe05
fix security update
All checks were successful
Build and Push playground container / build-playground (push) Successful in 1m57s
Build and Push container / build (push) Successful in 4m33s
2025-06-03 20:07:06 +02:00
8943b283e9
fix install
All checks were successful
Build and Push playground container / build-playground (push) Successful in 2m2s
Build and Push container / build (push) Successful in 4m7s
2025-06-02 03:02:25 +02:00
08905fc937
add requirements
Some checks failed
Build and Push playground container / build-playground (push) Failing after 35s
Build and Push container / build (push) Has been cancelled
2025-06-02 03:01:15 +02:00
8b5b1c937b
update ui command
All checks were successful
Build and Push playground container / build-playground (push) Successful in 1m47s
Build and Push container / build (push) Successful in 4m9s
2025-06-02 02:54:55 +02:00
205fc2cbd1
include all
All checks were successful
Build and Push playground container / build-playground (push) Successful in 1m42s
Build and Push container / build (push) Successful in 4m6s
2025-06-02 02:49:54 +02:00
4a122bbaca
use own code 2025-06-02 02:49:45 +02:00
a77b554bcf
update requiements
All checks were successful
Build and Push playground container / build-playground (push) Successful in 2m7s
Build and Push container / build (push) Successful in 4m32s
2025-06-02 02:34:19 +02:00
51816af52e
use env file
All checks were successful
Build and Push playground container / build-playground (push) Successful in 1m6s
Build and Push container / build (push) Successful in 4m22s
2025-06-02 01:39:17 +02:00
96003b55de
use auth for kvant
All checks were successful
Build and Push playground container / build-playground (push) Successful in 1m6s
Build and Push container / build (push) Successful in 4m27s
2025-06-02 01:23:33 +02:00
3bde47e562
add keycloak auth to playground ui
All checks were successful
Build and Push playground container / build-playground (push) Successful in 2m0s
Build and Push container / build (push) Successful in 4m11s
2025-06-01 22:23:49 +02:00
ed31462499
ci tag
All checks were successful
Build and Push playground container / build-playground (push) Successful in 1m5s
Build and Push container / build (push) Successful in 4m13s
2025-06-01 13:38:22 +02:00
43a7713140
use raw tag
Some checks failed
Build and Push playground container / build-playground (push) Failing after 27s
Build and Push container / build (push) Failing after 3m21s
2025-06-01 13:23:17 +02:00
ad9860c312
fix ci
Some checks failed
Build and Push playground container / build-playground (push) Failing after 24s
Build and Push container / build (push) Failing after 3m32s
2025-06-01 13:05:50 +02:00
9b70e01c99
add local registry
Some checks failed
Build and Push playground container / build-playground (push) Successful in 1m9s
Build and Push container / build (push) Failing after 1m9s
2025-06-01 13:01:23 +02:00
7bba685dee
add scripts
Some checks failed
Build and Push container / build (push) Failing after 1m4s
Build and Push playground container / build-playground (push) Successful in 1m4s
2025-06-01 12:43:43 +02:00
4603206065
ci
Some checks failed
Build and Push playground container / build-playground (push) Successful in 1m6s
Build and Push container / build (push) Failing after 1m41s
2025-06-01 12:13:57 +02:00
16abfaeb69
build playground 2025-06-01 12:13:57 +02:00
b2ac7f69cc
add responses_store 2025-06-01 12:13:57 +02:00
00fc43ae96
do not push twice 2025-06-01 12:13:57 +02:00
65936f7933
wip 2025-06-01 12:13:57 +02:00
226e443e03
wip 2025-06-01 12:13:57 +02:00
5b057d60ee
wip 2025-06-01 12:13:57 +02:00
95a56b62a0
wip 2025-06-01 12:13:57 +02:00
c642ea2dd5
wip 2025-06-01 12:13:57 +02:00
7e1725f72b
install uvx 2025-06-01 12:13:57 +02:00
b414fe5566
add kvant 2025-06-01 12:13:57 +02:00
cfa38bd13b
add kvant 2025-06-01 12:13:57 +02:00
Hardik Shah
b21050935e
feat: New OpenAI compat embeddings API (#2314)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 4s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Integration Tests / test-matrix (library, providers) (push) Failing after 7s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 7s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 9s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.10) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m11s
# What does this PR do?
Adds a new endpoint that is compatible with OpenAI for embeddings api. 
`/openai/v1/embeddings`
Added providers for OpenAI, LiteLLM and SentenceTransformer. 


## Test Plan
```
LLAMA_STACK_CONFIG=http://localhost:8321 pytest -sv tests/integration/inference/test_openai_embeddings.py --embedding-model all-MiniLM-L6-v2,text-embedding-3-small,gemini/text-embedding-004
```
2025-05-31 22:11:47 -07:00
Ben Browning
277f8690ef
fix: Responses streaming tools don't concatenate None and str (#2326)
# What does this PR do?

This adds a check to ensure we don't attempt to concatenate `None + str`
or `str + None` when building up our arguments for streaming tool calls
in the Responses API.

## Test Plan

All existing tests pass with this change.

Unit tests:

```
python -m pytest -s -v \
  tests/unit/providers/agents/meta_reference/test_openai_responses.py
```

Integration tests:

```
llama stack run llama_stack/templates/together/run.yaml

LLAMA_STACK_CONFIG=http://localhost:8321 \
python -m pytest -s -v \
  tests/integration/agents/test_openai_responses.py \
  --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```

Verification tests:

```
llama stack run llama_stack/templates/together/run.yaml

pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
  --base-url=http://localhost:8321/v1/openai/v1 \
  --model meta-llama/Llama-4-Scout-17B-16E-Instruct
```

Additionally, the manual example using Codex CLI from #2325 now succeeds
instead of throwing a 500 error.

Closes #2325

Signed-off-by: Ben Browning <bbrownin@redhat.com>
2025-05-31 18:24:04 -07:00
Francisco Arceo
f328436831
feat: Enable ingestion of precomputed embeddings (#2317) 2025-05-31 04:03:37 -06:00
Francisco Arceo
31ce208bda
fix: Fix requirements from broken github-actions[bot] (#2323) 2025-05-30 19:05:47 -07:00
github-actions[bot]
ad15276da1 build: Bump version to 0.2.9 2025-05-30 19:43:09 +00:00
ehhuang
2603f10f95
feat: support postgresql inference store (#2310)
# What does this PR do?
* Added support postgresql inference store
* Added 'oracle' template that demos how to config postgresql stores
(except for telemetry, which is not supported currently)


## Test Plan

llama stack build --template oracle --image-type conda --run
LLAMA_STACK_CONFIG=http://localhost:8321 pytest -s -v tests/integration/
--text-model accounts/fireworks/models/llama-v3p3-70b-instruct -k
'inference_store'
2025-05-29 14:33:09 -07:00
Jorge Piedrahita Ortiz
168c7113df
fix(providers): update sambanova json schema mode (#2306)
# What does this PR do?
Updates sambanova inference to use strict as false in json_schema
structured output

## Test Plan
pytest -s -v tests/integration/inference/test_text_inference.py
--stack-config=sambanova
--text-model=sambanova/Meta-Llama-3.3-70B-Instruct
2025-05-29 09:54:23 -07:00
Mark Campbell
f0d8ceb242
chore: fix flaky distro_codegen script (#2305)
# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
Adds an import for all of the template modules before the executor to
prevent deadlock
<!-- If resolving an issue, uncomment and update the line below -->
Closes #2278

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
```
# Run the pre-commit multiple times and verify the deadlock doesn't occur
for i in {1..10}; do pre-commit run --all-files; done
```
2025-05-29 09:53:45 -07:00
Ashwin Bharambe
bfdd15d1fa
fix(responses): use input, not original_input when storing the Response (#2300)
We must store the full (re-hydrated) input not just the original input
in the Response object. Of course, this is not very space efficient and
we should likely find a better storage scheme so that we can only store
unique entries in the database and then re-hydrate them efficiently
later. But that can be done safely later.

Closes https://github.com/meta-llama/llama-stack/issues/2299

## Test Plan

Unit test
2025-05-28 13:17:48 -07:00
Michael Dawson
a654467552
feat: add cpu/cuda config for prompt guard (#2194)
# What does this PR do?
Previously prompt guard was hard coded to require cuda which prevented
it from being used on an instance without a cuda support.

This PR allows prompt guard to be configured to use either cpu or cuda.

[//]: # (If resolving an issue, uncomment and update the line below)
Closes [#2133](https://github.com/meta-llama/llama-stack/issues/2133)

## Test Plan (Edited after incorporating suggestion)
1) started stack configured with prompt guard as follows on a system
without a GPU
and validated prompt guard could be used through the APIs

2) validated on a system with a gpu (but without llama stack) that the
python selecting between cpu and cuda support returned the right value
when a cuda device was available.

3) ran the unit tests as per -
https://github.com/meta-llama/llama-stack/blob/main/tests/unit/README.md

[//]: # (## Documentation)

---------

Signed-off-by: Michael Dawson <mdawson@devrus.com>
2025-05-28 12:23:15 -07:00
Sébastien Han
63a9f08c9e
chore: use starlette built-in Route class (#2267)
# What does this PR do?

Use a more common pattern and known terminology from the ecosystem,
where Route is more approved than Endpoint.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-28 09:53:33 -07:00
ehhuang
56e5ddb39f
feat(ui): add views for Responses (#2293)
# What does this PR do?
* Add responses list and detail views
* Refactored components to be shared as much as possible between chat
completions and responses

## Test Plan
<img width="2014" alt="image"
src="https://github.com/user-attachments/assets/6dee12ea-8876-4351-a6eb-2338058466ef"
/>
<img width="2021" alt="image"
src="https://github.com/user-attachments/assets/6c7c71b8-25b7-4199-9c57-6960be5580c8"
/>

added tests
2025-05-28 09:51:22 -07:00
Sébastien Han
6352078e4b
chore: use groups when running commands (#2298)
# What does this PR do?

Followup of https://github.com/meta-llama/llama-stack/pull/2287. We must
use `--group` when running commands with uv.

<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-28 09:13:16 -07:00
Charlie Doern
a7ecc92be1
docs: add post training to providers list (#2280)
# What does this PR do?

the providers list is missing post_training. Add that column and
`HuggingFace`, `TorchTune`, and `NVIDIA NEMO` as supported providers.

also point to these providers in docs/source/providers/index.md, and
describe basic functionality

There are other missing provider types here as well, but starting with
this

Signed-off-by: Charlie Doern <cdoern@redhat.com>
Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
2025-05-28 09:32:00 -04:00
raghotham
9b7f9db05c
fix: build docs without requirements.txt (#2294)
Following the instructions here
https://docs.readthedocs.com/platform/stable/build-customization.html#install-dependencies-with-uv
as per
https://github.com/meta-llama/llama-stack/pull/2223#issuecomment-2914315408
2025-05-27 16:27:57 -07:00
ehhuang
0b695538af
fix: chat completion with more than one choice (#2288)
# What does this PR do?
Fix a bug in openai_compat where choices are not indexed correctly.

## Test Plan
Added a new test.

Rerun the failed inference_store tests:
llama stack run fireworks --image-type conda
pytest -s -v tests/integration/ --stack-config http://localhost:8321 -k
'test_inference_store' --text-model meta-llama/Llama-3.3-70B-Instruct
--count 10
2025-05-27 15:39:15 -07:00
ehhuang
1d46f3102e
fix: enable test_responses_store (#2290)
# What does this PR do?
Changed the test to not require tool_call in output, but still keeping
the tools params there as a smoke test.

## Test Plan
Used llama3.3 from fireworks (same as CI)
<img width="1433" alt="image"
src="https://github.com/user-attachments/assets/1e5fca98-9b4f-402e-a0bc-d9f910f2c207"
/>

Run with ollama distro and 3b model.
2025-05-27 15:37:28 -07:00
Sébastien Han
4f3f28f718
chore: use dependency-groups for dev (#2287)
# What does this PR do?

The previous `[project.optional-dependencies]` was misrepresenting what
the packages were. They were NOT optional dependencies to the project
but development dependencies. Unlike optional dependencies, development
dependencies are local-only and will not be included in the project
requirements when published to PyPI or other indexes. As such,
development dependencies are not included in the [project] table.
Additionally, the dev group is synced by default.

Source:

https://docs.astral.sh/uv/concepts/projects/dependencies/#development-dependencies

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-27 23:00:17 +02:00
Sébastien Han
484abe3116
chore: bump uv version (#2289)
# What does this PR do?

To match the one used by the release bot.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-27 13:44:27 -07:00
github-actions[bot]
7105a25b0f build: Bump version to 0.2.8 2025-05-27 20:28:29 +00:00
Ashwin Bharambe
5cdb29758a
feat(responses): add output_text delta events to responses (#2265)
This adds initial streaming support to the Responses API. 

This PR makes sure that the _first_ inference call made to chat
completions streams out.

There's more to be done:
 - tool call output tokens need to stream out when possible
- we need to loop through multiple rounds of inference and they all need
to stream out.

## Test Plan

Added a test. Executed as:

```
FIREWORKS_API_KEY=... \
  pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
  --provider=stack:fireworks --model meta-llama/Llama-4-Scout-17B-16E-Instruct
```

Then, started a llama stack fireworks distro and tested against it like
this:

```
OPENAI_API_KEY=blah \
   pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
   --base-url http://localhost:8321/v1/openai/v1 \
  --model meta-llama/Llama-4-Scout-17B-16E-Instruct 
```
2025-05-27 13:07:14 -07:00
Sébastien Han
6ee319ae08
fix: convert boolean string to boolean (#2284)
# What does this PR do?

Handles the case where the vllm config `tls_verify` is set to `false` or
`true`.

Closes: https://github.com/meta-llama/llama-stack/issues/2283

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-27 13:05:38 -07:00
Sébastien Han
a8f75d3897
chore: remove dependencies.json (#2281)
# What does this PR do?
It's not used anywhere in the build process. Ancient artifact from an
old attempt of using sub packages to build distros.

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

N/A

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-27 10:26:57 -07:00
Mark Campbell
e7e9ec0379
chore: fix visible comments in pr template (#2279)
# What does this PR do?
This PR adds updated comments for the PR template as comments were
showing up in PRs when they were not meant to
2025-05-27 15:42:33 +02:00
Mark Campbell
b2adaa3f60
docs: fix evals notebook preview (#2277)
# What does this PR do?
Fixes the preview of the Evals Benchmark Notebook

## Explanation 
I took the original notebook, opened it in Google Colab and downloaded
it again from Colab.
I then replaced the original with the new fixed version 
cc: @leseb 

Closes #2142 

## Test Plan
You can view the nb preview from my fork
https://github.com/Bobbins228/llama-stack/blob/fix-evals-nb/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb
2025-05-27 15:18:20 +02:00
Sébastien Han
448f00903d
chore: mark blobpath as optional (#2271)
# What does this PR do?

This is not a core dependency of the distro server. It's only necessary
when using `inline::rag-runtime` or `inline::meta-reference` providers.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-27 10:55:24 +02:00
Ignas Baranauskas
28930cdab6
fix: handle None external_providers_dir in build with run arg (#2269)
# What does this PR do?
Fixes an issue where running `llama stack build --template ollama
--image-type venv --run` fails with a TypeError when validating external
providers directory paths.

The error occurs because `os.path.exists()` is called with `Path(None)`
instead of converting it to a string first. This change ensures
consistent handling of `None` values for `external_providers_dir` across
both build and
[run](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/cli/stack/run.py#L134)
commands by using `str()` conversion before path validation.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```bash
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template ollama --image-type venv --run
```
Command completes successfully without TypeError

[//]: # (## Documentation)
2025-05-27 09:41:12 +02:00
Ashwin Bharambe
7504c2f430
test: disable test_inference_store test urrrggg (#2273) 2025-05-26 22:48:41 -07:00
Ashwin Bharambe
51e6f529f3
fix: index non-MCP toolgroups at registration time (#2272)
Two somewhat annoying fixes: 

- we are going to index tools for non-MCP toolgroups always (like we
used to do). because there are just random assumptions in our tests,
etc. and I don't want to fix them right now
- we need to handle the funny case of toolgroups like
`builtin::rag/knowledge_search` where we added the tool name to use in
the toolgroup itself.
2025-05-26 20:33:36 -07:00
Sébastien Han
39b33a3b01
chore: allow to pass CA cert to remote vllm (#2266)
# What does this PR do?

The `tls_verify` can now receive a path to a certificate file if the
endpoint requires it.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-26 20:59:03 +02:00
Sébastien Han
7710b2f43b
chore: removed unused class (#2268)
Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-26 08:41:37 -07:00
Ashwin Bharambe
9623d5d230
fix: match mcp headers in provider data to Responses API shape (#2263) 2025-05-25 14:33:10 -07:00
Ashwin Bharambe
ce33d02443
fix(tools): do not index tools, only index toolgroups (#2261)
When registering a MCP endpoint, we cannot list tools (like we used to)
since the MCP endpoint may be behind an auth wall. Registration can
happen much sooner (via run.yaml).

Instead, we do listing only when the _user_ actually calls listing.
Furthermore, we cache the list in-memory in the server. Currently, the
cache is not invalidated -- we may want to periodically re-list for MCP
servers. Note that they must call `list_tools` before calling
`invoke_tool` -- we use this critically.

This will enable us to list MCP servers in run.yaml

## Test Plan

Existing tests, updated tests accordingly.
2025-05-25 13:27:52 -07:00
raghotham
5a422e236c
chore: make cprint write to stderr (#2250)
Also do sys.exit(1) in case of errors
2025-05-24 23:39:57 -07:00
raghotham
c25bd0ad58
fix: use pypi browser agent (#2260)
Getting this error from pypi of late

```
'python-requests/2.32.3 User-Agents are currently blocked from accessing JSON release resources. A cluster is apparently crawling all project/release resources resulting in excess cache misses. Please contact admin@pypi.org if you have information regarding what this software may be.'
```
2025-05-24 23:26:30 -07:00
Ashwin Bharambe
298721c238
chore: split routing_tables into individual files (#2259) 2025-05-24 23:15:05 -07:00
Ashwin Bharambe
eedf21f19c
chore: split routers into individual files (inference, tool, vector_io, eval_scoring) (#2258) 2025-05-24 22:59:07 -07:00
Ashwin Bharambe
ae7272d8ff
chore: split routers into individual files (datasets) (#2249) 2025-05-24 22:11:43 -07:00
Ashwin Bharambe
a2160dc0af
chore: split routers into individual files (safety)
Reviewers:
bbrowning, leseb, ehhuang, terrytangyuan, raghotham, yanxi0830, hardikjshah

Reviewed By: raghotham

Pull Request: https://github.com/meta-llama/llama-stack/pull/2248
2025-05-24 22:00:32 -07:00
Ashwin Bharambe
c290999c63
fix(telemetry): get rid of annoying sqlite span export error (#2245) 2025-05-24 20:24:34 -07:00
Ashwin Bharambe
3faf1e4a79
feat: enable MCP execution in Responses impl (#2240)
## Test Plan

```
pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
  --provider=stack:together --model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
2025-05-24 14:20:42 -07:00
Ashwin Bharambe
66f09f24ed
fix: disable test_responses_store (#2244)
The test depends on llama's tool calling ability. In the CI, we run with
a small ollama model.

The fix might be to check for either message or function_call because
the model is flaky and we aren't really testing that behavior?
2025-05-24 08:18:06 -07:00
raghotham
84751f3e55
fix: skip failing tests (#2243)
as title. trying release 0.2.8
2025-05-24 07:31:08 -07:00
Yuan Tang
a411029d7e
docs: Update CHANGELOG.md (#2241)
# What does this PR do?

This PR adds release notes for recent releases.

---------

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
2025-05-24 07:06:36 -07:00
ehhuang
15b0a67555
feat: add responses input items api (#2239)
# What does this PR do?
TSIA

## Test Plan
added integration and unit tests
2025-05-24 07:05:53 -07:00
Yuan Tang
055f48b6a2
fix(security): Upgrade setuptools to v80.8.0. Fixes CVE-2025-47273 (#2242)
# What does this PR do?

This fixes a high vulnerable CVE in `setuptools`:
https://github.com/advisories/GHSA-5rjg-fvgr-3xxf

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
2025-05-24 06:57:24 -07:00
ehhuang
ca65617a71
feat: start ui server in llama stack run (#2170)
# What does this PR do?
TSIA
`--enable-ui` to enable


## Test Plan
`llama stack run dev --image-type conda --enable-ui`
`localhost:8322` shows UI


llama stack run dev --image-type conda
`localhost:8322` does not work
2025-05-23 20:00:09 -07:00
ehhuang
5844c2da68
feat: add list responses API (#2233)
# What does this PR do?
This is not part of the official OpenAI API, but we'll use this for the
logs UI.
In order to support more filtering options, I'm adopting the newly
introduced sql store in in place of the kv store.

## Test Plan
Added integration/unit tests.
2025-05-23 13:16:48 -07:00
Ashwin Bharambe
6463ee7633
feat: allow using llama-stack-library-client from verifications (#2238)
Having to run (and re-run) a server while running verifications can be
annoying while you are iterating on code. This makes it so you can use
the library client -- and because it is OpenAI client compatible, it all
works.

## Test Plan

```
pytest -s -v tests/verifications/openai_api/test_responses.py \
   --provider=stack:together \
   --model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
2025-05-23 11:43:41 -07:00
Ashwin Bharambe
558d109ab7
fix: signature change to match OpenAI SDK (#2237) 2025-05-23 10:59:30 -07:00
ehhuang
b054023800
chore: add sqlalchemy to test dependencies (#2236)
# What does this PR do?


## Test Plan
2025-05-23 10:33:38 -07:00
Ashwin Bharambe
51945f1e57
feat: accept MCP authorization headers for MCP toolgroups (#2230)
The most interesting MCP servers are those with an authorization wall in
front of them. This PR uses the existing `provider_data` mechanism of
passing provider API keys for passing MCP access tokens (in fact,
arbitrary headers in the style of the OpenAI Responses API) from the
client through to the MCP server.

```
class MCPProviderDataValidator(BaseModel):
    # mcp_endpoint => list of headers to send
    mcp_headers: dict[str, list[str]] | None = None
```

Note how we must stuff the headers for all MCP endpoints into a single
"MCPProviderDataValidator". Unlike existing providers (e.g., Together
and Fireworks for inference) where we could name the provider api keys
clearly (`together_api_key`, `fireworks_api_key`), we cannot name these
keys for MCP. We have a single generic MCP provider which can serve
multiple "toolgroups". So we use a dict to combine all the headers for
all MCP endpoints you may want to use in an agentic call.


## Test Plan

See the added integration test for usage.
2025-05-23 08:52:18 -07:00
ehhuang
2708312168
feat(ui): implement chat completion views (#2201)
# What does this PR do?
 Implements table and detail views for chat completions

<img width="1548" alt="image"
src="https://github.com/user-attachments/assets/01061b7f-0d47-4b3b-b5ac-2df8f9035ef6"
/>
<img width="1549" alt="image"
src="https://github.com/user-attachments/assets/738d8612-8258-4c2c-858b-bee39030649f"
/>


## Test Plan
npm run test
2025-05-22 22:05:54 -07:00
Ashwin Bharambe
d8c6ab9bfc
feat: add MCP tool signature to Responses API (#2232) 2025-05-22 16:43:08 -07:00
ehhuang
8feb1827c8
fix: openai provider model id (#2229)
# What does this PR do?
Since https://github.com/meta-llama/llama-stack/pull/2193 switched to
openai sdk, we need to strip 'openai/' from the model_id


## Test Plan
start server with openai provider and send a chat completion call
2025-05-22 14:51:01 -07:00
ehhuang
549812f51e
feat: implement get chat completions APIs (#2200)
# What does this PR do?
* Provide sqlite implementation of the APIs introduced in
https://github.com/meta-llama/llama-stack/pull/2145.
* Introduced a SqlStore API: llama_stack/providers/utils/sqlstore/api.py
and the first Sqlite implementation
* Pagination support will be added in a future PR.

## Test Plan
Unit test on sql store:
<img width="1005" alt="image"
src="https://github.com/user-attachments/assets/9b8b7ec8-632b-4667-8127-5583426b2e29"
/>


Integration test:
```
INFERENCE_MODEL="llama3.2:3b-instruct-fp16" llama stack build --template ollama --image-type conda --run
```
```
LLAMA_STACK_CONFIG=http://localhost:5001 INFERENCE_MODEL="llama3.2:3b-instruct-fp16" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-fp16" -k 'inference_store and openai'
```
2025-05-21 22:21:52 -07:00
Jorge Piedrahita Ortiz
633bb9c5b3
feat(providers): sambanova safety provider (#2221)
# What does this PR do?

Includes SambaNova safety adaptor to use the sambanova cloud served
Meta-Llama-Guard-3-8B
minor updates in sambanova docs

## Test Plan
pytest -s -v tests/integration/safety/test_safety.py
--stack-config=sambanova --safety-shield=sambanova/Meta-Llama-Guard-3-8B
2025-05-21 15:33:02 -07:00
Sébastien Han
02e5e8a633
fix: only print routes that match the runtime config (#2226)
# What does this PR do?

We now only print the 'active' routes, not all the possible routes. This
is based on the distribution server config by looking at enabled APIs
and their respective providers.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-21 15:30:29 -07:00
Sébastien Han
37f1e8a7f7
fix: use proper service account for kube auth (#2227)
# What does this PR do?

Not sure why it passed CI earlier...

Strange only 24 workflows run here
https://github.com/meta-llama/llama-stack/pull/2216 so the test never
ran...

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-21 15:28:21 -07:00
Varsha
e92301f2d7
feat(sqlite-vec): enable keyword search for sqlite-vec (#1439)
# What does this PR do?
This PR introduces support for keyword based FTS5 search with BM25
relevance scoring. It makes changes to the existing EmbeddingIndex base
class in order to support a search_mode and query_str parameter, that
can be used for keyword based search implementations.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
run 
```
pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto
```
Output:
```
pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto
/Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
====================================================== test session starts =======================================================
platform darwin -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python
cachedir: .pytest_cache
metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.4-arm64-arm-64bit', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0'}}
rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack
configfile: pyproject.toml
plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0
asyncio: mode=auto, asyncio_default_fixture_loop_scope=None
collected 7 items                                                                                                                

llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_fts PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED
```


For reference, with the implementation, the fts table looks like below:
```
Chunk ID: 9fbc39ce-c729-64a2-260f-c5ec9bb2a33e, Content: Sentence 0 from document 0
Chunk ID: 94062914-3e23-44cf-1e50-9e25821ba882, Content: Sentence 1 from document 0
Chunk ID: e6cfd559-4641-33ba-6ce1-7038226495eb, Content: Sentence 2 from document 0
Chunk ID: 1383af9b-f1f0-f417-4de5-65fe9456cc20, Content: Sentence 3 from document 0
Chunk ID: 2db19b1a-de14-353b-f4e1-085e8463361c, Content: Sentence 4 from document 0
Chunk ID: 9faf986a-f028-7714-068a-1c795e8f2598, Content: Sentence 5 from document 0
Chunk ID: ef593ead-5a4a-392f-7ad8-471a50f033e8, Content: Sentence 6 from document 0
Chunk ID: e161950f-021f-7300-4d05-3166738b94cf, Content: Sentence 7 from document 0
Chunk ID: 90610fc4-67c1-e740-f043-709c5978867a, Content: Sentence 8 from document 0
Chunk ID: 97712879-6fff-98ad-0558-e9f42e6b81d3, Content: Sentence 9 from document 0
Chunk ID: aea70411-51df-61ba-d2f0-cb2b5972c210, Content: Sentence 0 from document 1
Chunk ID: b678a463-7b84-92b8-abb2-27e9a1977e3c, Content: Sentence 1 from document 1
Chunk ID: 27bd63da-909c-1606-a109-75bdb9479882, Content: Sentence 2 from document 1
Chunk ID: a2ad49ad-f9be-5372-e0c7-7b0221d0b53e, Content: Sentence 3 from document 1
Chunk ID: cac53bcd-1965-082a-c0f4-ceee7323fc70, Content: Sentence 4 from document 1
```

Query results:
Result 1: Sentence 5 from document 0
Result 2: Sentence 5 from document 1
Result 3: Sentence 5 from document 2

[//]: # (## Documentation)

---------

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
2025-05-21 15:24:24 -04:00
Sébastien Han
85b5f3172b
docs: misc cleanup (#2223)
# What does this PR do?

* remove requirements.txt to use pyproject.toml as the source of truth
* update relevant docs

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-21 17:35:27 +02:00
Sébastien Han
6a62e783b9
chore: refactor workflow writting (#2225)
# What does this PR do?

Use a composite action to avoid similar steps repetitions and
centralization of the defaults.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-21 17:31:14 +02:00
Sébastien Han
1862de4be5
chore: clarify cache_ttl to be key_recheck_period (#2220)
# What does this PR do?

The cache_ttl config value is not in fact tied to the lifetime of any of
the keys, it represents the time interval between for our key cache
refresher.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-21 17:30:23 +02:00
Sébastien Han
c25acedbcd
chore: remove k8s auth in favor of k8s jwks endpoint (#2216)
# What does this PR do?

Kubernetes since 1.20 exposes a JWKS endpoint that we can use with our
recent oauth2 recent implementation.
The CI test has been kept intact for validation.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-21 16:23:54 +02:00
liangwen12year
2890243107
feat(quota): add server‑side per‑client request quotas (requires auth) (#2096)
# What does this PR do?
feat(quota): add server‑side per‑client request quotas (requires auth)
    
Unrestricted usage can lead to runaway costs and fragmented client-side
    workarounds. This commit introduces a native quota mechanism to the
    server, giving operators a unified, centrally managed throttle for
    per-client requests—without needing extra proxies or custom client
logic. This helps contain cloud-compute expenses, enables fine-grained
usage control, and simplifies deployment and monitoring of Llama Stack
services. Quotas are fully opt-in and have no effect unless explicitly
    configured.
    
    Notice that Quotas are fully opt-in and require authentication to be
enabled. The 'sqlite' is the only supported quota `type` at this time,
any other `type` will be rejected. And the only supported `period` is
    'day'.
    
    Highlights:
    
    - Adds `QuotaMiddleware` to enforce per-client request quotas:
      - Uses `Authorization: Bearer <client_id>` (from
        AuthenticationMiddleware)
      - Tracks usage via a SQLite-based KV store
      - Returns 429 when the quota is exceeded
    
    - Extends `ServerConfig` with a `quota` section (type + config)
    
- Enforces strict coupling: quotas require authentication or the server
      will fail to start
    
    Behavior changes:
    - Quotas are disabled by default unless explicitly configured
    - SQLite defaults to `./quotas.db` if no DB path is set
    - The server requires authentication when quotas are enabled
    
    To enable per-client request quotas in `run.yaml`, add:
    ```
    server:
      port: 8321
      auth:
        provider_type: "custom"
        config:
          endpoint: "https://auth.example.com/validate"
      quota:
        type: sqlite
        config:
          db_path: ./quotas.db
          limit:
            max_requests: 1000
            period: day

[//]: # (If resolving an issue, uncomment and update the line below)
Closes #2093

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)

Signed-off-by: Wen Liang <wenliang@redhat.com>
Co-authored-by: Wen Liang <wenliang@redhat.com>
2025-05-21 10:58:45 +02:00
Abhishek koserwal
5a3d777b20
feat: add llama stack rm command (#2127)
# What does this PR do?
[Provide a short summary of what this PR does and why. Link to relevant
issues if applicable.]

```
llama stack rm llamastack-test
```

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])
#225 

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)
2025-05-21 10:25:51 +02:00
grs
091d8c48f2
feat: add additional auth provider that uses oauth token introspection (#2187)
# What does this PR do?

This adds an alternative option to the oauth_token auth provider that
can be used with existing authorization services which support token
introspection as defined in RFC 7662. This could be useful where token
revocation needs to be handled or where opaque tokens (or other non jwt
formatted tokens) are used

## Test Plan
Tested against keycloak

Signed-off-by: Gordon Sim <gsim@redhat.com>
2025-05-20 19:45:11 -07:00
grs
87a4b9cb28
fix: synchronize concurrent coroutines checking & updating key set (#2215)
# What does this PR do?

This PR adds a lock to coordinate concurrent coroutines passing through
the jwt verification. As _refresh_jwks() was setting _jwks to an empty
dict then repopulating it, having multiple coroutines doing this
concurrently risks losing keys. The PR also builds the updated dict as a
separate object and assigns it to _jwks once completed. This avoids
impacting any coroutines using the key set as it is being updated.

Signed-off-by: Gordon Sim <gsim@redhat.com>
2025-05-20 10:00:44 -07:00
Derek Higgins
3339844fda
feat: Add "instructions" support to responses API (#2205)
# What does this PR do?
Add support for "instructions" to the responses API. Instructions
provide a way to swap out system (or developer) messages in new
responses.


## Test Plan
unit tests added

Signed-off-by: Derek Higgins <derekh@redhat.com>
2025-05-20 09:52:10 -07:00
Jash Gulabrai
1a770cf8ac
fix: Pass model parameter as config name to NeMo Customizer (#2218)
# What does this PR do?
When launching a fine-tuning job, an upcoming version of NeMo Customizer
will expect the `config` name to be formatted as
`namespace/name@version`. Here, `config` is a reference to a model +
additional metadata. There could be multiple `config`s that reference
the same base model.

This PR updates NVIDIA's `supervised_fine_tune` to simply pass the
`model` param as-is to NeMo Customizer. Currently, it expects a
specific, allowlisted llama model (i.e. `meta/Llama3.1-8B-Instruct`) and
converts it to the provider format (`meta/llama-3.1-8b-instruct`).

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
From a notebook, I built an image with my changes: 
```
!llama stack build --template nvidia --image-type venv
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient

client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
And could successfully launch a job:
```
response = client.post_training.supervised_fine_tune(
    job_uuid="",
    model="meta/llama-3.2-1b-instruct@v1.0.0+A100", # Model passed as-is to Customimzer
    ...
)

job_id = response.job_uuid
print(f"Created job with ID: {job_id}")

Output:
Created job with ID: cust-Jm4oGmbwcvoufaLU4XkrRU
```

[//]: # (## Documentation)

---------

Co-authored-by: Jash Gulabrai <jgulabrai@nvidia.com>
2025-05-20 09:51:39 -07:00
Sébastien Han
2eae8568e1
chore: collapse all local hook under the same repo (#2217)
Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-20 09:51:09 -07:00
Sébastien Han
3f6368d56c
ci: enable ruff output format for github (#2214)
# What does this PR do?

Update output format to enable automatic inline annotations.

![Screenshot 2025-05-20 at 10 55
38](https://github.com/user-attachments/assets/f943aa00-9b60-4cdb-b434-67b2de8b79f2)

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-20 09:04:03 -07:00
Francisco Arceo
90d7612f5f
chore: Updated readme (#2219)
# What does this PR do?
chore: Updated readme

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
2025-05-20 17:06:20 +02:00
Francisco Arceo
ed7b4731aa
fix: Setting default value for metadata_token_count in case the key is not found (#2199)
# What does this PR do?
If a user has previously serialized data into their vector store without
the `metadata_token_count` in the chunk, the `query` method will fail in
a server error. This fixes that edge case by returning 0 when the key is
not detected. This solution is suboptimal but I think it's better to
understate the token size rather than recalculate it and add unnecessary
complexity to the retrieval code.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
2025-05-20 08:03:22 -04:00
Ben Browning
6d20b720b8
feat: Propagate W3C trace context headers from clients (#2153)
# What does this PR do?

This extracts the W3C trace context headers (traceparent and tracestate)
from incoming requests, stuffs them as attributes on the spans we
create, and uses them within the tracing provider implementation to
actually wrap our spans in the proper context.

What this means in practice is that when a client (such as an OpenAI
client) is instrumented to create these traces, we'll continue that
distributed trace within Llama Stack as opposed to creating our own root
span that breaks the distributed trace between client and server.

It's slightly awkward to do this in Llama Stack because our Tracing API
knows nothing about opentelemetry, W3C trace headers, etc - that's only
knowledge the specific provider implementation has. So, that's why the
trace headers get extracted by in the server code but not actually used
until the provider implementation to form the proper context.

This also centralizes how we were adding the `__root__` and
`__root_span__` attributes, as those two were being added in different
parts of the code instead of from a single place.

Closes #2097

## Test Plan

This was tested manually using the helpful scripts from #2097. I
verified that Llama Stack properly joined the client's span when the
client was instrumented for distributed tracing, and that Llama Stack
properly started its own root span when the incoming request was not
part of an existing trace.

Here's an example of the joined spans:

![Screenshot 2025-05-13 at 8 46
09 AM](https://github.com/user-attachments/assets/dbefda28-9faa-4339-a08d-1441efefc149)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
2025-05-19 18:56:54 -07:00
Sébastien Han
82778ecbb0
fix: remove wrong deprecated warning (#2202)
# What does this PR do?

`--yaml-config` is gone now with
https://github.com/meta-llama/llama-stack/pull/2196.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-05-19 13:02:23 -07:00
Michael Anstis
0cc0731189
fix: Pass external_config_dir to BuildConfig (#2190)
# What does this PR do?

The `external_config_dir` configuration parameter is not being passed to
the `BuildConfig` for `LlamaStackAsLibraryClient`.

This prevents _plugin_ providers from being loaded when `llama-stack` is
uses as a library.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
I ran `LlamaStackAsLibraryClient` with a configuration file that
contained `external_config_dir` and related configuration.

It does not work without this change: _external_ providers are not
resolved.

It does work with this change 👍 

[//]: # (## Documentation)
2025-05-19 14:01:28 +02:00
ehhuang
047303e339
feat: introduce APIs for retrieving chat completion requests (#2145)
# What does this PR do?
This PR introduces APIs to retrieve past chat completion requests, which
will be used in the LS UI.

Our current `Telemetry` is ill-suited for this purpose as it's untyped
so we'd need to filter by obscure attribute names, making it brittle.

Since these APIs are 'provided by stack' and don't need to be
implemented by inference providers, we introduce a new InferenceProvider
class, containing the existing inference protocol, which is implemented
by inference providers.

The APIs are OpenAI-compliant, with an additional `input_messages`
field.


## Test Plan
This PR just adds the API and marks them provided_by_stack. S
tart stack server -> doesn't crash
2025-05-18 21:43:19 -07:00
Ashwin Bharambe
c7015d3d60
feat: introduce OAuth2TokenAuthProvider and notion of "principal" (#2185)
This PR adds a notion of `principal` (aka some kind of persistent
identity) to the authentication infrastructure of the Stack. Until now
we only used access attributes ("claims" in the more standard OAuth /
OIDC setup) but we need the notion of a User fundamentally as well.
(Thanks @rhuss for bringing this up.)

This value is not yet _used_ anywhere downstream but will be used to
segregate access to resources.

In addition, the PR introduces a built-in JWT token validator so the
Stack does not need to contact an authentication provider to validating
the authorization and merely check the signed token for the represented
claims. Public keys are refreshed via the configured JWKS server. This
Auth Provider should overwhelmingly be considered the default given the
seamless integration it offers with OAuth setups.
2025-05-18 17:54:19 -07:00
dependabot[bot]
1341916caf
chore(github-deps): bump astral-sh/setup-uv from 5.4.1 to 6.0.1 (#2197) 2025-05-18 02:09:56 -04:00
Matthew Farrellee
f40693e720
feat: --image-type argument overrides value in --config build.yaml (#2179)
closes #2162

# test plan

run `llama stack build --image-name ollama --image-type
<venv/conda/container> --config llama_stack/templates/ollama/build.yaml`
and verify venv | conda | container are built.
2025-05-16 14:45:41 -07:00
Charlie Doern
f02f7b28c1
feat: add huggingface post_training impl (#2132)
# What does this PR do?


adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a
super popular option for running training jobs. The config allows a user
to specify some key fields such as a model, chat_template, device, etc

the provider comes with one recipe `finetune_single_device` which works
both with and without LoRA.

any model that is a valid HF identifier can be given and the model will
be pulled.

this has been tested so far with CPU and MPS device types, but should be
compatible with CUDA out of the box

The provider processes the given dataset into the proper format,
establishes the various steps per epoch, steps per save, steps per eval,
sets a sane SFTConfig, and runs n_epochs of training

if checkpoint_dir is none, no model is saved. If there is a checkpoint
dir, a model is saved every `save_steps` and at the end of training.


## Test Plan

re-enabled post_training integration test suite with a singular test
that loads the simpleqa dataset:
https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite
model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The
test now uses the llama stack client and the proper post_training API

runs one step with a batch_size of 1. This test runs on CPU on the
Ubuntu runner so it needs to be a small batch and a single step.

[//]: # (## Documentation)

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-05-16 14:41:28 -07:00
Matthew Farrellee
8f9964f46b
fix: update llama stack build --run to use new start_stack.sh signature (#2191)
# What does this PR do?
fixes #2188

## Test Plan
`INFERENCE_MODEL=meta-llama/Llama-3.3-70B-Instruct llama stack build
--image-name ollama --image-type conda --template ollama --run` without
error
2025-05-16 14:32:02 -07:00
Charlie Doern
1ae61e8d5f
fix: replace all instances of --yaml-config with --config (#2196)
# What does this PR do?

start_stack.sh was using --yaml-config which is deprecated.

a bunch of distro docs also mentioned --yaml-config. Replaces all
instances and logic for --yaml-config with --config

resolves #2189

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-05-16 14:31:12 -07:00
351 changed files with 25356 additions and 8965 deletions

View file

@ -1,10 +1,8 @@
# What does this PR do?
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.]
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])
<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->
## Test Plan
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
[//]: # (## Documentation)
<!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->

22
.github/actions/setup-runner/action.yml vendored Normal file
View file

@ -0,0 +1,22 @@
name: Setup runner
description: Prepare a runner for the tests (install uv, python, project dependencies, etc.)
runs:
using: "composite"
steps:
- name: Install uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with:
python-version: "3.10"
activate-environment: true
version: 0.7.6
- name: Install dependencies
shell: bash
run: |
uv sync --all-groups
uv pip install ollama faiss-cpu
# always test against the latest version of the client
# TODO: this is not necessarily a good idea. we need to test against both published and latest
# to find out backwards compatibility issues.
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install -e .

1
.github/workflows/Dockerfile vendored Normal file
View file

@ -0,0 +1 @@
FROM localhost:5000/distribution-kvant:dev

73
.github/workflows/ci-playground.yaml vendored Normal file
View file

@ -0,0 +1,73 @@
name: Build and Push playground container
run-name: Build and Push playground container
on:
workflow_dispatch:
#schedule:
# - cron: "0 10 * * *"
push:
branches:
- main
- kvant
tags:
- 'v*'
pull_request:
branches:
- main
- kvant
env:
IMAGE: git.kvant.cloud/${{github.repository}}-playground
jobs:
build-playground:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set current time
uses: https://github.com/gerred/actions/current-time@master
id: current_time
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to git.kvant.cloud registry
uses: docker/login-action@v3
with:
registry: git.kvant.cloud
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
# list of Docker images to use as base name for tags
images: |
${{env.IMAGE}}
# generate Docker tags based on the following events/attributes
tags: |
type=schedule
type=ref,event=branch
type=ref,event=pr
type=ref,event=tag
type=semver,pattern={{version}}
- name: Build and push to gitea registry
uses: docker/build-push-action@v6
with:
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
context: .
file: llama_stack/distribution/ui/Containerfile
provenance: mode=max
sbom: true
build-args: |
BUILD_DATE=${{ steps.current_time.outputs.time }}
cache-from: |
type=registry,ref=${{ env.IMAGE }}:buildcache
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
type=registry,ref=${{ env.IMAGE }}:main
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true

98
.github/workflows/ci.yaml vendored Normal file
View file

@ -0,0 +1,98 @@
name: Build and Push container
run-name: Build and Push container
on:
workflow_dispatch:
#schedule:
# - cron: "0 10 * * *"
push:
branches:
- main
- kvant
tags:
- 'v*'
pull_request:
branches:
- main
- kvant
env:
IMAGE: git.kvant.cloud/${{github.repository}}
jobs:
build:
runs-on: ubuntu-latest
services:
registry:
image: registry:2
ports:
- 5000:5000
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set current time
uses: https://github.com/gerred/actions/current-time@master
id: current_time
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver-opts: network=host
- name: Login to git.kvant.cloud registry
uses: docker/login-action@v3
with:
registry: git.kvant.cloud
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
# list of Docker images to use as base name for tags
images: |
${{env.IMAGE}}
# generate Docker tags based on the following events/attributes
tags: |
type=schedule
type=ref,event=branch
type=ref,event=pr
type=ref,event=tag
type=semver,pattern={{version}}
- name: Install uv
uses: https://github.com/astral-sh/setup-uv@v5
with:
# Install a specific version of uv.
version: "0.7.8"
- name: Build
env:
USE_COPY_NOT_MOUNT: true
LLAMA_STACK_DIR: .
run: |
uvx --from . llama stack build --template kvant --image-type container
# docker tag distribution-kvant:dev ${{env.IMAGE}}:kvant
# docker push ${{env.IMAGE}}:kvant
docker tag distribution-kvant:dev localhost:5000/distribution-kvant:dev
docker push localhost:5000/distribution-kvant:dev
- name: Build and push to gitea registry
uses: docker/build-push-action@v6
with:
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
context: .github/workflows
provenance: mode=max
sbom: true
build-args: |
BUILD_DATE=${{ steps.current_time.outputs.time }}
cache-from: |
type=registry,ref=${{ env.IMAGE }}:buildcache
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
type=registry,ref=${{ env.IMAGE }}:main
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true

View file

@ -23,23 +23,18 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
auth-provider: [kubernetes]
auth-provider: [oauth2_token]
fail-fast: false # we want to run all tests regardless of failure
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
with:
python-version: "3.10"
activate-environment: true
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Set Up Environment and Install Dependencies
- name: Build Llama Stack
run: |
uv sync --extra dev --extra test
uv pip install -e .
llama stack build --template ollama --image-type venv
- name: Install minikube
@ -47,29 +42,53 @@ jobs:
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
- name: Start minikube
if: ${{ matrix.auth-provider == 'kubernetes' }}
if: ${{ matrix.auth-provider == 'oauth2_token' }}
run: |
minikube start
kubectl get pods -A
- name: Configure Kube Auth
if: ${{ matrix.auth-provider == 'kubernetes' }}
if: ${{ matrix.auth-provider == 'oauth2_token' }}
run: |
kubectl create namespace llama-stack
kubectl create serviceaccount llama-stack-auth -n llama-stack
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
cat <<EOF | kubectl apply -f -
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: allow-anonymous-openid
rules:
- nonResourceURLs: ["/openid/v1/jwks"]
verbs: ["get"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: allow-anonymous-openid
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: allow-anonymous-openid
subjects:
- kind: User
name: system:anonymous
apiGroup: rbac.authorization.k8s.io
EOF
- name: Set Kubernetes Config
if: ${{ matrix.auth-provider == 'kubernetes' }}
if: ${{ matrix.auth-provider == 'oauth2_token' }}
run: |
echo "KUBERNETES_API_SERVER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}')" >> $GITHUB_ENV
echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
- name: Set Kube Auth Config and run server
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
if: ${{ matrix.auth-provider == 'kubernetes' }}
if: ${{ matrix.auth-provider == 'oauth2_token' }}
run: |
run_dir=$(mktemp -d)
cat <<'EOF' > $run_dir/run.yaml
@ -81,10 +100,10 @@ jobs:
port: 8321
EOF
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
cat $run_dir/run.yaml
source .venv/bin/activate
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready

View file

@ -24,7 +24,7 @@ jobs:
matrix:
# Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
client-type: [library, http]
fail-fast: false # we want to run all tests regardless of failure
@ -32,24 +32,14 @@ jobs:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
with:
python-version: "3.10"
activate-environment: true
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Setup ollama
uses: ./.github/actions/setup-ollama
- name: Set Up Environment and Install Dependencies
- name: Build Llama Stack
run: |
uv sync --extra dev --extra test
uv pip install ollama faiss-cpu
# always test against the latest version of the client
# TODO: this is not necessarily a good idea. we need to test against both published and latest
# to find out backwards compatibility issues.
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install -e .
llama stack build --template ollama --image-type venv
- name: Start Llama Stack server in background
@ -57,8 +47,7 @@ jobs:
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: |
source .venv/bin/activate
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
- name: Wait for Llama Stack server to be ready
if: matrix.client-type == 'http'
@ -86,6 +75,12 @@ jobs:
exit 1
fi
- name: Check Storage and Memory Available Before Tests
if: ${{ always() }}
run: |
free -h
df -h
- name: Run Integration Tests
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
@ -95,17 +90,24 @@ jobs:
else
stack_config="http://localhost:8321"
fi
uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
--embedding-model=all-MiniLM-L6-v2
- name: Check Storage and Memory Available After Tests
if: ${{ always() }}
run: |
free -h
df -h
- name: Write ollama logs to file
if: ${{ always() }}
run: |
sudo journalctl -u ollama.service > ollama.log
- name: Upload all logs to artifacts
if: always()
if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}

View file

@ -29,6 +29,7 @@ jobs:
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
env:
SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github
- name: Verify if there are any diff files after pre-commit
run: |

View file

@ -50,21 +50,8 @@ jobs:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Print build dependencies
run: |
@ -79,7 +66,6 @@ jobs:
- name: Print dependencies in the image
if: matrix.image-type == 'venv'
run: |
source test/bin/activate
uv pip list
build-single-provider:
@ -88,21 +74,8 @@ jobs:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Build a single provider
run: |
@ -114,21 +87,8 @@ jobs:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Build a single provider
run: |
@ -152,21 +112,8 @@ jobs:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Pin template to UBI9 base
run: |

View file

@ -25,15 +25,8 @@ jobs:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
with:
python-version: "3.10"
- name: Set Up Environment and Install Dependencies
run: |
uv sync --extra dev --extra test
uv pip install -e .
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Apply image type to config file
run: |
@ -59,7 +52,6 @@ jobs:
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: |
source ci-test/bin/activate
uv run pip list
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &

View file

@ -30,17 +30,11 @@ jobs:
- "3.12"
- "3.13"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: ${{ matrix.python }}
- uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
with:
python-version: ${{ matrix.python }}
enable-cache: false
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Run unit tests
run: |

View file

@ -37,16 +37,8 @@ jobs:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.11'
- name: Install the latest version of uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
- name: Sync with uv
run: uv sync --extra docs
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Build HTML
run: |

2
.gitignore vendored
View file

@ -6,6 +6,7 @@ dev_requirements.txt
build
.DS_Store
llama_stack/configs/*
.cursor/
xcuserdata/
*.hmap
.DS_Store
@ -23,3 +24,4 @@ venv/
pytest-report.xml
.coverage
.python-version
data

View file

@ -53,7 +53,7 @@ repos:
- black==24.3.0
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.6.3
rev: 0.7.8
hooks:
- id: uv-lock
- id: uv-export
@ -61,6 +61,7 @@ repos:
"--frozen",
"--no-hashes",
"--no-emit-project",
"--no-default-groups",
"--output-file=requirements.txt"
]
@ -88,20 +89,17 @@ repos:
- id: distro-codegen
name: Distribution Template Codegen
additional_dependencies:
- uv==0.6.0
entry: uv run --extra codegen ./scripts/distro_codegen.py
- uv==0.7.8
entry: uv run --group codegen ./scripts/distro_codegen.py
language: python
pass_filenames: false
require_serial: true
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
- repo: local
hooks:
- id: openapi-codegen
name: API Spec Codegen
additional_dependencies:
- uv==0.6.2
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
- uv==0.7.8
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
language: python
pass_filenames: false
require_serial: true

View file

@ -5,28 +5,21 @@
# Required
version: 2
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.12"
# You can also specify other tool versions:
# nodejs: "19"
# rust: "1.64"
# golang: "1.19"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements.txt
jobs:
pre_create_environment:
- asdf plugin add uv
- asdf install uv latest
- asdf global uv latest
create_environment:
- uv venv "${READTHEDOCS_VIRTUALENV_PATH}"
install:
- UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs

View file

@ -1,5 +1,26 @@
# Changelog
# v0.2.7
Published on: 2025-05-16T20:38:10Z
## Highlights
This is a small update. But a couple highlights:
* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece.
* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134
* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases.
---
# v0.2.6
Published on: 2025-05-12T18:06:52Z
---
# v0.2.5
Published on: 2025-05-04T20:16:49Z

View file

@ -167,14 +167,11 @@ If you have made changes to a provider's configuration in any form (introducing
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
```bash
cd docs
uv sync --extra docs
# This rebuilds the documentation pages.
uv run make html
uv run --group docs make -C docs/ html
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
uv run sphinx-autobuild source build/html --write-all
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
```
### Update API Documentation
@ -182,7 +179,7 @@ uv run sphinx-autobuild source build/html --write-all
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
```bash
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
uv run ./docs/openapi_generator/run_openapi_generator.sh
```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.

View file

@ -1,5 +1,4 @@
include pyproject.toml
include llama_stack/templates/dependencies.json
include llama_stack/models/llama/llama3/tokenizer.model
include llama_stack/models/llama/llama4/tokenizer.model
include llama_stack/distribution/*.sh

View file

@ -107,26 +107,29 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on
### API Providers
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
| SambaNova | Hosted | | ✅ | | | |
| Cerebras | Hosted | | ✅ | | | |
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
| Together | Hosted | ✅ | ✅ | | ✅ | |
| Groq | Hosted | | ✅ | | | |
| Ollama | Single Node | | ✅ | | | |
| TGI | Hosted and Single Node | | ✅ | | | |
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
| Chroma | Single Node | | | ✅ | | |
| PG Vector | Single Node | | | ✅ | | |
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
| vLLM | Hosted and Single Node | | ✅ | | | |
| OpenAI | Hosted | | ✅ | | | |
| Anthropic | Hosted | | ✅ | | | |
| Gemini | Hosted | | ✅ | | | |
| watsonx | Hosted | | ✅ | | | |
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | |
| SambaNova | Hosted | | ✅ | | ✅ | | |
| Cerebras | Hosted | | ✅ | | | | |
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | |
| AWS Bedrock | Hosted | | ✅ | | ✅ | | |
| Together | Hosted | ✅ | ✅ | | ✅ | | |
| Groq | Hosted | | ✅ | | | | |
| Ollama | Single Node | | ✅ | | | | |
| TGI | Hosted and Single Node | | ✅ | | | | |
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | |
| Chroma | Single Node | | | ✅ | | | |
| PG Vector | Single Node | | | ✅ | | | |
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | |
| vLLM | Hosted and Single Node | | ✅ | | | | |
| OpenAI | Hosted | | ✅ | | | | |
| Anthropic | Hosted | | ✅ | | | | |
| Gemini | Hosted | | ✅ | | | | |
| watsonx | Hosted | | ✅ | | | | |
| HuggingFace | Single Node | | | | | | ✅ |
| TorchTune | Single Node | | | | | | ✅ |
| NVIDIA NEMO | Hosted | | | | | | ✅ |
### Distributions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -759,7 +759,7 @@ class Generator:
)
return Operation(
tags=[op.defining_class.__name__],
tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
summary=None,
# summary=doc_string.short_description,
description=description,
@ -805,6 +805,8 @@ class Generator:
operation_tags: List[Tag] = []
for cls in endpoint_classes:
doc_string = parse_type(cls)
if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
continue
operation_tags.append(
Tag(
name=cls.__name__,

View file

@ -3,10 +3,10 @@
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
## Render locally
From the llama-stack root directory, run the following command to render the docs locally:
```bash
pip install -r requirements.txt
cd docs
python -m sphinx_autobuild source _build
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
```
You can open up the docs in your browser at http://localhost:8000

View file

@ -1,16 +0,0 @@
linkify
myst-parser
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinx==8.1.3
sphinx-copybutton
sphinx-design
sphinx-pdj-theme
sphinx-rtd-theme>=1.0.0
sphinx-tabs
sphinx_autobuild
sphinx_rtd_dark_mode
sphinxcontrib-mermaid
sphinxcontrib-openapi
sphinxcontrib-redoc
sphinxcontrib-video
tomli

View file

@ -57,6 +57,31 @@ chunks = [
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
```
#### Using Precomputed Embeddings
If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by
including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you
want to customize the ingestion process.
```python
chunks_with_embeddings = [
{
"content": "First chunk of text",
"mime_type": "text/plain",
"embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
"metadata": {"document_id": "doc1", "section": "introduction"},
},
{
"content": "Second chunk of text",
"mime_type": "text/plain",
"embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
"metadata": {"document_id": "doc1", "section": "methodology"},
},
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
```
When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when
registering the vector database.
### Retrieval
You can query the vector database to retrieve documents based on their embeddings.
```python

View file

@ -22,7 +22,11 @@ from docutils import nodes
# Read version from pyproject.toml
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
pypi_url = "https://pypi.org/pypi/llama-stack/json"
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
headers = {
'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent
'Accept': 'application/json'
}
version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"]
print(f"{version_tag=}")
# generate the full link including text and url here
@ -53,14 +57,6 @@ myst_enable_extensions = ["colon_fence"]
html_theme = "sphinx_rtd_theme"
html_use_relative_paths = True
# html_theme = "sphinx_pdj_theme"
# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()]
# html_theme = "pytorch_sphinx_theme"
# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

View file

@ -338,6 +338,48 @@ INFO: Application startup complete.
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
```
### Listing Distributions
Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files.
```
llama stack list -h
usage: llama stack list [-h]
list the build stacks
options:
-h, --help show this help message and exit
```
Example Usage
```
llama stack list
```
### Removing a Distribution
Use the remove command to delete a distribution you've previously built.
```
llama stack rm -h
usage: llama stack rm [-h] [--all] [name]
Remove the build stack
positional arguments:
name Name of the stack to delete (default: None)
options:
-h, --help show this help message and exit
--all, -a Delete all stacks (use with caution) (default: False)
```
Example
```
llama stack rm llamastack-test
```
To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm <name>` to delete them when theyre no longer needed.
### Troubleshooting

View file

@ -118,11 +118,6 @@ server:
port: 8321 # Port to listen on (default: 8321)
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
auth: # Optional: Authentication configuration
provider_type: "kubernetes" # Type of auth provider
config: # Provider-specific configuration
api_server_url: "https://kubernetes.default.svc"
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
```
### Authentication Configuration
@ -135,7 +130,7 @@ Authorization: Bearer <token>
The server supports multiple authentication providers:
#### Kubernetes Provider
#### OAuth 2.0/OpenID Connect Provider with Kubernetes
The Kubernetes cluster must be configured to use a service account for authentication.
@ -146,14 +141,67 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
```
Validates tokens against the Kubernetes API server:
Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
and that the correct RoleBinding is created to allow the service account to access the necessary
resources. If that is not the case, you can create a RoleBinding for the service account to access
the necessary resources:
```yaml
# allow-anonymous-openid.yaml
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: allow-anonymous-openid
rules:
- nonResourceURLs: ["/openid/v1/jwks"]
verbs: ["get"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: allow-anonymous-openid
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: allow-anonymous-openid
subjects:
- kind: User
name: system:anonymous
apiGroup: rbac.authorization.k8s.io
```
And then apply the configuration:
```bash
kubectl apply -f allow-anonymous-openid.yaml
```
Validates tokens against the Kubernetes API server through the OIDC provider:
```yaml
server:
auth:
provider_type: "kubernetes"
provider_type: "oauth2_token"
config:
api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
jwks:
uri: "https://kubernetes.default.svc"
key_recheck_period: 3600
tls_cafile: "/path/to/ca.crt"
issuer: "https://kubernetes.default.svc"
audience: "https://kubernetes.default.svc"
```
To find your cluster's audience, run:
```bash
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
```
For the issuer, you can use the OIDC provider's URL:
```bash
kubectl get --raw /.well-known/openid-configuration| jq .issuer
```
For the tls_cafile, you can use the CA certificate of the OIDC provider:
```bash
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
```
The provider extracts user information from the JWT token:
@ -208,6 +256,80 @@ And must respond with:
If no access attributes are returned, the token is used as a namespace.
### Quota Configuration
The `quota` section allows you to enable server-side request throttling for both
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
fairness across tenants, and controlling infrastructure costs without requiring
client-side rate limiting or external proxies.
Quotas are disabled by default. When enabled, each client is tracked using either:
* Their authenticated `client_id` (derived from the Bearer token), or
* Their IP address (fallback for anonymous requests)
Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
within a configurable time window (currently only `day` is supported).
#### Example
```yaml
server:
quota:
kvstore:
type: sqlite
db_path: ./quotas.db
anonymous_max_requests: 100
authenticated_max_requests: 1000
period: day
```
#### Configuration Options
| Field | Description |
| ---------------------------- | -------------------------------------------------------------------------- |
| `kvstore` | Required. Backend storage config for tracking request counts. |
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
| `kvstore.db_path` | File path to the SQLite database. |
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
| `period` | Time window for quota enforcement. Only `"day"` is supported. |
> Note: if `authenticated_max_requests` is set but no authentication provider is
configured, the server will fall back to applying `anonymous_max_requests` to all
clients.
#### Example with Authentication Enabled
```yaml
server:
port: 8321
auth:
provider_type: custom
config:
endpoint: https://auth.example.com/validate
quota:
kvstore:
type: sqlite
db_path: ./quotas.db
anonymous_max_requests: 100
authenticated_max_requests: 1000
period: day
```
If a client exceeds their limit, the server responds with:
```http
HTTP/1.1 429 Too Many Requests
Content-Type: application/json
{
"error": {
"message": "Quota exceeded"
}
}
```
## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example.

View file

@ -172,7 +172,7 @@ spec:
- name: llama-stack
image: localhost/llama-stack-run-k8s:latest
imagePullPolicy: IfNotPresent
command: ["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]
command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"]
ports:
- containerPort: 5000
volumeMounts:

View file

@ -70,7 +70,7 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-watsonx \
--yaml-config /root/my-run.yaml \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env WATSONX_API_KEY=$WATSONX_API_KEY \
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \

View file

@ -52,7 +52,7 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-cerebras \
--yaml-config /root/my-run.yaml \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```

View file

@ -155,7 +155,7 @@ docker run \
-v $HOME/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-dell \
--yaml-config /root/my-run.yaml \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env DEH_URL=$DEH_URL \

View file

@ -143,7 +143,7 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \
--yaml-config /root/my-run.yaml \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
```

View file

@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::ollama` |
| post_training | `inline::huggingface` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
@ -97,7 +98,7 @@ docker run \
-v ~/.llama:/root/.llama \
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-ollama \
--yaml-config /root/my-run.yaml \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL \

View file

@ -233,7 +233,7 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \
--yaml-config /root/my-run.yaml \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
@ -255,7 +255,7 @@ docker run \
-v ~/.llama:/root/.llama \
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \
--yaml-config /root/my-run.yaml \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \

View file

@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::sambanova`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| safety | `remote::sambanova` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
@ -48,33 +48,44 @@ The following models are available by default:
### Prerequisite: API Keys
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/).
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup).
## Running Llama Stack with SambaNova
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
### Via Docker
```bash
LLAMA_STACK_PORT=8321
llama stack build --template sambanova --image-type container
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-sambanova \
-v ~/.llama:/root/.llama \
distribution-sambanova \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```
### Via Venv
```bash
llama stack build --template sambanova --image-type venv
llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```
### Via Conda
```bash
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

View file

@ -117,7 +117,7 @@ docker run \
-v ~/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-tgi \
--yaml-config /root/my-run.yaml \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \

View file

@ -30,6 +30,18 @@ Runs inference with an LLM.
## Post Training
Fine-tunes a model.
#### Post Training Providers
The following providers are available for Post Training:
```{toctree}
:maxdepth: 1
external
post_training/huggingface
post_training/torchtune
post_training/nvidia_nemo
```
## Safety
Applies safety policies to the output at a Systems (not only model) level.

View file

@ -0,0 +1,122 @@
---
orphan: true
---
# HuggingFace SFTTrainer
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
## Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
## Usage
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Kick off a SFT job using the Llama Stack post_training API.
## Setup
You can access the HuggingFace trainer via the `ollama` distribution:
```bash
llama stack build --template ollama --image-type venv
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
```
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```

View file

@ -0,0 +1,163 @@
---
orphan: true
---
# NVIDIA NEMO
[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service.
## Features
- Enterprise-grade fine-tuning capabilities
- Support for LoRA and SFT fine-tuning
- Integration with NVIDIA's NeMo Customizer service
- Support for various NVIDIA-optimized models
- Efficient training with NVIDIA hardware acceleration
## Usage
To use NVIDIA NEMO in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Set up your NVIDIA API credentials.
3. Kick off a fine-tuning job using the Llama Stack post_training API.
## Setup
You'll need to set the following environment variables:
```bash
export NVIDIA_API_KEY="your-api-key"
export NVIDIA_DATASET_NAMESPACE="default"
export NVIDIA_CUSTOMIZER_URL="your-customizer-url"
export NVIDIA_PROJECT_ID="your-project-id"
export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir"
```
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=8, # Default batch size for NEMO
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
n_epochs=50, # Default epochs for NEMO
optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig(
lr=0.0001, # Default learning rate
weight_decay=0.01, # NEMO-specific parameter
),
# NEMO-specific parameters
log_every_n_steps=None,
val_check_interval=0.25,
sequence_packing_enabled=False,
hidden_dropout=None,
attention_dropout=None,
ffn_dropout=None,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=16, # Default alpha for NEMO
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model - must be a supported NEMO model
training_model = "meta/llama-3.1-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```
## Supported Models
Currently supports the following models:
- meta/llama-3.1-8b-instruct
- meta/llama-3.2-1b-instruct
## Supported Parameters
### TrainingConfig
- n_epochs (default: 50)
- data_config
- optimizer_config
- log_every_n_steps
- val_check_interval (default: 0.25)
- sequence_packing_enabled (default: False)
- hidden_dropout (0.0-1.0)
- attention_dropout (0.0-1.0)
- ffn_dropout (0.0-1.0)
### DataConfig
- dataset_id
- batch_size (default: 8)
### OptimizerConfig
- lr (default: 0.0001)
- weight_decay (default: 0.01)
### LoRA Config
- alpha (default: 16)
- type (must be "LoRA")
Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning.

View file

@ -0,0 +1,125 @@
---
orphan: true
---
# TorchTune
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
## Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support and single device capabilities.
- Support for LoRA
## Usage
To use TorchTune in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Kick off a fine-tuning job using the Llama Stack post_training API.
## Setup
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
```yaml
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
```
you can then build and run your own stack with this provider.
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "meta-llama/Llama-2-7b-hf"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```

View file

@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
2. Configure your Llama Stack project to use SQLite-Vec.
3. Start storing and querying vectors.
## Supported Search Modes
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
`RAGQueryConfig`. For example:
```python
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
results = client.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="what is torchtune",
query_config=query_config,
)
```
## Installation
You can install SQLite-Vec using pip:

6
kvant_build_local.sh Executable file
View file

@ -0,0 +1,6 @@
#!/usr/bin/env bash
export USE_COPY_NOT_MOUNT=true
export LLAMA_STACK_DIR=.
uvx --from . llama stack build --template kvant --image-type container --image-name kvant

17
kvant_start_local.sh Executable file
View file

@ -0,0 +1,17 @@
#!/usr/bin/env bash
export LLAMA_STACK_PORT=8321
# VLLM_API_TOKEN= env file
# KEYCLOAK_CLIENT_SECRET= env file
docker run -it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $(pwd)/data:/root/.llama \
--mount type=bind,source="$(pwd)"/llama_stack/templates/kvant/run.yaml,target=/root/.llama/config.yaml,readonly \
--entrypoint python \
--env-file ./.env \
distribution-kvant:dev \
-m llama_stack.distribution.server.server --config /root/.llama/config.yaml \
--port $LLAMA_STACK_PORT \

View file

@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
@ -31,6 +31,8 @@ from llama_stack.apis.tools import ToolDef
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from .openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
@ -579,14 +581,14 @@ class Agents(Protocol):
#
# Both of these APIs are inherently stateful.
@webmethod(route="/openai/v1/responses/{id}", method="GET")
@webmethod(route="/openai/v1/responses/{response_id}", method="GET")
async def get_openai_response(
self,
id: str,
response_id: str,
) -> OpenAIResponseObject:
"""Retrieve an OpenAI response by its ID.
:param id: The ID of the OpenAI response to retrieve.
:param response_id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject.
"""
...
@ -596,6 +598,7 @@ class Agents(Protocol):
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
@ -610,3 +613,43 @@ class Agents(Protocol):
:returns: An OpenAIResponseObject.
"""
...
@webmethod(route="/openai/v1/responses", method="GET")
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
"""List all OpenAI responses.
:param after: The ID of the last response to return.
:param limit: The number of responses to return.
:param model: The model to filter responses by.
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
:returns: A ListOpenAIResponseObject.
"""
...
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: The order to return the input items in. Default is desc.
:returns: An ListOpenAIResponseInputItem.
"""
...

View file

@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
# take their YAML and generate this file automatically. Their YAML is available.
@json_schema_type
class OpenAIResponseError(BaseModel):
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
arguments: str
call_id: str
name: str
arguments: str
type: Literal["function_call"] = "function_call"
id: str | None = None
status: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPCall(BaseModel):
id: str
status: str
type: Literal["mcp_call"] = "mcp_call"
arguments: str
name: str
server_label: str
error: str | None = None
output: str | None = None
class MCPListToolsTool(BaseModel):
input_schema: dict[str, Any]
name: str
description: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
id: str
type: Literal["mcp_list_tools"] = "mcp_list_tools"
server_label: str
tools: list[MCPListToolsTool]
OpenAIResponseOutput = Annotated[
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -117,6 +149,16 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.output_text.delta"] = "response.output_text.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
response: OpenAIResponseObject
@ -124,7 +166,9 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseOutputTextDelta
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"),
]
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
@ -186,13 +230,50 @@ class OpenAIResponseInputToolFileSearch(BaseModel):
# TODO: add filters
class ApprovalFilter(BaseModel):
always: list[str] | None = None
never: list[str] | None = None
class AllowedToolsFilter(BaseModel):
tool_names: list[str] | None = None
@json_schema_type
class OpenAIResponseInputToolMCP(BaseModel):
type: Literal["mcp"] = "mcp"
server_label: str
server_url: str
headers: dict[str, Any] | None = None
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
allowed_tools: list[str] | AllowedToolsFilter | None = None
OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch
| OpenAIResponseInputToolFileSearch
| OpenAIResponseInputToolFunction
| OpenAIResponseInputToolMCP,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
class OpenAIResponseInputItemList(BaseModel):
class ListOpenAIResponseInputItem(BaseModel):
data: list[OpenAIResponseInput]
object: Literal["list"] = "list"
@json_schema_type
class OpenAIResponseObjectWithInput(OpenAIResponseObject):
input: list[OpenAIResponseInput]
@json_schema_type
class ListOpenAIResponseObject(BaseModel):
data: list[OpenAIResponseObjectWithInput]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RestAPIMethod(Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
@json_schema_type
class RestAPIExecutionConfig(BaseModel):
url: URL
method: RestAPIMethod
params: dict[str, Any] | None = None
headers: dict[str, Any] | None = None
body: dict[str, Any] | None = None

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from pydantic import BaseModel
@ -11,6 +12,11 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
class Order(Enum):
asc = "asc"
desc = "desc"
@json_schema_type
class PaginatedResponse(BaseModel):
"""A generic paginated response that follows a simple format.

View file

@ -19,6 +19,7 @@ from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.common.responses import Order
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
@ -782,6 +783,48 @@ class OpenAICompletion(BaseModel):
object: Literal["text_completion"] = "text_completion"
@json_schema_type
class OpenAIEmbeddingData(BaseModel):
"""A single embedding data object from an OpenAI-compatible embeddings response.
:param object: The object type, which will be "embedding"
:param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64")
:param index: The index of the embedding in the input list
"""
object: Literal["embedding"] = "embedding"
embedding: list[float] | str
index: int
@json_schema_type
class OpenAIEmbeddingUsage(BaseModel):
"""Usage information for an OpenAI-compatible embeddings response.
:param prompt_tokens: The number of tokens in the input
:param total_tokens: The total number of tokens used
"""
prompt_tokens: int
total_tokens: int
@json_schema_type
class OpenAIEmbeddingsResponse(BaseModel):
"""Response from an OpenAI-compatible embeddings request.
:param object: The object type, which will be "list"
:param data: List of embedding data objects
:param model: The model that was used to generate the embeddings
:param usage: Usage information
"""
object: Literal["list"] = "list"
data: list[OpenAIEmbeddingData]
model: str
usage: OpenAIEmbeddingUsage
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...
@ -820,15 +863,27 @@ class BatchChatCompletionResponse(BaseModel):
batch: list[ChatCompletionResponse]
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
input_messages: list[OpenAIMessageParam]
@json_schema_type
class ListOpenAIChatCompletionResponse(BaseModel):
data: list[OpenAICompletionWithInputMessages]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"
@runtime_checkable
@trace_protocol
class Inference(Protocol):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
class InferenceProvider(Protocol):
"""
This protocol defines the interface that should be implemented by all inference providers.
"""
API_NAMESPACE: str = "Inference"
model_store: ModelStore | None = None
@ -1062,3 +1117,59 @@ class Inference(Protocol):
:returns: An OpenAIChatCompletion.
"""
...
@webmethod(route="/openai/v1/embeddings", method="POST")
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
"""
...
class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET")
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""List all chat completions.
:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
:param model: The model to filter by.
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
:returns: A ListOpenAIChatCompletionResponse.
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.
:returns: A OpenAICompletionWithInputMessages.
"""
raise NotImplementedError("Get chat completion is not implemented")

View file

@ -76,6 +76,7 @@ class RAGQueryConfig(BaseModel):
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector" or "keyword". Default "vector".
"""
# This config defines how a query is generated using the messages
@ -84,6 +85,7 @@ class RAGQueryConfig(BaseModel):
max_tokens_in_context: int = 4096
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:

View file

@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
default: Any | None = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None
@ -76,8 +68,8 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
async def get_tool(self, tool_name: str) -> Tool: ...
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
class ListToolGroupsResponse(BaseModel):

View file

@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
"""
content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
@json_schema_type
@ -50,7 +58,10 @@ class VectorIO(Protocol):
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
If `embedding` is not provided, it will be computed later.
:param ttl_seconds: The time to live of the chunks.
"""
...

View file

@ -9,6 +9,7 @@ import asyncio
import json
import os
import shutil
import sys
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import partial
@ -377,14 +378,15 @@ def _meta_download(
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
cprint(
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
"white",
file=sys.stderr,
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
"yellow",
color="yellow",
file=sys.stderr,
)

View file

@ -12,6 +12,7 @@ import shutil
import sys
import textwrap
from functools import lru_cache
from importlib.abc import Traversable
from pathlib import Path
import yaml
@ -78,6 +79,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
color="red",
file=sys.stderr,
)
sys.exit(1)
build_config = available_templates[args.template]
@ -87,6 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
color="red",
file=sys.stderr,
)
sys.exit(1)
elif args.providers:
@ -96,6 +99,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider = api_provider.split("=")
@ -104,6 +108,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"{api} is not a valid API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
if provider in providers_for_api:
@ -112,6 +117,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"{provider} is not a valid provider for the {api} API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
distribution_spec = DistributionSpec(
@ -122,6 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Please specify a image-type (container | conda | venv) for {args.template}",
color="red",
file=sys.stderr,
)
sys.exit(1)
@ -150,12 +157,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
color="yellow",
file=sys.stderr,
)
image_name = f"llamastack-{name}"
else:
cprint(
f"Using conda environment {image_name}",
color="green",
file=sys.stderr,
)
else:
image_name = f"llamastack-{name}"
@ -168,9 +177,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
""",
),
color="green",
file=sys.stderr,
)
print("Tip: use <TAB> to see options for the providers.\n")
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers = dict()
for api, providers_for_api in get_provider_registry().items():
@ -206,10 +216,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
if args.image_type:
build_config.image_type = args.image_type
except Exception as e:
cprint(
f"Could not parse config file {args.config}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
@ -236,25 +249,27 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Error building stack: {exc}",
color="red",
file=sys.stderr,
)
cprint("Stack trace:", color="red")
cprint("Stack trace:", color="red", file=sys.stderr)
traceback.print_exc()
sys.exit(1)
if run_config is None:
cprint(
"Run config path is empty",
color="red",
file=sys.stderr,
)
sys.exit(1)
if args.run:
run_config = Path(run_config)
config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
if config.external_providers_dir and not config.external_providers_dir.exists():
config.external_providers_dir.mkdir(exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
run_command(run_args)
@ -262,7 +277,7 @@ def _generate_run_config(
build_config: BuildConfig,
build_dir: Path,
image_name: str,
) -> str:
) -> Path:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
@ -302,6 +317,7 @@ def _generate_run_config(
cprint(
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
color="yellow",
file=sys.stderr,
)
# Set config_type to None to avoid UnboundLocalError
config_type = None
@ -329,10 +345,7 @@ def _generate_run_config(
# For non-container builds, the run.yaml is generated at the very end of the build process so it
# makes sense to display this message
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
cprint(
f"You can now run your stack with `llama stack run {run_config_file}`",
color="green",
)
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
return run_config_file
@ -341,7 +354,7 @@ def _run_stack_build_command_from_build_config(
image_name: str | None = None,
template_name: str | None = None,
config_path: str | None = None,
) -> str:
) -> Path | Traversable:
image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if template_name:
@ -370,7 +383,7 @@ def _run_stack_build_command_from_build_config(
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
# Only do this if we're building a container image and we're not using a template
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
cprint("Generating run.yaml file", color="green")
cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f:
@ -394,11 +407,13 @@ def _run_stack_build_command_from_build_config(
run_config_file = build_dir / f"{template_name}-run.yaml"
shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green")
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue"))
cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr)
cprint(
"You can run the new Llama Stack distro via: "
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue")
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"),
color="green",
file=sys.stderr,
)
return template_path
else:

View file

@ -49,7 +49,7 @@ class StackBuild(Subcommand):
type=str,
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
choices=[e.value for e in ImageType],
default=ImageType.CONDA.value,
default=None, # no default so we can detect if a user specified --image-type and override image_type in the config
)
self.parser.add_argument(

View file

@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class StackListBuilds(Subcommand):
"""List built stacks in .llama/distributions directory"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama stack list",
description="list the build stacks",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._list_stack_command)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stack_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if not distributions:
print("No stacks found in ~/.llama/distributions")
return
headers = ["Stack Name", "Path"]
headers.extend(["Build Config", "Run Config"])
rows = []
for name, path in distributions.items():
row = [name, str(path)]
# Check for build and run config files
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
row.extend([build_config, run_config])
rows.append(row)
print_table(rows, headers, separate_rows=True)

View file

@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import shutil
import sys
from pathlib import Path
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class StackRemove(Subcommand):
"""Remove the build stack"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"rm",
prog="llama stack rm",
description="Remove the build stack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._remove_stack_build_command)
def _add_arguments(self) -> None:
self.parser.add_argument(
"name",
type=str,
nargs="?",
help="Name of the stack to delete",
)
self.parser.add_argument(
"--all",
"-a",
action="store_true",
help="Delete all stacks (use with caution)",
)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stacks(self) -> None:
"""Display available stacks in a table"""
distributions = self._get_distribution_dirs()
if not distributions:
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
sys.exit(1)
headers = ["Stack Name", "Path"]
rows = [[name, str(path)] for name, path in distributions.items()]
print_table(rows, headers, separate_rows=True)
def _remove_stack_build_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if args.all:
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
if confirm != "yes-i-really-want":
cprint("Deletion cancelled.", color="green", file=sys.stderr)
return
for name, path in distributions.items():
try:
shutil.rmtree(path)
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
except Exception as e:
cprint(
f"Failed to delete stack {name}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
if not args.name:
self._list_stacks()
if not args.name:
return
if args.name not in distributions:
self._list_stacks()
cprint(
f"Stack not found: {args.name}",
color="red",
file=sys.stderr,
)
sys.exit(1)
stack_path = distributions[args.name]
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
if confirm != "y":
cprint("Deletion cancelled.", color="green", file=sys.stderr)
return
try:
shutil.rmtree(stack_path)
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
except Exception as e:
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
sys.exit(1)

View file

@ -6,6 +6,7 @@
import argparse
import os
import subprocess
from pathlib import Path
from llama_stack.cli.stack.utils import ImageType
@ -60,6 +61,11 @@ class StackRun(Subcommand):
help="Image Type used during the build. This can be either conda or container or venv.",
choices=[e.value for e in ImageType],
)
self.parser.add_argument(
"--enable-ui",
action="store_true",
help="Start the UI server",
)
# If neither image type nor image name is provided, but at the same time
# the current environment has conda breadcrumbs, then assume what the user
@ -83,6 +89,8 @@ class StackRun(Subcommand):
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
if args.enable_ui:
self._start_ui_development_server(args.port)
image_type, image_name = self._get_image_type_and_name(args)
# Check if config is required based on image type
@ -170,3 +178,44 @@ class StackRun(Subcommand):
run_args.extend(["--env", f"{key}={value}"])
run_command(run_args)
def _start_ui_development_server(self, stack_server_port: int):
logger.info("Attempting to start UI development server...")
# Check if npm is available
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
if npm_check.returncode != 0:
logger.warning(
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
)
return
ui_dir = REPO_ROOT / "llama_stack" / "ui"
logs_dir = Path("~/.llama/ui/logs").expanduser()
try:
# Create logs directory if it doesn't exist
logs_dir.mkdir(parents=True, exist_ok=True)
ui_stdout_log_path = logs_dir / "stdout.log"
ui_stderr_log_path = logs_dir / "stderr.log"
# Open log files in append mode
stdout_log_file = open(ui_stdout_log_path, "a")
stderr_log_file = open(ui_stderr_log_path, "a")
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=str(ui_dir),
stdout=stdout_log_file,
stderr=stderr_log_file,
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
)
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
except FileNotFoundError:
logger.error(
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
)
except Exception as e:
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")

View file

@ -7,12 +7,14 @@
import argparse
from importlib.metadata import version
from llama_stack.cli.stack.list_stacks import StackListBuilds
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild
from .list_apis import StackListApis
from .list_providers import StackListProviders
from .remove import StackRemove
from .run import StackRun
@ -41,5 +43,6 @@ class StackParser(Subcommand):
StackListApis.create(subparsers)
StackListProviders.create(subparsers)
StackRun.create(subparsers)
StackRemove.create(subparsers)
StackListBuilds.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -6,6 +6,7 @@
import importlib.resources
import logging
import sys
from pathlib import Path
from pydantic import BaseModel
@ -43,8 +44,20 @@ def get_provider_dependencies(
# Extract providers based on config type
if isinstance(config, DistributionTemplate):
providers = config.providers
# TODO: This is a hack to get the dependencies for internal APIs into build
# We should have a better way to do this by formalizing the concept of "internal" APIs
# and providers, with a way to specify dependencies for them.
run_configs = config.run_configs
additional_pip_packages: list[str] = []
if run_configs:
for run_config in run_configs.values():
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
if run_config_.inference_store:
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
@ -72,6 +85,9 @@ def get_provider_dependencies(
else:
normal_deps.append(package)
if additional_pip_packages:
normal_deps.extend(additional_pip_packages)
return list(set(normal_deps)), list(set(special_deps))
@ -80,10 +96,11 @@ def print_pip_install_help(config: BuildConfig):
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
"yellow",
color="yellow",
file=sys.stderr,
)
for special_dep in special_deps:
cprint(f"uv pip install {special_dep}", "yellow")
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
print()

View file

@ -6,6 +6,7 @@
import inspect
import json
import sys
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, Union, get_args, get_origin
@ -96,13 +97,13 @@ def create_api_client_class(protocol) -> type:
try:
data = json.loads(data)
if "error" in data:
cprint(data, "red")
cprint(data, color="red", file=sys.stderr)
continue
yield parse_obj_as(return_type, data)
except Exception as e:
print(f"Error with parsing or validation: {e}")
print(data)
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
cprint(data, color="red", file=sys.stderr)
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name]

View file

@ -25,7 +25,8 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
@ -220,21 +221,38 @@ class LoggingConfig(BaseModel):
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
OAUTH2_TOKEN = "oauth2_token"
CUSTOM = "custom"
class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field(
...,
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
description="Type of authentication provider",
)
config: dict[str, str] = Field(
config: dict[str, Any] = Field(
...,
description="Provider-specific configuration",
)
class AuthenticationRequiredError(Exception):
pass
class QuotaPeriod(str, Enum):
DAY = "day"
class QuotaConfig(BaseModel):
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
authenticated_max_requests: int = Field(
default=1000, description="Max requests for authenticated clients per period"
)
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -262,6 +280,10 @@ class ServerConfig(BaseModel):
default=None,
description="The host the server should listen on",
)
quota: QuotaConfig | None = Field(
default=None,
description="Per client quota request configuration",
)
class StackRunConfig(BaseModel):
@ -297,6 +319,13 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""",
)
inference_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the inference API. If not specified,
a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)
@ -340,8 +369,21 @@ class BuildConfig(BaseModel):
default=None,
description="Name of the distribution to build",
)
external_providers_dir: str | None = Field(
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",
)
additional_pip_packages: list[str] = Field(
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v

View file

@ -16,7 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo,
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
@ -31,7 +31,7 @@ async def get_provider_impl(config, deps):
class DistributionInspectImpl(Inspect):
def __init__(self, config, deps):
def __init__(self, config: DistributionInspectConfig, deps):
self.config = config
self.deps = deps
@ -39,22 +39,36 @@ class DistributionInspectImpl(Inspect):
pass
async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config
run_config: StackRunConfig = self.config.run_config
ret = []
all_endpoints = get_all_api_endpoints()
all_endpoints = get_all_api_routes()
for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret.extend(
[
RouteInfo(
route=e.route,
method=e.method,
provider_types=[p.provider_type for p in providers],
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e in endpoints
]
)
else:
providers = run_config.providers.get(api.value, [])
if providers: # Only process if there are providers for this API
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
]
)
for e in endpoints
]
)
return ListRoutesResponse(data=ret)

View file

@ -9,6 +9,7 @@ import inspect
import json
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
@ -36,10 +37,7 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
@ -207,13 +205,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def initialize(self) -> bool:
try:
self.endpoint_impls = None
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e:
cprint(_e.msg, "red")
cprint(_e.msg, color="red", file=sys.stderr)
cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
"yellow",
color="yellow",
file=sys.stderr,
)
if self.config_path_or_template_name.endswith(".yaml"):
# Convert Provider objects to their types
@ -226,6 +225,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
distribution_spec=DistributionSpec(
providers=provider_types,
),
external_providers_dir=self.config.external_providers_dir,
)
print_pip_install_help(build_config)
else:
@ -233,7 +233,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cprint(
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
"yellow",
file=sys.stderr,
)
cprint(
"Please check your internet connection and try again.",
"red",
file=sys.stderr,
)
raise _e
if Api.telemetry in self.impls:
@ -245,7 +251,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
self.endpoint_impls = initialize_endpoint_impls(self.impls)
self.route_impls = initialize_route_impls(self.impls)
return True
async def request(
@ -256,13 +262,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
stream=False,
stream_cls=None,
):
if not self.endpoint_impls:
if not self.route_impls:
raise ValueError("Client not initialized")
# Create headers with provider data if available
headers = {}
headers = options.headers or {}
if self.provider_data:
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
# Use context manager for provider data
with request_provider_data_context(headers):
@ -285,11 +293,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cast_to: Any,
options: Any,
):
if self.route_impls is None:
raise ValueError("Client not initialized")
path = options.url
body = options.params or {}
body |= options.json_data or {}
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"})
@ -331,10 +342,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
options: Any,
stream_cls: Any,
):
if self.route_impls is None:
raise ValueError("Client not initialized")
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
@ -386,7 +400,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body:
return {}
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
if self.route_impls is None:
raise ValueError("Client not initialized")
func, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature

View file

@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolsProtocolPrivate,
ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate,
)
@ -83,10 +83,17 @@ def api_protocol_map() -> dict[Api, Any]:
}
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
return {
**api_protocol_map(),
Api.inference: InferenceProvider,
}
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
@ -133,7 +140,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -236,7 +243,10 @@ def sort_providers_by_deps(
async def instantiate_providers(
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
sorted_providers: list[tuple[str, ProviderWithSpec]],
router_apis: set[Api],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
) -> dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
@ -251,7 +261,7 @@ async def instantiate_providers(
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@ -301,10 +311,8 @@ async def instantiate_provider(
deps: dict[Api, Any],
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
@ -323,7 +331,7 @@ async def instantiate_provider(
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
@ -342,6 +350,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check()
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])

View file

@ -7,18 +7,10 @@
from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.stack import StackRunConfig
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from .routing_tables import (
BenchmarksRoutingTable,
DatasetsRoutingTable,
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
ToolGroupsRoutingTable,
VectorDBsRoutingTable,
)
from llama_stack.providers.utils.inference.inference_store import InferenceStore
async def get_routing_table_impl(
@ -27,6 +19,14 @@ async def get_routing_table_impl(
_deps,
dist_registry: DistributionRegistry,
) -> Any:
from ..routing_tables.benchmarks import BenchmarksRoutingTable
from ..routing_tables.datasets import DatasetsRoutingTable
from ..routing_tables.models import ModelsRoutingTable
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = {
"vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable,
@ -45,16 +45,15 @@ async def get_routing_table_impl(
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
InferenceRouter,
SafetyRouter,
ScoringRouter,
ToolRuntimeRouter,
VectorIORouter,
)
async def get_auto_router_impl(
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
) -> Any:
from .datasets import DatasetIORouter
from .eval_scoring import EvalRouter, ScoringRouter
from .inference import InferenceRouter
from .safety import SafetyRouter
from .tool_runtime import ToolRuntimeRouter
from .vector_io import VectorIORouter
api_to_routers = {
"vector_io": VectorIORouter,
@ -76,6 +75,12 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
# TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
)

View file

@ -0,0 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def job_status(
self,
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
)

View file

@ -14,14 +14,9 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo
from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
@ -32,8 +27,11 @@ from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAICompletionWithInputMessages,
Order,
ResponseFormat,
SamplingParams,
StopReason,
@ -47,93 +45,23 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_vector_db_id,
)
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
class InferenceRouter(Inference):
"""Routes to an provider based on the model"""
@ -141,10 +69,12 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
self.store = store
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
@ -607,9 +537,59 @@ class InferenceRouter(Inference):
provider = self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_chat_completion(**params)
response_stream = await provider.openai_chat_completion(**params)
if self.store:
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
return response_stream
else:
return await self._nonstream_openai_chat_completion(provider, params)
response = await self._nonstream_openai_chat_completion(provider, params)
if self.store:
await self.store.store_chat_completion(response, messages)
return response
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
logger.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ValueError(f"Model '{model}' not found")
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model '{model}' is not an embedding model")
params = dict(
model=model_obj.identifier,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params)
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
@ -642,295 +622,3 @@ class InferenceRouter(Inference):
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
)
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def job_status(
self,
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
)
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -1,634 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import time
import uuid
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolGroup,
ToolGroups,
ToolHost,
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import (
AccessAttributes,
BenchmarkWithACL,
DatasetWithACL,
ModelWithACL,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithACL,
ShieldWithACL,
ToolGroupWithACL,
ToolWithACL,
VectorDBWithACL,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
logger = logging.getLogger(__name__)
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.vector_io:
return await p.register_vector_db(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.vector_io:
p.vector_db_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
else:
raise ValueError("Unknown routing table type")
apiname, objtype = apiname_object()
# Get objects from disk registry
obj = self.dist_registry.get_cached(objtype, routing_key)
if not obj:
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if params is None:
params = {}
shield = ShieldWithACL(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
)
await self.register_object(shield)
return shield
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise ValueError(f"Vector DB '{vector_db_id}' not found")
return vector_db
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
if existing_vector_db is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
await self.unregister_object(existing_vector_db)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise ValueError(f"Dataset '{dataset_id}' not found")
return dataset
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if metadata:
if metadata.get("provider_id"):
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
elif source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None:
metadata = {}
dataset = DatasetWithACL(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
purpose=purpose,
source=source,
metadata=metadata,
)
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
if dataset is None:
raise ValueError(f"Dataset {dataset_id} not found")
await self.unregister_object(dataset)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithACL(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs.data:
tools.append(
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
)
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id)
for tool in getattr(tools, "data", []):
await self.unregister_object(tool)
await self.unregister_object(tool_group)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)

View file

@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_vector_db_id,
)
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.distribution.datatypes import (
BenchmarkWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)

View file

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import (
AccessAttributes,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.vector_io:
return await p.register_vector_db(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_toolgroup(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_toolgroup(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.vector_io:
p.vector_db_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("ToolGroups", "tool_group")
else:
raise ValueError("Unknown routing table type")
apiname, objtype = apiname_object()
# Get objects from disk registry
obj = self.dist_registry.get_cached(objtype, routing_key)
if not obj:
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Any
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.resource import ResourceType
from llama_stack.distribution.datatypes import (
DatasetWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise ValueError(f"Dataset '{dataset_id}' not found")
return dataset
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if metadata:
if metadata.get("provider_id"):
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
elif source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None:
metadata = {}
dataset = DatasetWithACL(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
purpose=purpose,
source=source,
metadata=metadata,
)
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
if dataset is None:
raise ValueError(f"Dataset {dataset_id} not found")
await self.unregister_object(dataset)

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.distribution.datatypes import (
ScoringFnWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithACL(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.distribution.datatypes import (
ShieldWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if params is None:
params = {}
shield = ShieldWithACL(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
)
await self.register_object(shield)
return shield

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithACL
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
# handle the funny case like "builtin::rag/knowledge_search"
parts = toolgroup_name_with_maybe_tool_name.split("/")
if len(parts) == 2:
return parts[0]
else:
return None
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroups_to_tools: dict[str, list[Tool]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
if toolgroup_id:
routing_key = toolgroup_id
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
await self._index_tools(toolgroup)
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=toolgroup.identifier,
description=t.description or "",
parameters=t.parameters or [],
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
self.toolgroups_to_tools[toolgroup.identifier] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group(
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
toolgroup = ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
await self.register_object(toolgroup)
# ideally, indexing of the tools should not be necessary because anyone using
# the tools should first list the tools and then use them. but there are assumptions
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
await self._index_tools(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
await self.unregister_object(tool_group)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import TypeAdapter
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.datatypes import (
VectorDBWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise ValueError(f"Vector DB '{vector_db_id}' not found")
return vector_db
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
if existing_vector_db is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
await self.unregister_object(existing_vector_db)

View file

@ -8,7 +8,8 @@ import json
import httpx
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -77,7 +78,7 @@ class AuthenticationMiddleware:
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthProviderConfig):
def __init__(self, app, auth_config: AuthenticationConfig):
self.app = app
self.auth_provider = create_auth_provider(auth_config)
@ -93,7 +94,7 @@ class AuthenticationMiddleware:
# Validate token and get access attributes
try:
access_attributes = await self.auth_provider.validate_token(token, scope)
validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
@ -105,17 +106,24 @@ class AuthenticationMiddleware:
return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control
if access_attributes:
user_attributes = access_attributes.model_dump(exclude_none=True)
if validation_result.access_attributes:
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"namespaces": [token],
"roles": [token],
}
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
# Store attributes in request scope
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
scope["principal"] = validation_result.principal
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
)
return await self.app(scope, receive, send)

View file

@ -4,23 +4,29 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import ssl
import time
from abc import ABC, abstractmethod
from enum import Enum
from asyncio import Lock
from pathlib import Path
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
class TokenValidationResult(BaseModel):
principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field(
default=None,
description="""
@ -43,6 +49,10 @@ class AuthResponse(BaseModel):
""",
)
class AuthResponse(TokenValidationResult):
"""The format of the authentication response from the auth endpoint."""
message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
@ -64,25 +74,11 @@ class AuthRequest(BaseModel):
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
CUSTOM = "custom"
class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers."""
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: dict[str, str] = Field(..., description="Provider-specific configuration")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token and return access attributes."""
pass
@ -92,88 +88,219 @@ class AuthProvider(ABC):
pass
class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key):
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
values = claim.split()
def __init__(self, config: dict[str, str]):
self.api_server_url = config["api_server_url"]
self.ca_cert_path = config.get("ca_cert_path")
self._client = None
current = getattr(attributes, attribute_key)
if current:
current.extend(values)
else:
setattr(attributes, attribute_key, values)
return attributes
async def _get_client(self):
"""Get or create a Kubernetes client."""
if self._client is None:
# kubernetes-client has not async support, see:
# https://github.com/kubernetes-client/python/issues/323
from kubernetes import client
from kubernetes.client import ApiClient
# Configure the client
configuration = client.Configuration()
configuration.host = self.api_server_url
if self.ca_cert_path:
configuration.ssl_ca_cert = self.ca_cert_path
configuration.verify_ssl = bool(self.ca_cert_path)
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
# Create API client
self._client = ApiClient(configuration)
return self._client
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
"""Validate a Kubernetes token and return access attributes."""
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack"
verify_tls: bool = True
tls_cafile: Path | None = None
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None
introspection: OAuth2IntrospectionConfig | None = None
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: OAuth2TokenAuthProviderConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
if self.config.jwks:
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
try:
client = await self._get_client()
# Set the token in the client
client.set_default_header("Authorization", f"Bearer {token}")
# Make a request to validate the token
# We use the /api endpoint which requires authentication
from kubernetes.client import CoreV1Api
api = CoreV1Api(client)
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
# If we get here, the token is valid
# Extract user info from the token claims
import base64
# Decode the token (without verification since we've already validated it)
token_parts = token.split(".")
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
# Extract user information from the token
username = payload.get("sub", "")
groups = payload.get("groups", [])
return AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode(
token,
key_data,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=access_attributes,
)
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
}
if self.config.introspection is None:
raise ValueError("Introspection is not configured")
if self.config.introspection.send_secret_in_body:
form["client_id"] = self.config.introspection.client_id
form["client_secret"] = self.config.introspection.client_secret
auth = None
else:
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
ssl_ctxt = None
if self.config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post(
self.config.introspection.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json()
if not fields["active"]:
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Failed to validate Kubernetes token")
raise ValueError("Invalid or expired token") from e
logger.exception("Error during token introspection")
raise ValueError("Token introspection error") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
self._client.close()
self._client = None
pass
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):
endpoint: str
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: dict[str, str]):
self.endpoint = config["endpoint"]
def __init__(self, config: CustomAuthProviderConfig):
self.config = config
self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the custom authentication endpoint."""
if not self.endpoint:
raise ValueError("Authentication endpoint not configured")
if scope is None:
scope = {}
@ -202,7 +329,7 @@ class CustomAuthProvider(AuthProvider):
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.endpoint,
self.config.endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
@ -214,19 +341,7 @@ class CustomAuthProvider(AuthProvider):
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
return auth_response.access_attributes
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [token],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
return auth_response.access_attributes
return auth_response
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
@ -248,14 +363,14 @@ class CustomAuthProvider(AuthProvider):
self._client = None
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "kubernetes":
return KubernetesAuthProvider(config.config)
elif provider_type == "custom":
return CustomAuthProvider(config.config)
if provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import time
from datetime import datetime, timedelta, timezone
from starlette.types import ASGIApp, Receive, Scope, Send
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
class QuotaMiddleware:
"""
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
within a configurable time window.
- For authenticated requests, it reads the client ID from the
`Authorization: Bearer <client_id>` header.
- For anonymous requests, it falls back to the IP address of the client.
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
once a client exceeds its quota.
"""
def __init__(
self,
app: ASGIApp,
kv_config: KVStoreConfig,
anonymous_max_requests: int,
authenticated_max_requests: int,
window_seconds: int = 86400,
):
self.app = app
self.kv_config = kv_config
self.kv: KVStore | None = None
self.anonymous_max_requests = anonymous_max_requests
self.authenticated_max_requests = authenticated_max_requests
self.window_seconds = window_seconds
if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
async def _get_kv(self) -> KVStore:
if self.kv is None:
self.kv = await kvstore_impl(self.kv_config)
return self.kv
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "http":
# pick key & limit based on auth
auth_id = scope.get("authenticated_client_id")
if auth_id:
key_id = auth_id
limit = self.authenticated_max_requests
else:
# fallback to IP
client = scope.get("client")
key_id = client[0] if client else "anonymous"
limit = self.anonymous_max_requests
current_window = int(time.time() // self.window_seconds)
key = f"quota:{key_id}:{current_window}"
try:
kv = await self._get_kv()
prev = await kv.get(key) or "0"
count = int(prev) + 1
if int(prev) == 0:
# Set with expiration datetime when it is the first request in the window.
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
await kv.set(key, str(count), expiration=expiration)
else:
await kv.set(key, str(count))
except Exception:
logger.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error")
if count > limit:
logger.warning(
"Quota exceeded for client %s: %d/%d",
key_id,
count,
limit,
)
return await self._send_error(send, 429, "Quota exceeded")
return await self.app(scope, receive, send)
async def _send_error(self, send: Send, status: int, message: str):
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
body = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": body})

View file

@ -6,20 +6,23 @@
import inspect
import re
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
descriptive_name: str | None = None
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]
RouteInfo = tuple[EndpointFunc, str]
PathImpl = dict[str, RouteInfo]
RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str]
def toolgroup_protocol_map():
@ -28,13 +31,13 @@ def toolgroup_protocol_map():
}
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
def get_all_api_routes() -> dict[Api, list[Route]]:
apis = {}
protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
endpoints = []
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
webmethod = method.__webmethod__ # type: ignore[attr-defined]
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == hdrs.METH_GET:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else:
method = "post"
endpoints.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
)
http_method = hdrs.METH_POST
routes.append(
Route(path=path, methods=[http_method], name=name, endpoint=None)
) # setting endpoint to None since don't use a Router object
apis[api] = endpoints
apis[api] = routes
return apis
def initialize_endpoint_impls(impls):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
routes = get_all_api_routes()
route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
@ -83,29 +88,34 @@ def initialize_endpoint_impls(impls):
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
for api, api_routes in routes.items():
if api not in impls:
continue
for endpoint in api_endpoints:
for route in api_routes:
impl = impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
continue # Skip if only HEAD method is available
method = available_methods[0].lower()
if method not in route_impls:
route_impls[method] = {}
route_impls[method][_convert_path_to_regex(route.path)] = (
func,
endpoint.descriptive_name or endpoint.route,
route.path,
)
return endpoint_impls
return route_impls
def find_matching_endpoint(method, path, endpoint_impls):
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations
route_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
@ -113,7 +123,7 @@ def find_matching_endpoint(method, path, endpoint_impls):
Raises:
ValueError: If no matching endpoint is found
"""
impls = endpoint_impls.get(method.lower())
impls = route_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")

View file

@ -6,6 +6,7 @@
import argparse
import asyncio
import functools
import inspect
import json
import os
@ -13,6 +14,7 @@ import ssl
import sys
import traceback
import warnings
from collections.abc import Callable
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
@ -20,23 +22,26 @@ from typing import Annotated, Any
import rich.pretty
import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
from llama_stack.distribution.server.routes import (
find_matching_route,
get_all_api_routes,
initialize_route_impls,
)
from llama_stack.distribution.stack import (
construct_stack,
@ -59,7 +64,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -120,6 +125,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
else:
return HTTPException(
status_code=500,
@ -205,8 +212,9 @@ async def log_request_pre_validation(request: Request):
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
@ -246,9 +254,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
for param in new_params[1:]
]
endpoint.__signature__ = sig.replace(parameters=new_params)
route_handler.__signature__ = sig.replace(parameters=new_params)
return endpoint
return route_handler
class TracingMiddleware:
@ -270,17 +278,28 @@ class TracingMiddleware:
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
@ -370,14 +389,6 @@ def main(args: argparse.Namespace | None = None):
if args is None:
args = parser.parse_args()
# Check for deprecated argument usage
if "--yaml-config" in sys.argv:
warnings.warn(
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
DeprecationWarning,
stacklevel=2,
)
log_line = ""
if args.config:
# if the user provided a config file, use it, even if template was specified
@ -391,7 +402,7 @@ def main(args: argparse.Namespace | None = None):
raise ValueError(f"Template {args.template} does not exist")
log_line = f"Using template {args.template} config file: {config_file}"
else:
raise ValueError("Either --yaml-config or --template must be provided")
raise ValueError("Either --config or --template must be provided")
logger_config = None
with open(config_file) as fp:
@ -431,6 +442,46 @@ def main(args: argparse.Namespace | None = None):
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
quota.anonymous_max_requests,
)
if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
# if auth is disabled, use the anonymous max requests
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
kv_config = quota.kvstore
window_map = {"day": 86400}
window_seconds = window_map[quota.period.value]
app.add_middleware(
QuotaMiddleware,
kv_config=kv_config,
anonymous_max_requests=anonymous_max_requests,
authenticated_max_requests=authenticated_max_requests,
window_seconds=window_seconds,
)
# --- CORS middleware for local development ---
# TODO: move to reverse proxy
ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322)
app.add_middleware(
CORSMiddleware,
allow_origins=[f"http://localhost:{ui_port}"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
try:
impls = asyncio.run(construct_stack(config))
@ -443,7 +494,7 @@ def main(args: argparse.Namespace | None = None):
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints()
all_routes = get_all_api_routes()
if config.apis:
apis_to_serve = set(config.apis)
@ -461,24 +512,29 @@ def main(args: argparse.Namespace | None = None):
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
routes = all_routes[api]
impl = impls[api]
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
for route in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, endpoint.name)
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
endpoint.route,
method.lower(),
route.path,
)
)

View file

@ -54,7 +54,7 @@ other_args=""
# Process remaining arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--config|--yaml-config)
--config)
if [[ -n "$2" ]]; then
yaml_config="$2"
shift 2
@ -121,7 +121,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="--yaml-config $yaml_config"
yaml_config_arg="--config $yaml_config"
else
yaml_config_arg=""
fi
@ -181,9 +181,9 @@ elif [[ "$env_type" == "container" ]]; then
# Add yaml config if provided, otherwise use default
if [ -n "$yaml_config" ]; then
cmd="$cmd -v $yaml_config:/app/run.yaml --yaml-config /app/run.yaml"
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
else
cmd="$cmd --yaml-config /app/run.yaml"
cmd="$cmd --config /app/run.yaml"
fi
# Add any other args

View file

@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v8"
KEY_VERSION = "v9"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -5,7 +5,8 @@ FROM python:3.12-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \
/usr/local/bin/pip3 install -r requirements.txt
/usr/local/bin/pip3 install -r requirements.txt && \
/usr/local/bin/pip3 install -r llama_stack/distribution/ui/requirements.txt
EXPOSE 8501
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
ENTRYPOINT ["streamlit", "run", "llama_stack/distribution/ui/app.py", "--server.port=8501", "--server.address=0.0.0.0"]

View file

@ -48,3 +48,6 @@ uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
| KEYCLOAK_URL | URL for keycloak authentication | (empty string) |
| KEYCLOAK_REALM | Keycloak realm | default |
| KEYCLOAK_CLIENT_ID | Client ID for keycloak auth | (empty string) |

View file

@ -50,6 +50,42 @@ def main():
)
pg.run()
def main2():
from dataclasses import asdict
st.subheader(f"Welcome {keycloak.user_info['preferred_username']}!")
st.write(f"Here is your user information:")
st.write(asdict(keycloak))
def get_access_token() -> str|None:
return st.session_state.get('access_token')
if __name__ == "__main__":
main()
from streamlit_keycloak import login
import os
keycloak_url = os.environ.get("KEYCLOAK_URL")
keycloak_realm = os.environ.get("KEYCLOAK_REALM", "default")
keycloak_client_id = os.environ.get("KEYCLOAK_CLIENT_ID")
if keycloak_url and keycloak_client_id:
keycloak = login(
url=keycloak_url,
realm=keycloak_realm,
client_id=keycloak_client_id,
custom_labels={
"labelButton": "Sign in to kvant",
"labelLogin": "Please sign in to your kvant account.",
"errorNoPopup": "Unable to open the authentication popup. Allow popups and refresh the page to proceed.",
"errorPopupClosed": "Authentication popup was closed manually.",
"errorFatal": "Unable to connect to Keycloak using the current configuration."
},
auto_refresh=True,
)
if keycloak.authenticated:
st.session_state['access_token'] = keycloak.access_token
main()
# TBD - add other authentications
else:
main()

View file

@ -7,11 +7,13 @@
import os
from llama_stack_client import LlamaStackClient
from llama_stack.distribution.ui.app import get_access_token
class LlamaStackApi:
def __init__(self):
self.client = LlamaStackClient(
api_key=get_access_token(),
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
@ -28,5 +30,3 @@ class LlamaStackApi:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
llama_stack_api = LlamaStackApi()

Some files were not shown because too many files have changed in this diff Show more