diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml
index 238fed683..f9c42ef8a 100644
--- a/.github/workflows/integration-auth-tests.yml
+++ b/.github/workflows/integration-auth-tests.yml
@@ -86,7 +86,7 @@ jobs:
# avoid line breaks in the server log, especially because we grep it below.
export COLUMNS=1984
- nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
+ nohup uv run llama stack run $run_dir/run.yaml > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready
run: |
diff --git a/.github/workflows/stale_bot.yml b/.github/workflows/stale_bot.yml
index 502a78f8e..c5a1ba9e5 100644
--- a/.github/workflows/stale_bot.yml
+++ b/.github/workflows/stale_bot.yml
@@ -24,7 +24,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Stale Action
- uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
+ uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
with:
stale-issue-label: 'stale'
stale-issue-message: >
diff --git a/.github/workflows/test-external-provider-module.yml b/.github/workflows/test-external-provider-module.yml
index 8a757b068..b43cefb27 100644
--- a/.github/workflows/test-external-provider-module.yml
+++ b/.github/workflows/test-external-provider-module.yml
@@ -59,7 +59,7 @@ jobs:
# Use the virtual environment created by the build step (name comes from build config)
source ramalama-stack-test/bin/activate
uv pip list
- nohup llama stack run tests/external/ramalama-stack/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
+ nohup llama stack run tests/external/ramalama-stack/run.yaml > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready
run: |
diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml
index 7ee467451..a008b17af 100644
--- a/.github/workflows/test-external.yml
+++ b/.github/workflows/test-external.yml
@@ -59,7 +59,7 @@ jobs:
# Use the virtual environment created by the build step (name comes from build config)
source ci-test/bin/activate
uv pip list
- nohup llama stack run tests/external/run-byoa.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
+ nohup llama stack run tests/external/run-byoa.yaml > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready
run: |
diff --git a/docs/docs/advanced_apis/post_training.mdx b/docs/docs/advanced_apis/post_training.mdx
index 516ac07e1..43bfaea91 100644
--- a/docs/docs/advanced_apis/post_training.mdx
+++ b/docs/docs/advanced_apis/post_training.mdx
@@ -52,7 +52,7 @@ You can access the HuggingFace trainer via the `starter` distribution:
```bash
llama stack build --distro starter --image-type venv
-llama stack run --image-type venv ~/.llama/distributions/starter/starter-run.yaml
+llama stack run ~/.llama/distributions/starter/starter-run.yaml
```
### Usage Example
diff --git a/docs/docs/building_applications/tools.mdx b/docs/docs/building_applications/tools.mdx
index e5d9c46f9..3b78ec57b 100644
--- a/docs/docs/building_applications/tools.mdx
+++ b/docs/docs/building_applications/tools.mdx
@@ -219,13 +219,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
-2. [Optional] Provide the API key directly to the Llama Stack server
+2. [Optional] Set the API key in your environment before starting the Llama Stack server
```bash
export TAVILY_SEARCH_API_KEY="your key"
```
-```bash
---env TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY}
-```
@@ -273,9 +270,9 @@ for log in EventLogger().log(response):
1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access).
-2. Provide the API key either when starting the Llama Stack server:
+2. Provide the API key either by setting it in your environment before starting the Llama Stack server:
```bash
- --env WOLFRAM_ALPHA_API_KEY=${WOLFRAM_ALPHA_API_KEY}
+ export WOLFRAM_ALPHA_API_KEY="your key"
```
or from the client side:
```python
diff --git a/docs/docs/contributing/new_api_provider.mdx b/docs/docs/contributing/new_api_provider.mdx
index 4ae6d5e72..6f9744771 100644
--- a/docs/docs/contributing/new_api_provider.mdx
+++ b/docs/docs/contributing/new_api_provider.mdx
@@ -76,7 +76,7 @@ Integration tests are located in [tests/integration](https://github.com/meta-lla
Consult [tests/integration/README.md](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more details on how to run the tests.
Note that each provider's `sample_run_config()` method (in the configuration class for that provider)
- typically references some environment variables for specifying API keys and the like. You can set these in the environment or pass these via the `--env` flag to the test command.
+ typically references some environment variables for specifying API keys and the like. You can set these in the environment before running the test command.
### 2. Unit Testing
diff --git a/docs/docs/distributions/building_distro.mdx b/docs/docs/distributions/building_distro.mdx
index 5b65b7f16..a4f7e1f60 100644
--- a/docs/docs/distributions/building_distro.mdx
+++ b/docs/docs/distributions/building_distro.mdx
@@ -289,10 +289,10 @@ After this step is successful, you should be able to find the built container im
docker run -d \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
+ -e INFERENCE_MODEL=$INFERENCE_MODEL \
+ -e OLLAMA_URL=http://host.docker.internal:11434 \
localhost/distribution-ollama:dev \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env OLLAMA_URL=http://host.docker.internal:11434
+ --port $LLAMA_STACK_PORT
```
Here are the docker flags and their uses:
@@ -305,12 +305,12 @@ Here are the docker flags and their uses:
* `localhost/distribution-ollama:dev`: The name and tag of the container image to run
+* `-e INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the INFERENCE_MODEL environment variable in the container
+
+* `-e OLLAMA_URL=http://host.docker.internal:11434`: Sets the OLLAMA_URL environment variable in the container
+
* `--port $LLAMA_STACK_PORT`: Port number for the server to listen on
-* `--env INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the model to use for inference
-
-* `--env OLLAMA_URL=http://host.docker.internal:11434`: Configures the URL for the Ollama service
-
@@ -320,23 +320,22 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
```
llama stack run -h
-usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE]
+usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME]
[--image-type {venv}] [--enable-ui]
- [config | template]
+ [config | distro]
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
positional arguments:
- config | template Path to config file to use for the run or name of known template (`llama stack list` for a list). (default: None)
+ config | distro Path to config file to use for the run or name of known distro (`llama stack list` for a list). (default: None)
options:
-h, --help show this help message and exit
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
--image-name IMAGE_NAME
- Name of the image to run. Defaults to the current environment (default: None)
- --env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: None)
+ [DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running. (default: None)
--image-type {venv}
- Image Type used during the build. This should be venv. (default: None)
+ [DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running. (default: None)
--enable-ui Start the UI server (default: False)
```
@@ -348,9 +347,6 @@ llama stack run tgi
# Start using config file
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
-
-# Start using a venv
-llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
```
```
diff --git a/docs/docs/distributions/configuration.mdx b/docs/docs/distributions/configuration.mdx
index dbf879024..81243c97b 100644
--- a/docs/docs/distributions/configuration.mdx
+++ b/docs/docs/distributions/configuration.mdx
@@ -101,7 +101,7 @@ A few things to note:
- The id is a string you can choose freely.
- You can instantiate any number of provider instances of the same type.
- The configuration dictionary is provider-specific.
-- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server (via docker or via `llama stack run`), you can specify `--env OLLAMA_URL=http://my-server:11434` to override the default value.
+- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server, you can set environment variables in your shell before running `llama stack run` to override the default values.
### Environment Variable Substitution
@@ -173,13 +173,10 @@ optional_token: ${env.OPTIONAL_TOKEN:+}
#### Runtime Override
-You can override environment variables at runtime when starting the server:
+You can override environment variables at runtime by setting them in your shell before starting the server:
```bash
-# Override specific environment variables
-llama stack run --config run.yaml --env API_KEY=sk-123 --env BASE_URL=https://custom-api.com
-
-# Or set them in your shell
+# Set environment variables in your shell
export API_KEY=sk-123
export BASE_URL=https://custom-api.com
llama stack run --config run.yaml
diff --git a/docs/docs/distributions/remote_hosted_distro/watsonx.md b/docs/docs/distributions/remote_hosted_distro/watsonx.md
index 977af90dd..5add678f3 100644
--- a/docs/docs/distributions/remote_hosted_distro/watsonx.md
+++ b/docs/docs/distributions/remote_hosted_distro/watsonx.md
@@ -69,10 +69,10 @@ docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
+ -e WATSONX_API_KEY=$WATSONX_API_KEY \
+ -e WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
+ -e WATSONX_BASE_URL=$WATSONX_BASE_URL \
llamastack/distribution-watsonx \
--config /root/my-run.yaml \
- --port $LLAMA_STACK_PORT \
- --env WATSONX_API_KEY=$WATSONX_API_KEY \
- --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
- --env WATSONX_BASE_URL=$WATSONX_BASE_URL
+ --port $LLAMA_STACK_PORT
```
diff --git a/docs/docs/distributions/self_hosted_distro/dell.md b/docs/docs/distributions/self_hosted_distro/dell.md
index 52d40cf9d..851eac3bf 100644
--- a/docs/docs/distributions/self_hosted_distro/dell.md
+++ b/docs/docs/distributions/self_hosted_distro/dell.md
@@ -129,11 +129,11 @@ docker run -it \
# NOTE: mount the llama-stack / llama-model directories if testing local changes else not needed
-v $HOME/git/llama-stack:/app/llama-stack-source -v $HOME/git/llama-models:/app/llama-models-source \
# localhost/distribution-dell:dev if building / testing locally
- llamastack/distribution-dell\
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env DEH_URL=$DEH_URL \
- --env CHROMA_URL=$CHROMA_URL
+ -e INFERENCE_MODEL=$INFERENCE_MODEL \
+ -e DEH_URL=$DEH_URL \
+ -e CHROMA_URL=$CHROMA_URL \
+ llamastack/distribution-dell \
+ --port $LLAMA_STACK_PORT
```
@@ -154,14 +154,14 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
+ -e INFERENCE_MODEL=$INFERENCE_MODEL \
+ -e DEH_URL=$DEH_URL \
+ -e SAFETY_MODEL=$SAFETY_MODEL \
+ -e DEH_SAFETY_URL=$DEH_SAFETY_URL \
+ -e CHROMA_URL=$CHROMA_URL \
llamastack/distribution-dell \
--config /root/my-run.yaml \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env DEH_URL=$DEH_URL \
- --env SAFETY_MODEL=$SAFETY_MODEL \
- --env DEH_SAFETY_URL=$DEH_SAFETY_URL \
- --env CHROMA_URL=$CHROMA_URL
+ --port $LLAMA_STACK_PORT
```
### Via venv
@@ -170,21 +170,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
```bash
llama stack build --distro dell --image-type venv
-llama stack run dell
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env DEH_URL=$DEH_URL \
- --env CHROMA_URL=$CHROMA_URL
+INFERENCE_MODEL=$INFERENCE_MODEL \
+DEH_URL=$DEH_URL \
+CHROMA_URL=$CHROMA_URL \
+llama stack run dell \
+ --port $LLAMA_STACK_PORT
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
+INFERENCE_MODEL=$INFERENCE_MODEL \
+DEH_URL=$DEH_URL \
+SAFETY_MODEL=$SAFETY_MODEL \
+DEH_SAFETY_URL=$DEH_SAFETY_URL \
+CHROMA_URL=$CHROMA_URL \
llama stack run ./run-with-safety.yaml \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env DEH_URL=$DEH_URL \
- --env SAFETY_MODEL=$SAFETY_MODEL \
- --env DEH_SAFETY_URL=$DEH_SAFETY_URL \
- --env CHROMA_URL=$CHROMA_URL
+ --port $LLAMA_STACK_PORT
```
diff --git a/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md
index 84b85b91c..1c0ef5f6e 100644
--- a/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md
+++ b/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md
@@ -84,9 +84,9 @@ docker run \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
+ -e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
llamastack/distribution-meta-reference-gpu \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
+ --port $LLAMA_STACK_PORT
```
If you are using Llama Stack Safety / Shield APIs, use:
@@ -98,10 +98,10 @@ docker run \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
+ -e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
+ -e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
llamastack/distribution-meta-reference-gpu \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
- --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
+ --port $LLAMA_STACK_PORT
```
### Via venv
@@ -110,16 +110,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash
llama stack build --distro meta-reference-gpu --image-type venv
+INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
llama stack run distributions/meta-reference-gpu/run.yaml \
- --port 8321 \
- --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
+ --port 8321
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
+INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
+SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
- --port 8321 \
- --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
- --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
+ --port 8321
```
diff --git a/docs/docs/distributions/self_hosted_distro/nvidia.md b/docs/docs/distributions/self_hosted_distro/nvidia.md
index 1e52797db..a6e185442 100644
--- a/docs/docs/distributions/self_hosted_distro/nvidia.md
+++ b/docs/docs/distributions/self_hosted_distro/nvidia.md
@@ -129,10 +129,10 @@ docker run \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
+ -e NVIDIA_API_KEY=$NVIDIA_API_KEY \
llamastack/distribution-nvidia \
--config /root/my-run.yaml \
- --port $LLAMA_STACK_PORT \
- --env NVIDIA_API_KEY=$NVIDIA_API_KEY
+ --port $LLAMA_STACK_PORT
```
### Via venv
@@ -142,10 +142,10 @@ If you've set up your local development environment, you can also build the imag
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
llama stack build --distro nvidia --image-type venv
+NVIDIA_API_KEY=$NVIDIA_API_KEY \
+INFERENCE_MODEL=$INFERENCE_MODEL \
llama stack run ./run.yaml \
- --port 8321 \
- --env NVIDIA_API_KEY=$NVIDIA_API_KEY \
- --env INFERENCE_MODEL=$INFERENCE_MODEL
+ --port 8321
```
## Example Notebooks
diff --git a/docs/docs/getting_started/detailed_tutorial.mdx b/docs/docs/getting_started/detailed_tutorial.mdx
index 33786ac0e..e6c22224d 100644
--- a/docs/docs/getting_started/detailed_tutorial.mdx
+++ b/docs/docs/getting_started/detailed_tutorial.mdx
@@ -86,9 +86,9 @@ docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
+ -e OLLAMA_URL=http://host.docker.internal:11434 \
llamastack/distribution-starter \
- --port $LLAMA_STACK_PORT \
- --env OLLAMA_URL=http://host.docker.internal:11434
+ --port $LLAMA_STACK_PORT
```
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
`podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL`
@@ -106,9 +106,9 @@ docker run -it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
--network=host \
+ -e OLLAMA_URL=http://localhost:11434 \
llamastack/distribution-starter \
- --port $LLAMA_STACK_PORT \
- --env OLLAMA_URL=http://localhost:11434
+ --port $LLAMA_STACK_PORT
```
:::
You will see output like below:
diff --git a/docs/docs/providers/inference/remote_anthropic.mdx b/docs/docs/providers/inference/remote_anthropic.mdx
index 96162d25c..44c1fcbb1 100644
--- a/docs/docs/providers/inference/remote_anthropic.mdx
+++ b/docs/docs/providers/inference/remote_anthropic.mdx
@@ -15,6 +15,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | API key for Anthropic models |
## Sample Configuration
diff --git a/docs/docs/providers/inference/remote_azure.mdx b/docs/docs/providers/inference/remote_azure.mdx
index 721fe429c..56a14c100 100644
--- a/docs/docs/providers/inference/remote_azure.mdx
+++ b/docs/docs/providers/inference/remote_azure.mdx
@@ -22,6 +22,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `` | No | | Azure API key for Azure |
| `api_base` | `` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
diff --git a/docs/docs/providers/inference/remote_bedrock.mdx b/docs/docs/providers/inference/remote_bedrock.mdx
index 2a5d1b74d..683ec12f8 100644
--- a/docs/docs/providers/inference/remote_bedrock.mdx
+++ b/docs/docs/providers/inference/remote_bedrock.mdx
@@ -15,6 +15,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
diff --git a/docs/docs/providers/inference/remote_cerebras.mdx b/docs/docs/providers/inference/remote_cerebras.mdx
index 1a543389d..d364b9884 100644
--- a/docs/docs/providers/inference/remote_cerebras.mdx
+++ b/docs/docs/providers/inference/remote_cerebras.mdx
@@ -15,6 +15,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `base_url` | `` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
| `api_key` | `` | No | | Cerebras API Key |
diff --git a/docs/docs/providers/inference/remote_databricks.mdx b/docs/docs/providers/inference/remote_databricks.mdx
index 670f8a7f9..d7b0bd38d 100644
--- a/docs/docs/providers/inference/remote_databricks.mdx
+++ b/docs/docs/providers/inference/remote_databricks.mdx
@@ -15,6 +15,7 @@ Databricks inference provider for running models on Databricks' unified analytic
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
| `api_token` | `` | No | | The Databricks API token |
diff --git a/docs/docs/providers/inference/remote_fireworks.mdx b/docs/docs/providers/inference/remote_fireworks.mdx
index d2c3a664e..cfdfb993c 100644
--- a/docs/docs/providers/inference/remote_fireworks.mdx
+++ b/docs/docs/providers/inference/remote_fireworks.mdx
@@ -15,6 +15,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
diff --git a/docs/docs/providers/inference/remote_gemini.mdx b/docs/docs/providers/inference/remote_gemini.mdx
index 5222eaa89..a13d1c82d 100644
--- a/docs/docs/providers/inference/remote_gemini.mdx
+++ b/docs/docs/providers/inference/remote_gemini.mdx
@@ -15,6 +15,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | API key for Gemini models |
## Sample Configuration
diff --git a/docs/docs/providers/inference/remote_groq.mdx b/docs/docs/providers/inference/remote_groq.mdx
index 77516ed1f..1edb4f9ea 100644
--- a/docs/docs/providers/inference/remote_groq.mdx
+++ b/docs/docs/providers/inference/remote_groq.mdx
@@ -15,6 +15,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | The Groq API key |
| `url` | `` | No | https://api.groq.com | The URL for the Groq AI server |
diff --git a/docs/docs/providers/inference/remote_llama-openai-compat.mdx b/docs/docs/providers/inference/remote_llama-openai-compat.mdx
index bcd50f772..ca5830b09 100644
--- a/docs/docs/providers/inference/remote_llama-openai-compat.mdx
+++ b/docs/docs/providers/inference/remote_llama-openai-compat.mdx
@@ -15,6 +15,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | The Llama API key |
| `openai_compat_api_base` | `` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
diff --git a/docs/docs/providers/inference/remote_nvidia.mdx b/docs/docs/providers/inference/remote_nvidia.mdx
index 348a42e59..6b5e36180 100644
--- a/docs/docs/providers/inference/remote_nvidia.mdx
+++ b/docs/docs/providers/inference/remote_nvidia.mdx
@@ -15,6 +15,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
| `timeout` | `` | No | 60 | Timeout for the HTTP requests |
diff --git a/docs/docs/providers/inference/remote_ollama.mdx b/docs/docs/providers/inference/remote_ollama.mdx
index f075607d8..e00e34e4a 100644
--- a/docs/docs/providers/inference/remote_ollama.mdx
+++ b/docs/docs/providers/inference/remote_ollama.mdx
@@ -15,8 +15,8 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `` | No | http://localhost:11434 | |
-| `refresh_models` | `` | No | False | Whether to refresh models periodically |
## Sample Configuration
diff --git a/docs/docs/providers/inference/remote_openai.mdx b/docs/docs/providers/inference/remote_openai.mdx
index b795d02b1..e0910c809 100644
--- a/docs/docs/providers/inference/remote_openai.mdx
+++ b/docs/docs/providers/inference/remote_openai.mdx
@@ -15,6 +15,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | API key for OpenAI models |
| `base_url` | `` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
diff --git a/docs/docs/providers/inference/remote_passthrough.mdx b/docs/docs/providers/inference/remote_passthrough.mdx
index 58d5619b8..e356384ad 100644
--- a/docs/docs/providers/inference/remote_passthrough.mdx
+++ b/docs/docs/providers/inference/remote_passthrough.mdx
@@ -15,6 +15,7 @@ Passthrough inference provider for connecting to any external inference service
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `` | No | | The URL for the passthrough endpoint |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |
diff --git a/docs/docs/providers/inference/remote_runpod.mdx b/docs/docs/providers/inference/remote_runpod.mdx
index 92cc66eb1..876532029 100644
--- a/docs/docs/providers/inference/remote_runpod.mdx
+++ b/docs/docs/providers/inference/remote_runpod.mdx
@@ -15,6 +15,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
| `api_token` | `str \| None` | No | | The API token |
diff --git a/docs/docs/providers/inference/remote_sambanova.mdx b/docs/docs/providers/inference/remote_sambanova.mdx
index b28471890..9bd7b7613 100644
--- a/docs/docs/providers/inference/remote_sambanova.mdx
+++ b/docs/docs/providers/inference/remote_sambanova.mdx
@@ -15,6 +15,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |
diff --git a/docs/docs/providers/inference/remote_tgi.mdx b/docs/docs/providers/inference/remote_tgi.mdx
index 6ff82cc2b..67fe6d237 100644
--- a/docs/docs/providers/inference/remote_tgi.mdx
+++ b/docs/docs/providers/inference/remote_tgi.mdx
@@ -15,6 +15,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `` | No | | The URL for the TGI serving endpoint |
## Sample Configuration
diff --git a/docs/docs/providers/inference/remote_together.mdx b/docs/docs/providers/inference/remote_together.mdx
index da232a45b..6df2ca866 100644
--- a/docs/docs/providers/inference/remote_together.mdx
+++ b/docs/docs/providers/inference/remote_together.mdx
@@ -15,6 +15,7 @@ Together AI inference provider for open-source models and collaborative AI devel
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
diff --git a/docs/docs/providers/inference/remote_vertexai.mdx b/docs/docs/providers/inference/remote_vertexai.mdx
index 48da6be24..c182ed485 100644
--- a/docs/docs/providers/inference/remote_vertexai.mdx
+++ b/docs/docs/providers/inference/remote_vertexai.mdx
@@ -54,6 +54,7 @@ Available Models:
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `project` | `` | No | | Google Cloud project ID for Vertex AI |
| `location` | `` | No | us-central1 | Google Cloud location for Vertex AI |
diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx
index 598f97b19..fbbd424a3 100644
--- a/docs/docs/providers/inference/remote_vllm.mdx
+++ b/docs/docs/providers/inference/remote_vllm.mdx
@@ -15,11 +15,11 @@ Remote vLLM inference provider for connecting to vLLM servers.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
| `max_tokens` | `` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
-| `refresh_models` | `` | No | False | Whether to refresh models periodically |
## Sample Configuration
diff --git a/docs/docs/providers/inference/remote_watsonx.mdx b/docs/docs/providers/inference/remote_watsonx.mdx
index 8cd3b2869..f081703ab 100644
--- a/docs/docs/providers/inference/remote_watsonx.mdx
+++ b/docs/docs/providers/inference/remote_watsonx.mdx
@@ -15,9 +15,10 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `url` | `` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
-| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key |
-| `project_id` | `str \| None` | No | | The Project ID key |
+| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
+| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
| `timeout` | `` | No | 60 | Timeout for the HTTP requests |
## Sample Configuration
diff --git a/docs/docs/providers/safety/remote_bedrock.mdx b/docs/docs/providers/safety/remote_bedrock.mdx
index 530a208b5..663a761f0 100644
--- a/docs/docs/providers/safety/remote_bedrock.mdx
+++ b/docs/docs/providers/safety/remote_bedrock.mdx
@@ -15,6 +15,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
+| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb
index d7d544ad5..3dcedfed6 100644
--- a/docs/getting_started.ipynb
+++ b/docs/getting_started.ipynb
@@ -123,12 +123,12 @@
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
"\n",
"# this command installs all the dependencies needed for the llama stack server with the together inference provider\n",
- "!uv run --with llama-stack llama stack build --distro together --image-type venv\n",
+ "!uv run --with llama-stack llama stack build --distro together\n",
"\n",
"def run_llama_stack_server_background():\n",
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
" process = subprocess.Popen(\n",
- " \"uv run --with llama-stack llama stack run together --image-type venv\",\n",
+ " \"uv run --with llama-stack llama stack run together\",\n",
" shell=True,\n",
" stdout=log_file,\n",
" stderr=log_file,\n",
diff --git a/docs/getting_started_llama4.ipynb b/docs/getting_started_llama4.ipynb
index cd5f83517..bca505b5e 100644
--- a/docs/getting_started_llama4.ipynb
+++ b/docs/getting_started_llama4.ipynb
@@ -233,12 +233,12 @@
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
"\n",
"# this command installs all the dependencies needed for the llama stack server\n",
- "!uv run --with llama-stack llama stack build --distro meta-reference-gpu --image-type venv\n",
+ "!uv run --with llama-stack llama stack build --distro meta-reference-gpu\n",
"\n",
"def run_llama_stack_server_background():\n",
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
" process = subprocess.Popen(\n",
- " f\"uv run --with llama-stack llama stack run meta-reference-gpu --image-type venv --env INFERENCE_MODEL={model_id}\",\n",
+ " f\"INFERENCE_MODEL={model_id} uv run --with llama-stack llama stack run meta-reference-gpu\",\n",
" shell=True,\n",
" stdout=log_file,\n",
" stderr=log_file,\n",
diff --git a/docs/getting_started_llama_api.ipynb b/docs/getting_started_llama_api.ipynb
index f65566205..7680c4a0c 100644
--- a/docs/getting_started_llama_api.ipynb
+++ b/docs/getting_started_llama_api.ipynb
@@ -223,12 +223,12 @@
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
"\n",
"# this command installs all the dependencies needed for the llama stack server\n",
- "!uv run --with llama-stack llama stack build --distro llama_api --image-type venv\n",
+ "!uv run --with llama-stack llama stack build --distro llama_api\n",
"\n",
"def run_llama_stack_server_background():\n",
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
" process = subprocess.Popen(\n",
- " \"uv run --with llama-stack llama stack run llama_api --image-type venv\",\n",
+ " \"uv run --with llama-stack llama stack run llama_api\",\n",
" shell=True,\n",
" stdout=log_file,\n",
" stderr=log_file,\n",
diff --git a/docs/quick_start.ipynb b/docs/quick_start.ipynb
index c194a901d..eebfd6686 100644
--- a/docs/quick_start.ipynb
+++ b/docs/quick_start.ipynb
@@ -145,12 +145,12 @@
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
"\n",
"# this command installs all the dependencies needed for the llama stack server with the ollama inference provider\n",
- "!uv run --with llama-stack llama stack build --distro starter --image-type venv\n",
+ "!uv run --with llama-stack llama stack build --distro starter\n",
"\n",
"def run_llama_stack_server_background():\n",
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
" process = subprocess.Popen(\n",
- " f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter --image-type venv\n",
+ " f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter\n",
" shell=True,\n",
" stdout=log_file,\n",
" stderr=log_file,\n",
diff --git a/docs/zero_to_hero_guide/README.md b/docs/zero_to_hero_guide/README.md
index 183038a88..1b643d692 100644
--- a/docs/zero_to_hero_guide/README.md
+++ b/docs/zero_to_hero_guide/README.md
@@ -88,7 +88,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
...
Build Successful!
You can find the newly-built template here: ~/.llama/distributions/starter/starter-run.yaml
- You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter --image-type venv
+ You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter
```
3. **Set the ENV variables by exporting them to the terminal**:
@@ -102,12 +102,11 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
3. **Run the Llama Stack**:
Run the stack using uv:
```bash
+ INFERENCE_MODEL=$INFERENCE_MODEL \
+ SAFETY_MODEL=$SAFETY_MODEL \
+ OLLAMA_URL=$OLLAMA_URL \
uv run --with llama-stack llama stack run starter \
- --image-type venv \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env SAFETY_MODEL=$SAFETY_MODEL \
- --env OLLAMA_URL=$OLLAMA_URL
+ --port $LLAMA_STACK_PORT
```
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py
index b14e6fe55..471d5cb66 100644
--- a/llama_stack/cli/stack/_build.py
+++ b/llama_stack/cli/stack/_build.py
@@ -444,12 +444,24 @@ def _run_stack_build_command_from_build_config(
cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
- cprint(
- "You can run the new Llama Stack distro via: "
- + colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
- color="green",
- file=sys.stderr,
- )
+ if build_config.image_type == LlamaStackImageType.VENV:
+ cprint(
+ "You can run the new Llama Stack distro (after activating "
+ + colored(image_name, "cyan")
+ + ") via: "
+ + colored(f"llama stack run {run_config_file}", "blue"),
+ color="green",
+ file=sys.stderr,
+ )
+ elif build_config.image_type == LlamaStackImageType.CONTAINER:
+ cprint(
+ "You can run the container with: "
+ + colored(
+ f"docker run -p 8321:8321 -v ~/.llama:/root/.llama localhost/{image_name} --port 8321", "blue"
+ ),
+ color="green",
+ file=sys.stderr,
+ )
return distro_path
else:
return _generate_run_config(build_config, build_dir, image_name)
diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py
index cec101083..06dae7318 100644
--- a/llama_stack/cli/stack/run.py
+++ b/llama_stack/cli/stack/run.py
@@ -16,7 +16,7 @@ import yaml
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
-from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars, validate_env_pair
+from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.log import get_logger
@@ -55,18 +55,12 @@ class StackRun(Subcommand):
"--image-name",
type=str,
default=None,
- help="Name of the image to run. Defaults to the current environment",
- )
- self.parser.add_argument(
- "--env",
- action="append",
- help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
- metavar="KEY=VALUE",
+ help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
)
self.parser.add_argument(
"--image-type",
type=str,
- help="Image Type used during the build. This can be only venv.",
+ help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
)
self.parser.add_argument(
@@ -75,48 +69,22 @@ class StackRun(Subcommand):
help="Start the UI server",
)
- def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
- """Resolve config file path and distribution name from args.config"""
- from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
-
- if not args.config:
- return None, None
-
- config_file = Path(args.config)
- has_yaml_suffix = args.config.endswith(".yaml")
- distro_name = None
-
- if not config_file.exists() and not has_yaml_suffix:
- # check if this is a distribution
- config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml"
- if config_file.exists():
- distro_name = args.config
-
- if not config_file.exists() and not has_yaml_suffix:
- # check if it's a build config saved to ~/.llama dir
- config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
-
- if not config_file.exists():
- self.parser.error(
- f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
- )
-
- if not config_file.is_file():
- self.parser.error(
- f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
- )
-
- return config_file, distro_name
-
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
from llama_stack.core.configure import parse_and_maybe_upgrade_config
- from llama_stack.core.utils.exec import formulate_run_args, run_command
+
+ if args.image_type or args.image_name:
+ self.parser.error(
+ "The --image-type and --image-name flags are no longer supported.\n\n"
+ "Please activate your virtual environment manually before running `llama stack run`.\n\n"
+ "For example:\n"
+ " source /path/to/venv/bin/activate\n"
+ " llama stack run \n"
+ )
if args.enable_ui:
self._start_ui_development_server(args.port)
- image_type, image_name = args.image_type, args.image_name
if args.config:
try:
@@ -128,10 +96,6 @@ class StackRun(Subcommand):
else:
config_file = None
- # Check if config is required based on image type
- if image_type == ImageType.VENV.value and not config_file:
- self.parser.error("Config file is required for venv environment")
-
if config_file:
logger.info(f"Using run configuration: {config_file}")
@@ -146,50 +110,13 @@ class StackRun(Subcommand):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
- else:
- config = None
- # If neither image type nor image name is provided, assume the server should be run directly
- # using the current environment packages.
- if not image_type and not image_name:
- logger.info("No image type or image name provided. Assuming environment packages.")
- self._uvicorn_run(config_file, args)
- else:
- run_args = formulate_run_args(image_type, image_name)
-
- run_args.extend([str(args.port)])
-
- if config_file:
- run_args.extend(["--config", str(config_file)])
-
- if args.env:
- for env_var in args.env:
- if "=" not in env_var:
- self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
- return
- key, value = env_var.split("=", 1) # split on first = only
- if not key:
- self.parser.error(f"Environment variable '{env_var}' has empty key")
- return
- run_args.extend(["--env", f"{key}={value}"])
-
- run_command(run_args)
+ self._uvicorn_run(config_file, args)
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
if not config_file:
self.parser.error("Config file is required")
- # Set environment variables if provided
- if args.env:
- for env_pair in args.env:
- try:
- key, value = validate_env_pair(env_pair)
- logger.info(f"Setting environment variable {key} => {value}")
- os.environ[key] = value
- except ValueError as e:
- logger.error(f"Error: {str(e)}")
- self.parser.error(f"Invalid environment variable format: {env_pair}")
-
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
diff --git a/llama_stack/core/conversations/conversations.py b/llama_stack/core/conversations/conversations.py
index bef138e69..612b2f68e 100644
--- a/llama_stack/core/conversations/conversations.py
+++ b/llama_stack/core/conversations/conversations.py
@@ -32,7 +32,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import (
sqlstore_impl,
)
-logger = get_logger(name=__name__, category="openai::conversations")
+logger = get_logger(name=__name__, category="openai_conversations")
class ConversationServiceConfig(BaseModel):
diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py
index c4338e614..847f6a2d2 100644
--- a/llama_stack/core/routers/inference.py
+++ b/llama_stack/core/routers/inference.py
@@ -611,7 +611,7 @@ class InferenceRouter(Inference):
completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
- if self.telemetry and chunk.usage:
+ if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py
index 641c73c16..716be936a 100644
--- a/llama_stack/core/routing_tables/models.py
+++ b/llama_stack/core/routing_tables/models.py
@@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try:
models = await provider.list_models()
except Exception as e:
- logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
+ logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
@@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
return self.impls_by_provider_id[model.provider_id]
+ async def has_model(self, model_id: str) -> bool:
+ """
+ Check if a model exists in the routing table.
+
+ :param model_id: The model identifier to check
+ :return: True if the model exists, False otherwise
+ """
+ try:
+ await lookup_model(self, model_id)
+ return True
+ except ModelNotFoundError:
+ return False
+
async def register_model(
self,
model_id: str,
diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py
index d5d55319a..acc02eeff 100644
--- a/llama_stack/core/stack.py
+++ b/llama_stack/core/stack.py
@@ -274,22 +274,6 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
-def validate_env_pair(env_pair: str) -> tuple[str, str]:
- """Validate and split an environment variable key-value pair."""
- try:
- key, value = env_pair.split("=", 1)
- key = key.strip()
- if not key:
- raise ValueError(f"Empty key in environment variable pair: {env_pair}")
- if not all(c.isalnum() or c == "_" for c in key):
- raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
- return key, value
- except ValueError as e:
- raise ValueError(
- f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
- ) from e
-
-
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
diff --git a/llama_stack/core/start_stack.sh b/llama_stack/core/start_stack.sh
index 02b1cd408..cc0ae68d8 100755
--- a/llama_stack/core/start_stack.sh
+++ b/llama_stack/core/start_stack.sh
@@ -25,7 +25,7 @@ error_handler() {
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
- echo "Usage: $0 [--config ] [--env KEY=VALUE]..."
+ echo "Usage: $0 [--config ]"
exit 1
fi
@@ -43,7 +43,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
# Initialize variables
yaml_config=""
-env_vars=""
other_args=""
# Process remaining arguments
@@ -58,15 +57,6 @@ while [[ $# -gt 0 ]]; do
exit 1
fi
;;
- --env)
- if [[ -n "$2" ]]; then
- env_vars="$env_vars --env $2"
- shift 2
- else
- echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
- exit 1
- fi
- ;;
*)
other_args="$other_args $1"
shift
@@ -119,7 +109,6 @@ if [[ "$env_type" == "venv" ]]; then
llama stack run \
$yaml_config_arg \
--port "$port" \
- $env_vars \
$other_args
elif [[ "$env_type" == "container" ]]; then
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py
index 624dbd176..0486553d5 100644
--- a/llama_stack/core/store/registry.py
+++ b/llama_stack/core/store/registry.py
@@ -98,7 +98,10 @@ class DiskDistributionRegistry(DistributionRegistry):
existing_obj = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id == obj.provider_id:
- return False
+ raise ValueError(
+ f"Provider '{obj.provider_id}' is already registered."
+ f"Unregister the existing provider first before registering it again."
+ )
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
diff --git a/llama_stack/distributions/dell/doc_template.md b/llama_stack/distributions/dell/doc_template.md
index fcec3ea14..852e78d0e 100644
--- a/llama_stack/distributions/dell/doc_template.md
+++ b/llama_stack/distributions/dell/doc_template.md
@@ -117,11 +117,11 @@ docker run -it \
# NOTE: mount the llama-stack directory if testing local changes else not needed
-v $HOME/git/llama-stack:/app/llama-stack-source \
# localhost/distribution-dell:dev if building / testing locally
+ -e INFERENCE_MODEL=$INFERENCE_MODEL \
+ -e DEH_URL=$DEH_URL \
+ -e CHROMA_URL=$CHROMA_URL \
llamastack/distribution-{{ name }}\
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env DEH_URL=$DEH_URL \
- --env CHROMA_URL=$CHROMA_URL
+ --port $LLAMA_STACK_PORT
```
@@ -142,14 +142,14 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
+ -e INFERENCE_MODEL=$INFERENCE_MODEL \
+ -e DEH_URL=$DEH_URL \
+ -e SAFETY_MODEL=$SAFETY_MODEL \
+ -e DEH_SAFETY_URL=$DEH_SAFETY_URL \
+ -e CHROMA_URL=$CHROMA_URL \
llamastack/distribution-{{ name }} \
--config /root/my-run.yaml \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env DEH_URL=$DEH_URL \
- --env SAFETY_MODEL=$SAFETY_MODEL \
- --env DEH_SAFETY_URL=$DEH_SAFETY_URL \
- --env CHROMA_URL=$CHROMA_URL
+ --port $LLAMA_STACK_PORT
```
### Via Conda
@@ -158,21 +158,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
```bash
llama stack build --distro {{ name }} --image-type conda
-llama stack run {{ name }}
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env DEH_URL=$DEH_URL \
- --env CHROMA_URL=$CHROMA_URL
+INFERENCE_MODEL=$INFERENCE_MODEL \
+DEH_URL=$DEH_URL \
+CHROMA_URL=$CHROMA_URL \
+llama stack run {{ name }} \
+ --port $LLAMA_STACK_PORT
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
+INFERENCE_MODEL=$INFERENCE_MODEL \
+DEH_URL=$DEH_URL \
+SAFETY_MODEL=$SAFETY_MODEL \
+DEH_SAFETY_URL=$DEH_SAFETY_URL \
+CHROMA_URL=$CHROMA_URL \
llama stack run ./run-with-safety.yaml \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=$INFERENCE_MODEL \
- --env DEH_URL=$DEH_URL \
- --env SAFETY_MODEL=$SAFETY_MODEL \
- --env DEH_SAFETY_URL=$DEH_SAFETY_URL \
- --env CHROMA_URL=$CHROMA_URL
+ --port $LLAMA_STACK_PORT
```
diff --git a/llama_stack/distributions/meta-reference-gpu/doc_template.md b/llama_stack/distributions/meta-reference-gpu/doc_template.md
index 602d053c4..92dcc6102 100644
--- a/llama_stack/distributions/meta-reference-gpu/doc_template.md
+++ b/llama_stack/distributions/meta-reference-gpu/doc_template.md
@@ -72,9 +72,9 @@ docker run \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
+ -e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
llamastack/distribution-{{ name }} \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
+ --port $LLAMA_STACK_PORT
```
If you are using Llama Stack Safety / Shield APIs, use:
@@ -86,10 +86,10 @@ docker run \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
+ -e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
+ -e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
llamastack/distribution-{{ name }} \
- --port $LLAMA_STACK_PORT \
- --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
- --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
+ --port $LLAMA_STACK_PORT
```
### Via venv
@@ -98,16 +98,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash
llama stack build --distro {{ name }} --image-type venv
+INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
llama stack run distributions/{{ name }}/run.yaml \
- --port 8321 \
- --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
+ --port 8321
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
+INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
+SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
llama stack run distributions/{{ name }}/run-with-safety.yaml \
- --port 8321 \
- --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
- --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
+ --port 8321
```
diff --git a/llama_stack/distributions/nvidia/doc_template.md b/llama_stack/distributions/nvidia/doc_template.md
index fbee17ef8..df2b68ef7 100644
--- a/llama_stack/distributions/nvidia/doc_template.md
+++ b/llama_stack/distributions/nvidia/doc_template.md
@@ -118,10 +118,10 @@ docker run \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
+ -e NVIDIA_API_KEY=$NVIDIA_API_KEY \
llamastack/distribution-{{ name }} \
--config /root/my-run.yaml \
- --port $LLAMA_STACK_PORT \
- --env NVIDIA_API_KEY=$NVIDIA_API_KEY
+ --port $LLAMA_STACK_PORT
```
### Via venv
@@ -131,10 +131,10 @@ If you've set up your local development environment, you can also build the imag
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
llama stack build --distro nvidia --image-type venv
+NVIDIA_API_KEY=$NVIDIA_API_KEY \
+INFERENCE_MODEL=$INFERENCE_MODEL \
llama stack run ./run.yaml \
- --port 8321 \
- --env NVIDIA_API_KEY=$NVIDIA_API_KEY \
- --env INFERENCE_MODEL=$INFERENCE_MODEL
+ --port 8321
```
## Example Notebooks
diff --git a/llama_stack/distributions/watsonx/__init__.py b/llama_stack/distributions/watsonx/__init__.py
index 756f351d8..078d86144 100644
--- a/llama_stack/distributions/watsonx/__init__.py
+++ b/llama_stack/distributions/watsonx/__init__.py
@@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+
+from .watsonx import get_distribution_template # noqa: F401
diff --git a/llama_stack/distributions/watsonx/build.yaml b/llama_stack/distributions/watsonx/build.yaml
index bf4be7eaf..06349a741 100644
--- a/llama_stack/distributions/watsonx/build.yaml
+++ b/llama_stack/distributions/watsonx/build.yaml
@@ -3,44 +3,33 @@ distribution_spec:
description: Use watsonx for running LLM inference
providers:
inference:
- - provider_id: watsonx
- provider_type: remote::watsonx
- - provider_id: sentence-transformers
- provider_type: inline::sentence-transformers
+ - provider_type: remote::watsonx
+ - provider_type: inline::sentence-transformers
vector_io:
- - provider_id: faiss
- provider_type: inline::faiss
+ - provider_type: inline::faiss
safety:
- - provider_id: llama-guard
- provider_type: inline::llama-guard
+ - provider_type: inline::llama-guard
agents:
- - provider_id: meta-reference
- provider_type: inline::meta-reference
+ - provider_type: inline::meta-reference
telemetry:
- - provider_id: meta-reference
- provider_type: inline::meta-reference
+ - provider_type: inline::meta-reference
eval:
- - provider_id: meta-reference
- provider_type: inline::meta-reference
+ - provider_type: inline::meta-reference
datasetio:
- - provider_id: huggingface
- provider_type: remote::huggingface
- - provider_id: localfs
- provider_type: inline::localfs
+ - provider_type: remote::huggingface
+ - provider_type: inline::localfs
scoring:
- - provider_id: basic
- provider_type: inline::basic
- - provider_id: llm-as-judge
- provider_type: inline::llm-as-judge
- - provider_id: braintrust
- provider_type: inline::braintrust
+ - provider_type: inline::basic
+ - provider_type: inline::llm-as-judge
+ - provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
+ files:
+ - provider_type: inline::localfs
image_type: venv
additional_pip_packages:
+- aiosqlite
- sqlalchemy[asyncio]
-- aiosqlite
-- aiosqlite
diff --git a/llama_stack/distributions/watsonx/run.yaml b/llama_stack/distributions/watsonx/run.yaml
index 92f367910..e0c337f9d 100644
--- a/llama_stack/distributions/watsonx/run.yaml
+++ b/llama_stack/distributions/watsonx/run.yaml
@@ -4,13 +4,13 @@ apis:
- agents
- datasetio
- eval
+- files
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
-- files
providers:
inference:
- provider_id: watsonx
@@ -19,8 +19,6 @@ providers:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=}
- - provider_id: sentence-transformers
- provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
@@ -48,7 +46,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
- sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
+ sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
@@ -109,102 +107,7 @@ metadata_store:
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
-models:
-- metadata: {}
- model_id: meta-llama/llama-3-3-70b-instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-3-70b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-3.3-70B-Instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-3-70b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/llama-2-13b-chat
- provider_id: watsonx
- provider_model_id: meta-llama/llama-2-13b-chat
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-2-13b
- provider_id: watsonx
- provider_model_id: meta-llama/llama-2-13b-chat
- model_type: llm
-- metadata: {}
- model_id: meta-llama/llama-3-1-70b-instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-1-70b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-3.1-70B-Instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-1-70b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/llama-3-1-8b-instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-1-8b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-3.1-8B-Instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-1-8b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/llama-3-2-11b-vision-instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/llama-3-2-1b-instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-2-1b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-3.2-1B-Instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-2-1b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/llama-3-2-3b-instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-2-3b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-3.2-3B-Instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-2-3b-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/llama-3-2-90b-vision-instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
- provider_id: watsonx
- provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
- model_type: llm
-- metadata: {}
- model_id: meta-llama/llama-guard-3-11b-vision
- provider_id: watsonx
- provider_model_id: meta-llama/llama-guard-3-11b-vision
- model_type: llm
-- metadata: {}
- model_id: meta-llama/Llama-Guard-3-11B-Vision
- provider_id: watsonx
- provider_model_id: meta-llama/llama-guard-3-11b-vision
- model_type: llm
-- metadata:
- embedding_dimension: 384
- model_id: all-MiniLM-L6-v2
- provider_id: sentence-transformers
- model_type: embedding
+models: []
shields: []
vector_dbs: []
datasets: []
diff --git a/llama_stack/distributions/watsonx/watsonx.py b/llama_stack/distributions/watsonx/watsonx.py
index c3cab5d1b..645770612 100644
--- a/llama_stack/distributions/watsonx/watsonx.py
+++ b/llama_stack/distributions/watsonx/watsonx.py
@@ -4,17 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from pathlib import Path
-from llama_stack.apis.models import ModelType
-from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
-from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
+from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
+from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
-from llama_stack.providers.inline.inference.sentence_transformers import (
- SentenceTransformersInferenceConfig,
-)
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
-from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
@@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
config=WatsonXConfig.sample_run_config(),
)
- embedding_provider = Provider(
- provider_id="sentence-transformers",
- provider_type="inline::sentence-transformers",
- config=SentenceTransformersInferenceConfig.sample_run_config(),
- )
-
- available_models = {
- "watsonx": MODEL_ENTRIES,
- }
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
@@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
),
]
- embedding_model = ModelInput(
- model_id="all-MiniLM-L6-v2",
- provider_id="sentence-transformers",
- model_type=ModelType.embedding,
- metadata={
- "embedding_dimension": 384,
- },
- )
-
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
- default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use watsonx for running LLM inference",
container_image=None,
- template_path=Path(__file__).parent / "doc_template.md",
+ template_path=None,
providers=providers,
- available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
- "inference": [inference_provider, embedding_provider],
+ "inference": [inference_provider],
"files": [files_provider],
},
- default_models=default_models + [embedding_model],
+ default_models=[],
default_tool_groups=default_tool_groups,
),
},
diff --git a/llama_stack/log.py b/llama_stack/log.py
index 8aee4c9a9..ce92219f4 100644
--- a/llama_stack/log.py
+++ b/llama_stack/log.py
@@ -31,12 +31,17 @@ CATEGORIES = [
"client",
"telemetry",
"openai_responses",
+ "openai_conversations",
"testing",
"providers",
"models",
"files",
"vector_io",
"tool_runtime",
+ "cli",
+ "post_training",
+ "scoring",
+ "tests",
]
UNCATEGORIZED = "uncategorized"
@@ -264,11 +269,12 @@ def get_logger(
if root_category in _category_levels:
log_level = _category_levels[root_category]
else:
- log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
if category != UNCATEGORIZED:
- logging.warning(
- f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}"
+ raise ValueError(
+ f"Unknown logging category: {category}. To resolve, choose a valid category from the CATEGORIES list "
+ f"or add it to the CATEGORIES list. Available categories: {CATEGORIES}"
)
+ log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
logger.setLevel(log_level)
return logging.LoggerAdapter(logger, {"category": category})
diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py
index 6191df61a..16e4068d7 100644
--- a/llama_stack/models/llama/prompt_format.py
+++ b/llama_stack/models/llama/prompt_format.py
@@ -11,19 +11,13 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
-import json
import textwrap
-from pathlib import Path
from pydantic import BaseModel, Field
from llama_stack.models.llama.datatypes import (
RawContent,
- RawMediaItem,
RawMessage,
- RawTextItem,
- StopReason,
- ToolCall,
ToolPromptFormat,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
@@ -175,25 +169,6 @@ def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat
return messages
-def llama3_1_builtin_tool_call_with_image_dialog(
- tool_prompt_format=ToolPromptFormat.json,
-):
- this_dir = Path(__file__).parent
- with open(this_dir / "llama3/dog.jpg", "rb") as f:
- img = f.read()
-
- interface = LLama31Interface(tool_prompt_format)
-
- messages = interface.system_messages(**system_message_builtin_tools_only())
- messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")])
- messages += interface.assistant_response_messages(
- "Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix",
- StopReason.end_of_turn,
- )
- messages += interface.user_message("Search the web for some food recommendations for the indentified breed")
- return messages
-
-
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
interface = LLama31Interface(tool_prompt_format)
@@ -202,35 +177,6 @@ def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
return messages
-def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
- tool_response = json.dumps(["great song1", "awesome song2", "cool song3"])
- interface = LLama31Interface(tool_prompt_format)
-
- messages = interface.system_messages(**system_message_custom_tools_only())
- messages += interface.user_message(content="Use tools to get latest trending songs")
- messages.append(
- RawMessage(
- role="assistant",
- content="",
- stop_reason=StopReason.end_of_message,
- tool_calls=[
- ToolCall(
- call_id="call_id",
- tool_name="trending_songs",
- arguments={"n": "10", "genre": "latest"},
- )
- ],
- ),
- )
- messages.append(
- RawMessage(
- role="assistant",
- content=tool_response,
- )
- )
- return messages
-
-
def llama3_2_user_assistant_conversation():
return UseCase(
title="User and assistant conversation",
diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py
index 334c32e15..37b0b50c8 100644
--- a/llama_stack/providers/inline/agents/meta_reference/__init__.py
+++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py
@@ -22,6 +22,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.tool_runtime],
deps[Api.tool_groups],
policy,
+ Api.telemetry in deps,
)
await impl.initialize()
return impl
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index 207f0daec..b17c720e9 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -7,8 +7,6 @@
import copy
import json
import re
-import secrets
-import string
import uuid
import warnings
from collections.abc import AsyncGenerator
@@ -84,11 +82,6 @@ from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin
-
-def make_random_string(length: int = 8):
- return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
-
-
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
@@ -110,6 +103,7 @@ class ChatAgent(ShieldRunnerMixin):
persistence_store: KVStore,
created_at: str,
policy: list[AccessRule],
+ telemetry_enabled: bool = False,
):
self.agent_id = agent_id
self.agent_config = agent_config
@@ -120,6 +114,7 @@ class ChatAgent(ShieldRunnerMixin):
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at
+ self.telemetry_enabled = telemetry_enabled
ShieldRunnerMixin.__init__(
self,
@@ -188,28 +183,30 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
turn_id = str(uuid.uuid4())
- span = tracing.get_current_span()
- if span:
- span.set_attribute("session_id", request.session_id)
- span.set_attribute("agent_id", self.agent_id)
- span.set_attribute("request", request.model_dump_json())
- span.set_attribute("turn_id", turn_id)
- if self.agent_config.name:
- span.set_attribute("agent_name", self.agent_config.name)
+ if self.telemetry_enabled:
+ span = tracing.get_current_span()
+ if span is not None:
+ span.set_attribute("session_id", request.session_id)
+ span.set_attribute("agent_id", self.agent_id)
+ span.set_attribute("request", request.model_dump_json())
+ span.set_attribute("turn_id", turn_id)
+ if self.agent_config.name:
+ span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
- span = tracing.get_current_span()
- if span:
- span.set_attribute("agent_id", self.agent_id)
- span.set_attribute("session_id", request.session_id)
- span.set_attribute("request", request.model_dump_json())
- span.set_attribute("turn_id", request.turn_id)
- if self.agent_config.name:
- span.set_attribute("agent_name", self.agent_config.name)
+ if self.telemetry_enabled:
+ span = tracing.get_current_span()
+ if span is not None:
+ span.set_attribute("agent_id", self.agent_id)
+ span.set_attribute("session_id", request.session_id)
+ span.set_attribute("request", request.model_dump_json())
+ span.set_attribute("turn_id", request.turn_id)
+ if self.agent_config.name:
+ span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools()
async for chunk in self._run_turn(request):
@@ -395,9 +392,12 @@ class ChatAgent(ShieldRunnerMixin):
touchpoint: str,
) -> AsyncGenerator:
async with tracing.span("run_shields") as span:
- span.set_attribute("input", [m.model_dump_json() for m in messages])
+ if self.telemetry_enabled and span is not None:
+ span.set_attribute("input", [m.model_dump_json() for m in messages])
+ if len(shields) == 0:
+ span.set_attribute("output", "no shields")
+
if len(shields) == 0:
- span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
@@ -430,7 +430,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
- span.set_attribute("output", e.violation.model_dump_json())
+ if self.telemetry_enabled and span is not None:
+ span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
@@ -453,7 +454,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
- span.set_attribute("output", "no violations")
+ if self.telemetry_enabled and span is not None:
+ span.set_attribute("output", "no violations")
async def _run(
self,
@@ -518,8 +520,9 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason: StopReason | None = None
async with tracing.span("inference") as span:
- if self.agent_config.name:
- span.set_attribute("agent_name", self.agent_config.name)
+ if self.telemetry_enabled and span is not None:
+ if self.agent_config.name:
+ span.set_attribute("agent_name", self.agent_config.name)
def _serialize_nested(value):
"""Recursively serialize nested Pydantic models to dicts."""
@@ -637,18 +640,19 @@ class ChatAgent(ShieldRunnerMixin):
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
- span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
- span.set_attribute(
- "input",
- json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
- )
- output_attr = json.dumps(
- {
- "content": content,
- "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
- }
- )
- span.set_attribute("output", output_attr)
+ if self.telemetry_enabled and span is not None:
+ span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
+ span.set_attribute(
+ "input",
+ json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
+ )
+ output_attr = json.dumps(
+ {
+ "content": content,
+ "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
+ }
+ )
+ span.set_attribute("output", output_attr)
n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
@@ -756,7 +760,9 @@ class ChatAgent(ShieldRunnerMixin):
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
- },
+ }
+ if self.telemetry_enabled
+ else {},
) as span:
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_result = await self.execute_tool_call_maybe(
@@ -771,7 +777,8 @@ class ChatAgent(ShieldRunnerMixin):
call_id=tool_call.call_id,
content=tool_result.content,
)
- span.set_attribute("output", result_message.model_dump_json())
+ if self.telemetry_enabled and span is not None:
+ span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step
tool_execution_step = ToolExecutionStep(
diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py
index 5431e8f28..cfaf56a34 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agents.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agents.py
@@ -64,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
policy: list[AccessRule],
+ telemetry_enabled: bool = False,
):
self.config = config
self.inference_api = inference_api
@@ -71,6 +72,7 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
+ self.telemetry_enabled = telemetry_enabled
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
@@ -135,6 +137,7 @@ class MetaReferenceAgentsImpl(Agents):
),
created_at=agent_info.created_at,
policy=self.policy,
+ telemetry_enabled=self.telemetry_enabled,
)
async def create_agent_session(
diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py
index 8ccdcb0e1..245203f10 100644
--- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py
+++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py
@@ -269,7 +269,7 @@ class OpenAIResponsesImpl:
response_tools=tools,
temperature=temperature,
response_format=response_format,
- inputs=input,
+ inputs=all_input,
)
# Create orchestrator and delegate streaming logic
diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py
index 0bb524f5c..895d13a7f 100644
--- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py
+++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py
@@ -97,6 +97,8 @@ class StreamingResponseOrchestrator:
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
+ # mapping for annotations
+ self.citation_files: dict[str, str] = {}
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
@@ -126,6 +128,7 @@ class StreamingResponseOrchestrator:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
+ logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
@@ -160,7 +163,7 @@ class StreamingResponseOrchestrator:
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
- output_messages.append(await convert_chat_choice_to_response_message(choice))
+ output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
@@ -172,6 +175,8 @@ class StreamingResponseOrchestrator:
):
yield stream_event
+ messages = next_turn_messages
+
if not function_tool_calls and not non_function_tool_calls:
break
@@ -184,9 +189,7 @@ class StreamingResponseOrchestrator:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
break
- messages = next_turn_messages
-
- self.final_messages = messages.copy() + [current_response.choices[0].message]
+ self.final_messages = messages.copy()
# Create final response
final_response = OpenAIResponseObject(
@@ -211,6 +214,8 @@ class StreamingResponseOrchestrator:
for choice in current_response.choices:
next_turn_messages.append(choice.message)
+ logger.debug(f"Choice message content: {choice.message.content}")
+ logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
@@ -227,9 +232,11 @@ class StreamingResponseOrchestrator:
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
+ next_turn_messages.pop()
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
+ next_turn_messages.pop()
else:
non_function_tool_calls.append(tool_call)
@@ -470,6 +477,8 @@ class StreamingResponseOrchestrator:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
+ if result.citation_files:
+ self.citation_files.update(result.citation_files)
if tool_call_log:
output_messages.append(tool_call_log)
diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py
index b028c018b..b33b47454 100644
--- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py
+++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py
@@ -94,7 +94,10 @@ class ToolExecutor:
# Yield the final result
yield ToolExecutionResult(
- sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
+ sequence_number=sequence_number,
+ final_output_message=output_message,
+ final_input_message=input_message,
+ citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
)
async def _execute_knowledge_search_via_vector_store(
@@ -129,8 +132,6 @@ class ToolExecutor:
for results in all_results:
search_results.extend(results)
- # Convert search results to tool result format matching memory.py
- # Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
@@ -138,27 +139,58 @@ class ToolExecutor:
)
)
+ unique_files = set()
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
- metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
+ # Get file_id from attributes if result_item.file_id is empty
+ file_id = result_item.file_id or (
+ result_item.attributes.get("document_id") if result_item.attributes else None
+ )
+ metadata_text = f"document_id: {file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
- text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
+
+ text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
+ unique_files.add(file_id)
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
+
+ citation_instruction = ""
+ if unique_files:
+ citation_instruction = (
+ " Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
+ "Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
+ )
+
content_items.append(
TextContentItem(
- text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
+ text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
)
)
+ # handling missing attributes for old versions
+ citation_files = {}
+ for result in search_results:
+ file_id = result.file_id
+ if not file_id and result.attributes:
+ file_id = result.attributes.get("document_id")
+
+ filename = result.filename
+ if not filename and result.attributes:
+ filename = result.attributes.get("filename")
+ if not filename:
+ filename = "unknown"
+
+ citation_files[file_id] = filename
+
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
+ "citation_files": citation_files,
},
)
diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py
index d3b5a16bd..fd5f44242 100644
--- a/llama_stack/providers/inline/agents/meta_reference/responses/types.py
+++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py
@@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel):
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
+ citation_files: dict[str, str] | None = None
@dataclass
diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py
index 310a88298..5b013b9c4 100644
--- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py
+++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py
@@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import re
import uuid
from llama_stack.apis.agents.openai_responses import (
+ OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
@@ -45,7 +47,9 @@ from llama_stack.apis.inference import (
)
-async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
+async def convert_chat_choice_to_response_message(
+ choice: OpenAIChoice, citation_files: dict[str, str] | None = None
+) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
@@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
+ annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
+
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
- content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
+ content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
status="completed",
role="assistant",
)
@@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str):
return role_to_type.get(role)
+def _extract_citations_from_text(
+ text: str, citation_files: dict[str, str]
+) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
+ """Extract citation markers from text and create annotations
+
+ Args:
+ text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
+ citation_files: Dictionary mapping file_id to filename
+
+ Returns:
+ Tuple of (annotations_list, clean_text_without_markers)
+ """
+ file_id_regex = re.compile(r"<\|(?Pfile-[A-Za-z0-9_-]+)\|>")
+
+ annotations = []
+ parts = []
+ total_len = 0
+ last_end = 0
+
+ for m in file_id_regex.finditer(text):
+ # segment before the marker
+ prefix = text[last_end : m.start()]
+
+ # drop one space if it exists (since marker is at sentence end)
+ if prefix.endswith(" "):
+ prefix = prefix[:-1]
+
+ parts.append(prefix)
+ total_len += len(prefix)
+
+ fid = m.group(1)
+ if fid in citation_files:
+ annotations.append(
+ OpenAIResponseAnnotationFileCitation(
+ file_id=fid,
+ filename=citation_files[fid],
+ index=total_len, # index points to punctuation
+ )
+ )
+
+ last_end = m.end()
+
+ parts.append(text[last_end:])
+ cleaned_text = "".join(parts)
+ return annotations, cleaned_text
+
+
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py
index c8499a9b8..3ccfd0bcb 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py
@@ -8,8 +8,6 @@ import asyncio
import base64
import io
import mimetypes
-import secrets
-import string
from typing import Any
import httpx
@@ -52,10 +50,6 @@ from .context_retriever import generate_rag_query
log = get_logger(name=__name__, category="tool_runtime")
-def make_random_string(length: int = 8):
- return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
-
-
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
@@ -331,5 +325,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return ToolInvocationResult(
content=result.content or [],
- metadata=result.metadata,
+ metadata={
+ **(result.metadata or {}),
+ "citation_files": getattr(result, "citation_files", None),
+ },
)
diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py
index 405c134e5..5a456c7c9 100644
--- a/llama_stack/providers/inline/vector_io/faiss/faiss.py
+++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py
@@ -225,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
- # Cleanup if needed
- pass
+ # Clean up mixin resources (file batch tasks)
+ await super().shutdown()
async def health(self) -> HealthResponse:
"""
diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
index 26231a9b7..a433257b2 100644
--- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
+++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
@@ -434,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
- # nothing to do since we don't maintain a persistent connection
- pass
+ # Clean up mixin resources (file batch tasks)
+ await super().shutdown()
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py
index 57110d129..bc46b4de2 100644
--- a/llama_stack/providers/registry/agents.py
+++ b/llama_stack/providers/registry/agents.py
@@ -36,6 +36,9 @@ def available_providers() -> list[ProviderSpec]:
Api.tool_runtime,
Api.tool_groups,
],
+ optional_api_dependencies=[
+ Api.telemetry,
+ ],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
),
]
diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py
index bf6a09b6c..f89565892 100644
--- a/llama_stack/providers/registry/inference.py
+++ b/llama_stack/providers/registry/inference.py
@@ -268,7 +268,7 @@ Available Models:
api=Api.inference,
adapter_type="watsonx",
provider_type="remote::watsonx",
- pip_packages=["ibm_watsonx_ai"],
+ pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py
index ad8c31dfd..39dc7fccd 100644
--- a/llama_stack/providers/registry/tool_runtime.py
+++ b/llama_stack/providers/registry/tool_runtime.py
@@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import (
ProviderSpec,
RemoteProviderSpec,
)
+from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
def available_providers() -> list[ProviderSpec]:
@@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
- pip_packages=[
- "chardet",
- "pypdf",
+ pip_packages=DEFAULT_VECTOR_IO_DEPS
+ + [
"tqdm",
"numpy",
"scikit-learn",
diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py
index ebab7aaf9..da2a68535 100644
--- a/llama_stack/providers/registry/vector_io.py
+++ b/llama_stack/providers/registry/vector_io.py
@@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec,
)
+# Common dependencies for all vector IO providers that support document processing
+DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"]
+
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::meta-reference",
- pip_packages=["faiss-cpu"],
+ pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
@@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::faiss",
- pip_packages=["faiss-cpu"],
+ pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
@@ -82,7 +85,7 @@ more details about Faiss in general.
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite-vec",
- pip_packages=["sqlite-vec"],
+ pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
@@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
- pip_packages=["sqlite-vec"],
+ pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
@@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation.
api=Api.vector_io,
adapter_type="chromadb",
provider_type="remote::chromadb",
- pip_packages=["chromadb-client"],
+ pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::chromadb",
- pip_packages=["chromadb"],
+ pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
api=Api.vector_io,
adapter_type="pgvector",
provider_type="remote::pgvector",
- pip_packages=["psycopg2-binary"],
+ pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
@@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
- pip_packages=["weaviate-client>=4.16.5"],
+ pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
@@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::qdrant",
- pip_packages=["qdrant-client"],
+ pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.qdrant",
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
@@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
api=Api.vector_io,
adapter_type="qdrant",
provider_type="remote::qdrant",
- pip_packages=["qdrant-client"],
+ pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
@@ -607,7 +610,7 @@ Please refer to the inline provider documentation.
api=Api.vector_io,
adapter_type="milvus",
provider_type="remote::milvus",
- pip_packages=["pymilvus>=2.4.10"],
+ pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
@@ -813,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
- pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
+ pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index f4ad1be94..200b36171 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -41,9 +41,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
).serving_endpoints.list() # TODO: this is not async
]
- async def should_refresh_models(self) -> bool:
- return False
-
async def openai_completion(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py
deleted file mode 100644
index 0b0d7fcf3..000000000
--- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py
+++ /dev/null
@@ -1,217 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-import warnings
-from collections.abc import AsyncGenerator
-from typing import Any
-
-from openai import AsyncStream
-from openai.types.chat.chat_completion import (
- Choice as OpenAIChoice,
-)
-from openai.types.completion import Completion as OpenAICompletion
-from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
-
-from llama_stack.apis.inference import (
- ChatCompletionRequest,
- CompletionRequest,
- CompletionResponse,
- CompletionResponseStreamChunk,
- GreedySamplingStrategy,
- JsonSchemaResponseFormat,
- TokenLogProbs,
- TopKSamplingStrategy,
- TopPSamplingStrategy,
-)
-from llama_stack.providers.utils.inference.openai_compat import (
- _convert_openai_finish_reason,
- convert_message_to_openai_dict_new,
- convert_tooldef_to_openai_tool,
-)
-
-
-async def convert_chat_completion_request(
- request: ChatCompletionRequest,
- n: int = 1,
-) -> dict:
- """
- Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
- """
- # model -> model
- # messages -> messages
- # sampling_params TODO(mattf): review strategy
- # strategy=greedy -> nvext.top_k = -1, temperature = temperature
- # strategy=top_p -> nvext.top_k = -1, top_p = top_p
- # strategy=top_k -> nvext.top_k = top_k
- # temperature -> temperature
- # top_p -> top_p
- # top_k -> nvext.top_k
- # max_tokens -> max_tokens
- # repetition_penalty -> nvext.repetition_penalty
- # response_format -> GrammarResponseFormat TODO(mf)
- # response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
- # tools -> tools
- # tool_choice ("auto", "required") -> tool_choice
- # tool_prompt_format -> TBD
- # stream -> stream
- # logprobs -> logprobs
-
- if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat):
- raise ValueError(
- f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported."
- )
-
- nvext = {}
- payload: dict[str, Any] = dict(
- model=request.model,
- messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
- stream=request.stream,
- n=n,
- extra_body=dict(nvext=nvext),
- extra_headers={
- b"User-Agent": b"llama-stack: nvidia-inference-adapter",
- },
- )
-
- if request.response_format:
- # server bug - setting guided_json changes the behavior of response_format resulting in an error
- # payload.update(response_format="json_object")
- nvext.update(guided_json=request.response_format.json_schema)
-
- if request.tools:
- payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools])
- if request.tool_config.tool_choice:
- payload.update(
- tool_choice=request.tool_config.tool_choice.value
- ) # we cannot include tool_choice w/o tools, server will complain
-
- if request.logprobs:
- payload.update(logprobs=True)
- payload.update(top_logprobs=request.logprobs.top_k)
-
- if request.sampling_params:
- nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
-
- if request.sampling_params.max_tokens:
- payload.update(max_tokens=request.sampling_params.max_tokens)
-
- strategy = request.sampling_params.strategy
- if isinstance(strategy, TopPSamplingStrategy):
- nvext.update(top_k=-1)
- payload.update(top_p=strategy.top_p)
- payload.update(temperature=strategy.temperature)
- elif isinstance(strategy, TopKSamplingStrategy):
- if strategy.top_k != -1 and strategy.top_k < 1:
- warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
- nvext.update(top_k=strategy.top_k)
- elif isinstance(strategy, GreedySamplingStrategy):
- nvext.update(top_k=-1)
- else:
- raise ValueError(f"Unsupported sampling strategy: {strategy}")
-
- return payload
-
-
-def convert_completion_request(
- request: CompletionRequest,
- n: int = 1,
-) -> dict:
- """
- Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
- """
- # model -> model
- # prompt -> prompt
- # sampling_params TODO(mattf): review strategy
- # strategy=greedy -> nvext.top_k = -1, temperature = temperature
- # strategy=top_p -> nvext.top_k = -1, top_p = top_p
- # strategy=top_k -> nvext.top_k = top_k
- # temperature -> temperature
- # top_p -> top_p
- # top_k -> nvext.top_k
- # max_tokens -> max_tokens
- # repetition_penalty -> nvext.repetition_penalty
- # response_format -> nvext.guided_json
- # stream -> stream
- # logprobs.top_k -> logprobs
-
- nvext = {}
- payload: dict[str, Any] = dict(
- model=request.model,
- prompt=request.content,
- stream=request.stream,
- extra_body=dict(nvext=nvext),
- extra_headers={
- b"User-Agent": b"llama-stack: nvidia-inference-adapter",
- },
- n=n,
- )
-
- if request.response_format:
- # this is not openai compliant, it is a nim extension
- nvext.update(guided_json=request.response_format.json_schema)
-
- if request.logprobs:
- payload.update(logprobs=request.logprobs.top_k)
-
- if request.sampling_params:
- nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
-
- if request.sampling_params.max_tokens:
- payload.update(max_tokens=request.sampling_params.max_tokens)
-
- if request.sampling_params.strategy == "top_p":
- nvext.update(top_k=-1)
- payload.update(top_p=request.sampling_params.top_p)
- elif request.sampling_params.strategy == "top_k":
- if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
- warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
- nvext.update(top_k=request.sampling_params.top_k)
- elif request.sampling_params.strategy == "greedy":
- nvext.update(top_k=-1)
- payload.update(temperature=request.sampling_params.temperature)
-
- return payload
-
-
-def _convert_openai_completion_logprobs(
- logprobs: OpenAICompletionLogprobs | None,
-) -> list[TokenLogProbs] | None:
- """
- Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
- """
- if not logprobs:
- return None
-
- return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs]
-
-
-def convert_openai_completion_choice(
- choice: OpenAIChoice,
-) -> CompletionResponse:
- """
- Convert an OpenAI Completion Choice into a CompletionResponse.
- """
- return CompletionResponse(
- content=choice.text,
- stop_reason=_convert_openai_finish_reason(choice.finish_reason),
- logprobs=_convert_openai_completion_logprobs(choice.logprobs),
- )
-
-
-async def convert_openai_completion_stream(
- stream: AsyncStream[OpenAICompletion],
-) -> AsyncGenerator[CompletionResponse, None]:
- """
- Convert a stream of OpenAI Completions into a stream
- of ChatCompletionResponseStreamChunks.
- """
- async for chunk in stream:
- choice = chunk.choices[0]
- yield CompletionResponseStreamChunk(
- delta=choice.text,
- stop_reason=_convert_openai_finish_reason(choice.finish_reason),
- logprobs=_convert_openai_completion_logprobs(choice.logprobs),
- )
diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py
index b8431e859..46ee939d9 100644
--- a/llama_stack/providers/remote/inference/nvidia/utils.py
+++ b/llama_stack/providers/remote/inference/nvidia/utils.py
@@ -4,53 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-import httpx
-
-from llama_stack.log import get_logger
-
from . import NVIDIAConfig
-logger = get_logger(name=__name__, category="inference::nvidia")
-
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
-
-
-async def _get_health(url: str) -> tuple[bool, bool]:
- """
- Query {url}/v1/health/{live,ready} to check if the server is running and ready
-
- Args:
- url (str): URL of the server
-
- Returns:
- Tuple[bool, bool]: (is_live, is_ready)
- """
- async with httpx.AsyncClient() as client:
- live = await client.get(f"{url}/v1/health/live")
- ready = await client.get(f"{url}/v1/health/ready")
- return live.status_code == 200, ready.status_code == 200
-
-
-async def check_health(config: NVIDIAConfig) -> None:
- """
- Check if the server is running and ready
-
- Args:
- url (str): URL of the server
-
- Raises:
- RuntimeError: If the server is not running or ready
- """
- if not _is_nvidia_hosted(config):
- logger.info("Checking NVIDIA NIM health...")
- try:
- is_live, is_ready = await _get_health(config.url)
- if not is_live:
- raise ConnectionError("NVIDIA NIM is not running")
- if not is_ready:
- raise ConnectionError("NVIDIA NIM is not ready")
- # TODO(mf): should we wait for the server to be ready?
- except httpx.ConnectError as e:
- raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py
index d2f104e1e..1e4ce9113 100644
--- a/llama_stack/providers/remote/inference/ollama/config.py
+++ b/llama_stack/providers/remote/inference/ollama/config.py
@@ -6,8 +6,6 @@
from typing import Any
-from pydantic import Field
-
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
@@ -15,10 +13,6 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(RemoteInferenceProviderConfig):
url: str = DEFAULT_OLLAMA_URL
- refresh_models: bool = Field(
- default=False,
- description="Whether to refresh models periodically",
- )
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index e5b08997c..67d0caa54 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -72,9 +72,6 @@ class OllamaInferenceAdapter(OpenAIMixin):
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
)
- async def should_refresh_models(self) -> bool:
- return self.config.refresh_models
-
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
diff --git a/llama_stack/providers/remote/inference/runpod/__init__.py b/llama_stack/providers/remote/inference/runpod/__init__.py
index 69bf95046..d1fd2b718 100644
--- a/llama_stack/providers/remote/inference/runpod/__init__.py
+++ b/llama_stack/providers/remote/inference/runpod/__init__.py
@@ -11,6 +11,6 @@ async def get_adapter_impl(config: RunpodImplConfig, _deps):
from .runpod import RunpodInferenceAdapter
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
- impl = RunpodInferenceAdapter(config)
+ impl = RunpodInferenceAdapter(config=config)
await impl.initialize()
return impl
diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py
index 08652f8c0..f752740e5 100644
--- a/llama_stack/providers/remote/inference/runpod/runpod.py
+++ b/llama_stack/providers/remote/inference/runpod/runpod.py
@@ -4,69 +4,86 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+from typing import Any
-from llama_stack.apis.inference import * # noqa: F403
-from llama_stack.apis.inference import OpenAIEmbeddingsResponse
-
-# from llama_stack.providers.datatypes import ModelsProtocolPrivate
-from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
-from llama_stack.providers.utils.inference.openai_compat import (
- get_sampling_options,
-)
-from llama_stack.providers.utils.inference.prompt_adapter import (
- chat_completion_request_to_prompt,
+from llama_stack.apis.inference import (
+ OpenAIMessageParam,
+ OpenAIResponseFormatParam,
)
+from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import RunpodImplConfig
-# https://docs.runpod.io/serverless/vllm/overview#compatible-models
-# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
-RUNPOD_SUPPORTED_MODELS = {
- "Llama3.1-8B": "meta-llama/Llama-3.1-8B",
- "Llama3.1-70B": "meta-llama/Llama-3.1-70B",
- "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
- "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
- "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
- "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
- "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
- "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
- "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
- "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
- "Llama3.2-1B": "meta-llama/Llama-3.2-1B",
- "Llama3.2-3B": "meta-llama/Llama-3.2-3B",
-}
-SAFETY_MODELS_ENTRIES = []
+class RunpodInferenceAdapter(OpenAIMixin):
+ """
+ Adapter for RunPod's OpenAI-compatible API endpoints.
+ Supports VLLM for serverless endpoint self-hosted or public endpoints.
+ Can work with any runpod endpoints that support OpenAI-compatible API
+ """
-# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
-MODEL_ENTRIES = [
- build_hf_repo_model_entry(provider_model_id, model_descriptor)
- for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
-] + SAFETY_MODELS_ENTRIES
+ config: RunpodImplConfig
+ def get_api_key(self) -> str:
+ """Get API key for OpenAI client."""
+ return self.config.api_token
-class RunpodInferenceAdapter(
- ModelRegistryHelper,
- Inference,
-):
- def __init__(self, config: RunpodImplConfig) -> None:
- ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
- self.config = config
+ def get_base_url(self) -> str:
+ """Get base URL for OpenAI client."""
+ return self.config.url
- def _get_params(self, request: ChatCompletionRequest) -> dict:
- return {
- "model": self.map_to_provider_model(request.model),
- "prompt": chat_completion_request_to_prompt(request),
- "stream": request.stream,
- **get_sampling_options(request.sampling_params),
- }
-
- async def openai_embeddings(
+ async def openai_chat_completion(
self,
model: str,
- input: str | list[str],
- encoding_format: str | None = "float",
- dimensions: int | None = None,
+ messages: list[OpenAIMessageParam],
+ frequency_penalty: float | None = None,
+ function_call: str | dict[str, Any] | None = None,
+ functions: list[dict[str, Any]] | None = None,
+ logit_bias: dict[str, float] | None = None,
+ logprobs: bool | None = None,
+ max_completion_tokens: int | None = None,
+ max_tokens: int | None = None,
+ n: int | None = None,
+ parallel_tool_calls: bool | None = None,
+ presence_penalty: float | None = None,
+ response_format: OpenAIResponseFormatParam | None = None,
+ seed: int | None = None,
+ stop: str | list[str] | None = None,
+ stream: bool | None = None,
+ stream_options: dict[str, Any] | None = None,
+ temperature: float | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
+ tools: list[dict[str, Any]] | None = None,
+ top_logprobs: int | None = None,
+ top_p: float | None = None,
user: str | None = None,
- ) -> OpenAIEmbeddingsResponse:
- raise NotImplementedError()
+ ):
+ """Override to add RunPod-specific stream_options requirement."""
+ if stream and not stream_options:
+ stream_options = {"include_usage": True}
+
+ return await super().openai_chat_completion(
+ model=model,
+ messages=messages,
+ frequency_penalty=frequency_penalty,
+ function_call=function_call,
+ functions=functions,
+ logit_bias=logit_bias,
+ logprobs=logprobs,
+ max_completion_tokens=max_completion_tokens,
+ max_tokens=max_tokens,
+ n=n,
+ parallel_tool_calls=parallel_tool_calls,
+ presence_penalty=presence_penalty,
+ response_format=response_format,
+ seed=seed,
+ stop=stop,
+ stream=stream,
+ stream_options=stream_options,
+ temperature=temperature,
+ tool_choice=tool_choice,
+ tools=tools,
+ top_logprobs=top_logprobs,
+ top_p=top_p,
+ user=user,
+ )
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index fbefe630f..224de6721 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -63,9 +63,6 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
return [m.id for m in await self._get_client().models.list()]
- async def should_refresh_models(self) -> bool:
- return True
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py
index 86ef3fe26..87c5408d3 100644
--- a/llama_stack/providers/remote/inference/vllm/config.py
+++ b/llama_stack/providers/remote/inference/vllm/config.py
@@ -30,10 +30,6 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
- refresh_models: bool = Field(
- default=False,
- description="Whether to refresh models periodically",
- )
@field_validator("tls_verify")
@classmethod
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 4e7884cd2..310eaf7b6 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -53,10 +53,6 @@ class VLLMInferenceAdapter(OpenAIMixin):
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)
- async def should_refresh_models(self) -> bool:
- # Strictly respecting the refresh_models directive
- return self.config.refresh_models
-
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote vLLM server.
diff --git a/llama_stack/providers/remote/inference/watsonx/__init__.py b/llama_stack/providers/remote/inference/watsonx/__init__.py
index e59e873b6..35e74a720 100644
--- a/llama_stack/providers/remote/inference/watsonx/__init__.py
+++ b/llama_stack/providers/remote/inference/watsonx/__init__.py
@@ -4,19 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from llama_stack.apis.inference import Inference
-
from .config import WatsonXConfig
-async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
- # import dynamically so `llama stack build` does not fail due to missing dependencies
+async def get_adapter_impl(config: WatsonXConfig, _deps):
+ # import dynamically so the import is used only when it is needed
from .watsonx import WatsonXInferenceAdapter
- if not isinstance(config, WatsonXConfig):
- raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config)
return adapter
-
-
-__all__ = ["get_adapter_impl", "WatsonXConfig"]
diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py
index 4bc0173c4..9e98d4003 100644
--- a/llama_stack/providers/remote/inference/watsonx/config.py
+++ b/llama_stack/providers/remote/inference/watsonx/config.py
@@ -7,16 +7,18 @@
import os
from typing import Any
-from pydantic import BaseModel, Field, SecretStr
+from pydantic import BaseModel, ConfigDict, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
- url: str
- api_key: str
- project_id: str
+ model_config = ConfigDict(
+ from_attributes=True,
+ extra="forbid",
+ )
+ watsonx_api_key: str | None
@json_schema_type
@@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
+ # This seems like it should be required, but none of the other remote inference
+ # providers require it, so this is optional here too for consistency.
+ # The OpenAIConfig uses default=None instead, so this is following that precedent.
api_key: SecretStr | None = Field(
- default_factory=lambda: os.getenv("WATSONX_API_KEY"),
- description="The watsonx API key",
+ default=None,
+ description="The watsonx.ai API key",
)
+ # As above, this is optional here too for consistency.
project_id: str | None = Field(
- default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
- description="The Project ID key",
+ default=None,
+ description="The watsonx.ai project ID",
)
timeout: int = Field(
default=60,
diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py
deleted file mode 100644
index d98f0510a..000000000
--- a/llama_stack/providers/remote/inference/watsonx/models.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-from llama_stack.models.llama.sku_types import CoreModelId
-from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
-
-MODEL_ENTRIES = [
- build_hf_repo_model_entry(
- "meta-llama/llama-3-3-70b-instruct",
- CoreModelId.llama3_3_70b_instruct.value,
- ),
- build_hf_repo_model_entry(
- "meta-llama/llama-2-13b-chat",
- CoreModelId.llama2_13b.value,
- ),
- build_hf_repo_model_entry(
- "meta-llama/llama-3-1-70b-instruct",
- CoreModelId.llama3_1_70b_instruct.value,
- ),
- build_hf_repo_model_entry(
- "meta-llama/llama-3-1-8b-instruct",
- CoreModelId.llama3_1_8b_instruct.value,
- ),
- build_hf_repo_model_entry(
- "meta-llama/llama-3-2-11b-vision-instruct",
- CoreModelId.llama3_2_11b_vision_instruct.value,
- ),
- build_hf_repo_model_entry(
- "meta-llama/llama-3-2-1b-instruct",
- CoreModelId.llama3_2_1b_instruct.value,
- ),
- build_hf_repo_model_entry(
- "meta-llama/llama-3-2-3b-instruct",
- CoreModelId.llama3_2_3b_instruct.value,
- ),
- build_hf_repo_model_entry(
- "meta-llama/llama-3-2-90b-vision-instruct",
- CoreModelId.llama3_2_90b_vision_instruct.value,
- ),
- build_hf_repo_model_entry(
- "meta-llama/llama-guard-3-11b-vision",
- CoreModelId.llama_guard_3_11b_vision.value,
- ),
-]
diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py
index fc58691e2..d04472936 100644
--- a/llama_stack/providers/remote/inference/watsonx/watsonx.py
+++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py
@@ -4,240 +4,120 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
-from ibm_watsonx_ai.foundation_models import Model
-from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
-from openai import AsyncOpenAI
+import requests
-from llama_stack.apis.inference import (
- ChatCompletionRequest,
- CompletionRequest,
- GreedySamplingStrategy,
- Inference,
- OpenAIChatCompletion,
- OpenAIChatCompletionChunk,
- OpenAICompletion,
- OpenAIEmbeddingsResponse,
- OpenAIMessageParam,
- OpenAIResponseFormatParam,
- TopKSamplingStrategy,
- TopPSamplingStrategy,
-)
-from llama_stack.log import get_logger
-from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
-from llama_stack.providers.utils.inference.openai_compat import (
- prepare_openai_completion_params,
-)
-from llama_stack.providers.utils.inference.prompt_adapter import (
- chat_completion_request_to_prompt,
- completion_request_to_prompt,
- request_has_media,
-)
-
-from . import WatsonXConfig
-from .models import MODEL_ENTRIES
-
-logger = get_logger(name=__name__, category="inference::watsonx")
+from llama_stack.apis.inference import ChatCompletionRequest
+from llama_stack.apis.models import Model
+from llama_stack.apis.models.models import ModelType
+from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
+from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
-# Note on structured output
-# WatsonX returns responses with a json embedded into a string.
-# Examples:
+class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
+ _model_cache: dict[str, Model] = {}
-# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
-# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
-# Not even a valid JSON, but we can still extract the JSON from the content
+ def __init__(self, config: WatsonXConfig):
+ LiteLLMOpenAIMixin.__init__(
+ self,
+ litellm_provider_name="watsonx",
+ api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
+ provider_data_api_key_field="watsonx_api_key",
+ )
+ self.available_models = None
+ self.config = config
-# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
-# "year_born": "1963", "year_retired": "2003"\\}}$')
-# Find the start of the boxed content
+ def get_base_url(self) -> str:
+ return self.config.url
+ async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
+ # Get base parameters from parent
+ params = await super()._get_params(request)
-class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
- def __init__(self, config: WatsonXConfig) -> None:
- ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
-
- logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
- self._config = config
- self._openai_client: AsyncOpenAI | None = None
-
- self._project_id = self._config.project_id
-
- def _get_client(self, model_id) -> Model:
- config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
- config_url = self._config.url
- project_id = self._config.project_id
- credentials = {"url": config_url, "apikey": config_api_key}
-
- return Model(model_id=model_id, credentials=credentials, project_id=project_id)
-
- def _get_openai_client(self) -> AsyncOpenAI:
- if not self._openai_client:
- self._openai_client = AsyncOpenAI(
- base_url=f"{self._config.url}/openai/v1",
- api_key=self._config.api_key,
- )
- return self._openai_client
-
- async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
- input_dict = {"params": {}}
- media_present = request_has_media(request)
- llama_model = self.get_llama_model(request.model)
- if isinstance(request, ChatCompletionRequest):
- input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
- else:
- assert not media_present, "Together does not support media for Completion requests"
- input_dict["prompt"] = await completion_request_to_prompt(request)
- if request.sampling_params:
- if request.sampling_params.strategy:
- input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
- if request.sampling_params.max_tokens:
- input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
- if request.sampling_params.repetition_penalty:
- input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
-
- if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
- input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
- input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
- if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
- input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
- if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
- input_dict["params"][GenParams.TEMPERATURE] = 0.0
-
- input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
-
- params = {
- **input_dict,
- }
+ # Add watsonx.ai specific parameters
+ params["project_id"] = self.config.project_id
+ params["time_limit"] = self.config.timeout
return params
- async def openai_embeddings(
- self,
- model: str,
- input: str | list[str],
- encoding_format: str | None = "float",
- dimensions: int | None = None,
- user: str | None = None,
- ) -> OpenAIEmbeddingsResponse:
- raise NotImplementedError()
+ # Copied from OpenAIMixin
+ async def check_model_availability(self, model: str) -> bool:
+ """
+ Check if a specific model is available from the provider's /v1/models.
- async def openai_completion(
- self,
- model: str,
- prompt: str | list[str] | list[int] | list[list[int]],
- best_of: int | None = None,
- echo: bool | None = None,
- frequency_penalty: float | None = None,
- logit_bias: dict[str, float] | None = None,
- logprobs: bool | None = None,
- max_tokens: int | None = None,
- n: int | None = None,
- presence_penalty: float | None = None,
- seed: int | None = None,
- stop: str | list[str] | None = None,
- stream: bool | None = None,
- stream_options: dict[str, Any] | None = None,
- temperature: float | None = None,
- top_p: float | None = None,
- user: str | None = None,
- guided_choice: list[str] | None = None,
- prompt_logprobs: int | None = None,
- suffix: str | None = None,
- ) -> OpenAICompletion:
- model_obj = await self.model_store.get_model(model)
- params = await prepare_openai_completion_params(
- model=model_obj.provider_resource_id,
- prompt=prompt,
- best_of=best_of,
- echo=echo,
- frequency_penalty=frequency_penalty,
- logit_bias=logit_bias,
- logprobs=logprobs,
- max_tokens=max_tokens,
- n=n,
- presence_penalty=presence_penalty,
- seed=seed,
- stop=stop,
- stream=stream,
- stream_options=stream_options,
- temperature=temperature,
- top_p=top_p,
- user=user,
- )
- return await self._get_openai_client().completions.create(**params) # type: ignore
+ :param model: The model identifier to check.
+ :return: True if the model is available dynamically, False otherwise.
+ """
+ if not self._model_cache:
+ await self.list_models()
+ return model in self._model_cache
- async def openai_chat_completion(
- self,
- model: str,
- messages: list[OpenAIMessageParam],
- frequency_penalty: float | None = None,
- function_call: str | dict[str, Any] | None = None,
- functions: list[dict[str, Any]] | None = None,
- logit_bias: dict[str, float] | None = None,
- logprobs: bool | None = None,
- max_completion_tokens: int | None = None,
- max_tokens: int | None = None,
- n: int | None = None,
- parallel_tool_calls: bool | None = None,
- presence_penalty: float | None = None,
- response_format: OpenAIResponseFormatParam | None = None,
- seed: int | None = None,
- stop: str | list[str] | None = None,
- stream: bool | None = None,
- stream_options: dict[str, Any] | None = None,
- temperature: float | None = None,
- tool_choice: str | dict[str, Any] | None = None,
- tools: list[dict[str, Any]] | None = None,
- top_logprobs: int | None = None,
- top_p: float | None = None,
- user: str | None = None,
- ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
- model_obj = await self.model_store.get_model(model)
- params = await prepare_openai_completion_params(
- model=model_obj.provider_resource_id,
- messages=messages,
- frequency_penalty=frequency_penalty,
- function_call=function_call,
- functions=functions,
- logit_bias=logit_bias,
- logprobs=logprobs,
- max_completion_tokens=max_completion_tokens,
- max_tokens=max_tokens,
- n=n,
- parallel_tool_calls=parallel_tool_calls,
- presence_penalty=presence_penalty,
- response_format=response_format,
- seed=seed,
- stop=stop,
- stream=stream,
- stream_options=stream_options,
- temperature=temperature,
- tool_choice=tool_choice,
- tools=tools,
- top_logprobs=top_logprobs,
- top_p=top_p,
- user=user,
- )
- if params.get("stream", False):
- return self._stream_openai_chat_completion(params)
- return await self._get_openai_client().chat.completions.create(**params) # type: ignore
+ async def list_models(self) -> list[Model] | None:
+ self._model_cache = {}
+ models = []
+ for model_spec in self._get_model_specs():
+ functions = [f["id"] for f in model_spec.get("functions", [])]
+ # Format: {"embedding_dimension": 1536, "context_length": 8192}
- async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
- # watsonx.ai sometimes adds usage data to the stream
- include_usage = False
- if params.get("stream_options", None):
- include_usage = params["stream_options"].get("include_usage", False)
- stream = await self._get_openai_client().chat.completions.create(**params)
+ # Example of an embedding model:
+ # {'model_id': 'ibm/granite-embedding-278m-multilingual',
+ # 'label': 'granite-embedding-278m-multilingual',
+ # 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
+ # ...
+ provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
+ if "embedding" in functions:
+ embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
+ context_length = model_spec["model_limits"]["max_sequence_length"]
+ embedding_metadata = {
+ "embedding_dimension": embedding_dimension,
+ "context_length": context_length,
+ }
+ model = Model(
+ identifier=model_spec["model_id"],
+ provider_resource_id=provider_resource_id,
+ provider_id=self.__provider_id__,
+ metadata=embedding_metadata,
+ model_type=ModelType.embedding,
+ )
+ self._model_cache[provider_resource_id] = model
+ models.append(model)
+ if "text_chat" in functions:
+ model = Model(
+ identifier=model_spec["model_id"],
+ provider_resource_id=provider_resource_id,
+ provider_id=self.__provider_id__,
+ metadata={},
+ model_type=ModelType.llm,
+ )
+ # In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
+ # In that case, the cache will record the generator Model object, and the list which we return will have
+ # both the generator Model object and the text chat Model object. That's fine because the cache is
+ # only used for check_model_availability() anyway.
+ self._model_cache[provider_resource_id] = model
+ models.append(model)
+ return models
- seen_finish_reason = False
- async for chunk in stream:
- # Final usage chunk with no choices that the user didn't request, so discard
- if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
- break
- yield chunk
- for choice in chunk.choices:
- if choice.finish_reason:
- seen_finish_reason = True
- break
+ # LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
+ # So we need to implement our own method to list models by calling the watsonx.ai API.
+ def _get_model_specs(self) -> list[dict[str, Any]]:
+ """
+ Retrieves foundation model specifications from the watsonx.ai API.
+ """
+ url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
+ headers = {
+ # Note that there is no authorization header. Listing models does not require authentication.
+ "Content-Type": "application/json",
+ }
+
+ response = requests.get(url, headers=headers)
+
+ # --- Process the Response ---
+ # Raise an exception for bad status codes (4xx or 5xx)
+ response.raise_for_status()
+
+ # If the request is successful, parse and return the JSON response.
+ # The response should contain a list of model specifications
+ response_data = response.json()
+ if "resources" not in response_data:
+ raise ValueError("Resources not found in response")
+ return response_data["resources"]
diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py
index 511123d6e..331e5432e 100644
--- a/llama_stack/providers/remote/vector_io/chroma/chroma.py
+++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py
@@ -167,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
- pass
+ # Clean up mixin resources (file batch tasks)
+ await super().shutdown()
async def register_vector_db(
self,
diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py
index 0acc90595..029eacfe3 100644
--- a/llama_stack/providers/remote/vector_io/milvus/milvus.py
+++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py
@@ -349,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def shutdown(self) -> None:
self.client.close()
+ # Clean up mixin resources (file batch tasks)
+ await super().shutdown()
async def register_vector_db(
self,
diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
index dfdfef6eb..21c388b1d 100644
--- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
+++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
@@ -390,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
if self.conn is not None:
self.conn.close()
log.info("Connection to PGVector database server closed")
+ # Clean up mixin resources (file batch tasks)
+ await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store
diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
index 6b386840c..021938afd 100644
--- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
+++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
@@ -191,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def shutdown(self) -> None:
await self.client.close()
+ # Clean up mixin resources (file batch tasks)
+ await super().shutdown()
async def register_vector_db(
self,
diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
index 54ac6f8d3..21df3bc45 100644
--- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
+++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
@@ -347,6 +347,8 @@ class WeaviateVectorIOAdapter(
async def shutdown(self) -> None:
for client in self.client_cache.values():
client.close()
+ # Clean up mixin resources (file batch tasks)
+ await super().shutdown()
async def register_vector_db(
self,
diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
index 6c8f61c3b..6bef97dd5 100644
--- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py
+++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
@@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import base64
+import struct
from collections.abc import AsyncIterator
from typing import Any
@@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
+ OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
@@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
- b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
get_sampling_options,
@@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin(
return False
return model in litellm.models_by_provider[self.litellm_provider_name]
+
+
+def b64_encode_openai_embeddings_response(
+ response_data: list[dict], encoding_format: str | None = "float"
+) -> list[OpenAIEmbeddingData]:
+ """
+ Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
+ """
+ data = []
+ for i, embedding_data in enumerate(response_data):
+ if encoding_format == "base64":
+ byte_array = bytearray()
+ for embedding_value in embedding_data["embedding"]:
+ byte_array.extend(struct.pack("f", float(embedding_value)))
+
+ response_embedding = base64.b64encode(byte_array).decode("utf-8")
+ else:
+ response_embedding = embedding_data["embedding"]
+ data.append(
+ OpenAIEmbeddingData(
+ embedding=response_embedding,
+ index=i,
+ )
+ )
+ return data
diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py
index 4913c2e1f..9d42d68c6 100644
--- a/llama_stack/providers/utils/inference/model_registry.py
+++ b/llama_stack/providers/utils/inference/model_registry.py
@@ -24,6 +24,10 @@ class RemoteInferenceProviderConfig(BaseModel):
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)
+ refresh_models: bool = Field(
+ default=False,
+ description="Whether to refresh models periodically from the provider",
+ )
# TODO: this class is more confusing than useful right now. We need to make it
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index d863eb53a..7e465a14c 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -3,9 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-import base64
import json
-import struct
import time
import uuid
import warnings
@@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
- OpenAIEmbeddingData,
OpenAIMessageParam,
OpenAIResponseFormatParam,
SamplingParams,
@@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
params["user"] = user
return params
-
-
-def b64_encode_openai_embeddings_response(
- response_data: dict, encoding_format: str | None = "float"
-) -> list[OpenAIEmbeddingData]:
- """
- Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
- """
- data = []
- for i, embedding_data in enumerate(response_data):
- if encoding_format == "base64":
- byte_array = bytearray()
- for embedding_value in embedding_data.embedding:
- byte_array.extend(struct.pack("f", float(embedding_value)))
-
- response_embedding = base64.b64encode(byte_array).decode("utf-8")
- else:
- response_embedding = embedding_data.embedding
- data.append(
- OpenAIEmbeddingData(
- embedding=response_embedding,
- index=i,
- )
- )
- return data
diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py
index 9137013ee..cba7508a2 100644
--- a/llama_stack/providers/utils/inference/openai_mixin.py
+++ b/llama_stack/providers/utils/inference/openai_mixin.py
@@ -474,17 +474,23 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
async def check_model_availability(self, model: str) -> bool:
"""
- Check if a specific model is available from the provider's /v1/models.
+ Check if a specific model is available from the provider's /v1/models or pre-registered.
:param model: The model identifier to check.
- :return: True if the model is available dynamically, False otherwise.
+ :return: True if the model is available dynamically or pre-registered, False otherwise.
"""
+ # First check if the model is pre-registered in the model store
+ if hasattr(self, "model_store") and self.model_store:
+ if await self.model_store.has_model(model):
+ return True
+
+ # Then check the provider's dynamic model cache
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def should_refresh_models(self) -> bool:
- return False
+ return self.config.refresh_models
#
# The model_dump implementations are to avoid serializing the extra fields,
diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
index 0d0aa25a4..c179eba6c 100644
--- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
+++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
@@ -293,6 +293,18 @@ class OpenAIVectorStoreMixin(ABC):
await self._resume_incomplete_batches()
self._last_file_batch_cleanup_time = 0
+ async def shutdown(self) -> None:
+ """Clean up mixin resources including background tasks."""
+ # Cancel any running file batch tasks gracefully
+ tasks_to_cancel = list(self._file_batch_tasks.items())
+ for _, task in tasks_to_cancel:
+ if not task.done():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
@abstractmethod
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a vector store."""
@@ -587,7 +599,7 @@ class OpenAIVectorStoreMixin(ABC):
content = self._chunk_to_vector_store_content(chunk)
response_data_item = VectorStoreSearchResponse(
- file_id=chunk.metadata.get("file_id", ""),
+ file_id=chunk.metadata.get("document_id", ""),
filename=chunk.metadata.get("filename", ""),
score=score,
attributes=chunk.metadata,
@@ -746,12 +758,15 @@ class OpenAIVectorStoreMixin(ABC):
content = content_from_data_and_mime_type(content_response.body, mime_type)
+ chunk_attributes = attributes.copy()
+ chunk_attributes["filename"] = file_response.filename
+
chunks = make_overlapped_chunks(
file_id,
content,
max_chunk_size_tokens,
chunk_overlap_tokens,
- attributes,
+ chunk_attributes,
)
if not chunks:
vector_store_file_object.status = "failed"
diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py
index 857fbe910..c0534a875 100644
--- a/llama_stack/providers/utils/memory/vector_store.py
+++ b/llama_stack/providers/utils/memory/vector_store.py
@@ -20,7 +20,6 @@ from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
- TextContentItem,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
@@ -129,26 +128,6 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
return ""
-def concat_interleaved_content(content: list[InterleavedContent]) -> InterleavedContent:
- """concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list"""
-
- ret = []
-
- def _process(c):
- if isinstance(c, str):
- ret.append(TextContentItem(text=c))
- elif isinstance(c, list):
- for item in c:
- _process(item)
- else:
- ret.append(c)
-
- for c in content:
- _process(c)
-
- return ret
-
-
async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
diff --git a/scripts/install.sh b/scripts/install.sh
index f6fbc259c..571468dc5 100755
--- a/scripts/install.sh
+++ b/scripts/install.sh
@@ -221,8 +221,8 @@ fi
cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
--network llama-net \
-p "${PORT}:${PORT}" \
- "${SERVER_IMAGE}" --port "${PORT}" \
- --env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}")
+ -e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
+ "${SERVER_IMAGE}" --port "${PORT}")
log "🦙 Starting Llama Stack..."
if ! execute_with_log $ENGINE "${cmd[@]}"; then
diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh
index b009ad696..4ae73f170 100755
--- a/scripts/integration-tests.sh
+++ b/scripts/integration-tests.sh
@@ -191,9 +191,11 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then
echo "Llama Stack Server is already running, skipping start"
else
echo "=== Starting Llama Stack Server ==="
- # Set a reasonable log width for better readability in server.log
export LLAMA_STACK_LOG_WIDTH=120
- nohup llama stack run ci-tests --image-type venv > server.log 2>&1 &
+
+ # remove "server:" from STACK_CONFIG
+ stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
+ nohup llama stack run $stack_config > server.log 2>&1 &
echo "Waiting for Llama Stack Server to start..."
for i in {1..30}; do
diff --git a/scripts/telemetry/setup_telemetry.sh b/scripts/telemetry/setup_telemetry.sh
index e0b57a354..ecdd56175 100755
--- a/scripts/telemetry/setup_telemetry.sh
+++ b/scripts/telemetry/setup_telemetry.sh
@@ -16,10 +16,19 @@
set -Eeuo pipefail
-CONTAINER_RUNTIME=${CONTAINER_RUNTIME:-docker}
-SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+if command -v podman &> /dev/null; then
+ CONTAINER_RUNTIME="podman"
+elif command -v docker &> /dev/null; then
+ CONTAINER_RUNTIME="docker"
+else
+ echo "🚨 Neither Podman nor Docker could be found"
+ echo "Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
+ exit 1
+fi
-echo "🚀 Setting up telemetry stack for Llama Stack using Podman..."
+echo "🚀 Setting up telemetry stack for Llama Stack using $CONTAINER_RUNTIME..."
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then
echo "🚨 $CONTAINER_RUNTIME could not be found"
diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py
index 54a9dd72e..a1c3d1e95 100644
--- a/tests/unit/distribution/routers/test_routing_tables.py
+++ b/tests/unit/distribution/routers/test_routing_tables.py
@@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry):
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
assert non_existent is None
+ # Test has_model
+ assert await table.has_model("test_provider/test-model")
+ assert await table.has_model("test_provider/test-model-2")
+ assert not await table.has_model("non-existent-model")
+ assert not await table.has_model("test_provider/non-existent-model")
+
await table.unregister_model(model_id="test_provider/test-model")
await table.unregister_model(model_id="test_provider/test-model-2")
diff --git a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py
index 187540f82..2698b88c8 100644
--- a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py
+++ b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py
@@ -8,6 +8,7 @@
import pytest
from llama_stack.apis.agents.openai_responses import (
+ OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
@@ -35,6 +36,7 @@ from llama_stack.apis.inference import (
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
+ _extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
@@ -340,3 +342,26 @@ class TestIsFunctionToolCall:
result = is_function_tool_call(tool_call, tools)
assert result is False
+
+
+class TestExtractCitationsFromText:
+ def test_extract_citations_and_annotations(self):
+ text = "Start [not-a-file]. New source <|file-abc123|>. "
+ text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
+ file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
+
+ annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
+
+ expected_annotations = [
+ OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
+ OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
+ OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
+ ]
+ expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
+
+ assert cleaned_text == expected_clean_text
+ assert annotations == expected_annotations
+ # OpenAI cites at the end of the sentence
+ assert cleaned_text[expected_annotations[0].index] == "."
+ assert cleaned_text[expected_annotations[1].index] == "?"
+ assert cleaned_text[expected_annotations[2].index] == "!"
diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py
index d30b5b12a..55a6793c2 100644
--- a/tests/unit/providers/inference/test_inference_client_caching.py
+++ b/tests/unit/providers/inference/test_inference_client_caching.py
@@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
+from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
+from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.mark.parametrize(
@@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.client.api_key == api_key
+
+
+@pytest.mark.parametrize(
+ "config_cls,adapter_cls,provider_data_validator",
+ [
+ (
+ WatsonXConfig,
+ WatsonXInferenceAdapter,
+ "llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
+ ),
+ ],
+)
+def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
+ """Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
+ assumption that there is an OpenAI-compatible client object."""
+
+ inference_adapter = adapter_cls(config=config_cls())
+
+ inference_adapter.__provider_spec__ = MagicMock()
+ inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
+
+ for api_key in ["test1", "test2"]:
+ with request_provider_data_context(
+ {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
+ ):
+ assert inference_adapter.get_api_key() == api_key
diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py
index 2806f618c..6d6bb20d5 100644
--- a/tests/unit/providers/inference/test_remote_vllm.py
+++ b/tests/unit/providers/inference/test_remote_vllm.py
@@ -186,43 +186,3 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
-
-
-async def test_should_refresh_models():
- """
- Test the should_refresh_models method with different refresh_models configurations.
-
- This test verifies that:
- 1. When refresh_models is True, should_refresh_models returns True regardless of api_token
- 2. When refresh_models is False, should_refresh_models returns False regardless of api_token
- """
-
- # Test case 1: refresh_models is True, api_token is None
- config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
- adapter1 = VLLMInferenceAdapter(config=config1)
- result1 = await adapter1.should_refresh_models()
- assert result1 is True, "should_refresh_models should return True when refresh_models is True"
-
- # Test case 2: refresh_models is True, api_token is empty string
- config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
- adapter2 = VLLMInferenceAdapter(config=config2)
- result2 = await adapter2.should_refresh_models()
- assert result2 is True, "should_refresh_models should return True when refresh_models is True"
-
- # Test case 3: refresh_models is True, api_token is "fake" (default)
- config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
- adapter3 = VLLMInferenceAdapter(config=config3)
- result3 = await adapter3.should_refresh_models()
- assert result3 is True, "should_refresh_models should return True when refresh_models is True"
-
- # Test case 4: refresh_models is True, api_token is real token
- config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
- adapter4 = VLLMInferenceAdapter(config=config4)
- result4 = await adapter4.should_refresh_models()
- assert result4 is True, "should_refresh_models should return True when refresh_models is True"
-
- # Test case 5: refresh_models is False, api_token is real token
- config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
- adapter5 = VLLMInferenceAdapter(config=config5)
- result5 = await adapter5.should_refresh_models()
- assert result5 is False, "should_refresh_models should return False when refresh_models is False"
diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py
index ac4c29fea..ad9406951 100644
--- a/tests/unit/providers/utils/inference/test_openai_mixin.py
+++ b/tests/unit/providers/utils/inference/test_openai_mixin.py
@@ -44,11 +44,12 @@ def mixin():
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinImpl(config=config)
- # just enough to satisfy _get_provider_model_id calls
- mock_model_store = MagicMock()
+ # Mock model_store with async methods
+ mock_model_store = AsyncMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
+ mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
mixin_instance.model_store = mock_model_store
return mixin_instance
@@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
assert len(mixin._model_cache) == 3
+ async def test_check_model_availability_with_pre_registered_model(
+ self, mixin, mock_client_with_models, mock_client_context
+ ):
+ """Test that check_model_availability returns True for pre-registered models in model_store"""
+ # Mock model_store.has_model to return True for a specific model
+ mock_model_store = AsyncMock()
+ mock_model_store.has_model = AsyncMock(return_value=True)
+ mixin.model_store = mock_model_store
+
+ # Test that pre-registered model is found without calling the provider's API
+ with mock_client_context(mixin, mock_client_with_models):
+ mock_client_with_models.models.list.assert_not_called()
+ assert await mixin.check_model_availability("pre-registered-model")
+ # Should not call the provider's list_models since model was found in store
+ mock_client_with_models.models.list.assert_not_called()
+ mock_model_store.has_model.assert_called_once_with("pre-registered-model")
+
+ async def test_check_model_availability_fallback_to_provider_when_not_in_store(
+ self, mixin, mock_client_with_models, mock_client_context
+ ):
+ """Test that check_model_availability falls back to provider when model not in store"""
+ # Mock model_store.has_model to return False
+ mock_model_store = AsyncMock()
+ mock_model_store.has_model = AsyncMock(return_value=False)
+ mixin.model_store = mock_model_store
+
+ # Test that it falls back to provider's model cache
+ with mock_client_context(mixin, mock_client_with_models):
+ mock_client_with_models.models.list.assert_not_called()
+ assert await mixin.check_model_availability("some-mock-model-id")
+ # Should call the provider's list_models since model was not found in store
+ mock_client_with_models.models.list.assert_called_once()
+ mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
+
class TestOpenAIMixinCacheBehavior:
"""Test cases for cache behavior and edge cases"""
@@ -466,10 +501,16 @@ class TestOpenAIMixinModelRegistration:
assert result is None
async def test_should_refresh_models(self, mixin):
- """Test should_refresh_models method (should always return False)"""
+ """Test should_refresh_models method returns config value"""
+ # Default config has refresh_models=False
result = await mixin.should_refresh_models()
assert result is False
+ config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True)
+ mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh)
+ result_with_refresh = await mixin_with_refresh.should_refresh_models()
+ assert result_with_refresh is True
+
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
"""Test that errors from provider API are properly propagated during registration"""
model = Model(
diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py
index 70ace695e..d122f9323 100644
--- a/tests/unit/providers/vector_io/conftest.py
+++ b/tests/unit/providers/vector_io/conftest.py
@@ -145,10 +145,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
@pytest.fixture
-async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
+async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
- kvstore=SqliteKVStoreConfig(),
+ kvstore=unique_kvstore_config,
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
@@ -187,10 +187,10 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
@pytest.fixture
-async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
+async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
- kvstore=SqliteKVStoreConfig(),
+ kvstore=unique_kvstore_config,
)
adapter = MilvusVectorIOAdapter(
config=config,
@@ -264,10 +264,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
@pytest.fixture
-async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
+async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = ChromaVectorIOConfig(
db_path=chroma_vec_db_path,
- kvstore=SqliteKVStoreConfig(),
+ kvstore=unique_kvstore_config,
)
adapter = ChromaVectorIOAdapter(
config=config,
@@ -296,12 +296,12 @@ def qdrant_vec_db_path(tmp_path_factory):
@pytest.fixture
-async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
+async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
import uuid
config = QdrantVectorIOConfig(
db_path=qdrant_vec_db_path,
- kvstore=SqliteKVStoreConfig(),
+ kvstore=unique_kvstore_config,
)
adapter = QdrantVectorIOAdapter(
config=config,
@@ -386,14 +386,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
@pytest.fixture
-async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
+async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = PGVectorVectorIOConfig(
host="localhost",
port=5432,
db="test_db",
user="test_user",
password="test_password",
- kvstore=SqliteKVStoreConfig(),
+ kvstore=unique_kvstore_config,
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
@@ -476,7 +476,7 @@ async def weaviate_vec_index(weaviate_vec_db_path):
@pytest.fixture
-async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
+async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
import pytest_socket
import weaviate
@@ -492,7 +492,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi
config = WeaviateVectorIOConfig(
weaviate_cluster_url="localhost:8080",
weaviate_api_key=None,
- kvstore=SqliteKVStoreConfig(),
+ kvstore=unique_kvstore_config,
)
adapter = WeaviateVectorIOAdapter(
config=config,
diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py
index 4ea4a20b9..c1f834d5d 100644
--- a/tests/unit/registry/test_registry.py
+++ b/tests/unit/registry/test_registry.py
@@ -125,8 +125,15 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
provider_resource_id="test_vector_db_2",
provider_id="baz", # Same provider_id
)
- await cached_disk_dist_registry.register(duplicate_vector_db)
+ # Now we expect a ValueError to be raised for duplicate registration
+ with pytest.raises(
+ ValueError,
+ match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.",
+ ):
+ await cached_disk_dist_registry.register(duplicate_vector_db)
+
+ # Verify the original registration is still intact
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved