diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 238fed683..f9c42ef8a 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -86,7 +86,7 @@ jobs: # avoid line breaks in the server log, especially because we grep it below. export COLUMNS=1984 - nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & + nohup uv run llama stack run $run_dir/run.yaml > server.log 2>&1 & - name: Wait for Llama Stack server to be ready run: | diff --git a/.github/workflows/stale_bot.yml b/.github/workflows/stale_bot.yml index 502a78f8e..c5a1ba9e5 100644 --- a/.github/workflows/stale_bot.yml +++ b/.github/workflows/stale_bot.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Stale Action - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 + uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: stale-issue-label: 'stale' stale-issue-message: > diff --git a/.github/workflows/test-external-provider-module.yml b/.github/workflows/test-external-provider-module.yml index 8a757b068..b43cefb27 100644 --- a/.github/workflows/test-external-provider-module.yml +++ b/.github/workflows/test-external-provider-module.yml @@ -59,7 +59,7 @@ jobs: # Use the virtual environment created by the build step (name comes from build config) source ramalama-stack-test/bin/activate uv pip list - nohup llama stack run tests/external/ramalama-stack/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & + nohup llama stack run tests/external/ramalama-stack/run.yaml > server.log 2>&1 & - name: Wait for Llama Stack server to be ready run: | diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml index 7ee467451..a008b17af 100644 --- a/.github/workflows/test-external.yml +++ b/.github/workflows/test-external.yml @@ -59,7 +59,7 @@ jobs: # Use the virtual environment created by the build step (name comes from build config) source ci-test/bin/activate uv pip list - nohup llama stack run tests/external/run-byoa.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & + nohup llama stack run tests/external/run-byoa.yaml > server.log 2>&1 & - name: Wait for Llama Stack server to be ready run: | diff --git a/docs/docs/advanced_apis/post_training.mdx b/docs/docs/advanced_apis/post_training.mdx index 516ac07e1..43bfaea91 100644 --- a/docs/docs/advanced_apis/post_training.mdx +++ b/docs/docs/advanced_apis/post_training.mdx @@ -52,7 +52,7 @@ You can access the HuggingFace trainer via the `starter` distribution: ```bash llama stack build --distro starter --image-type venv -llama stack run --image-type venv ~/.llama/distributions/starter/starter-run.yaml +llama stack run ~/.llama/distributions/starter/starter-run.yaml ``` ### Usage Example diff --git a/docs/docs/building_applications/tools.mdx b/docs/docs/building_applications/tools.mdx index e5d9c46f9..3b78ec57b 100644 --- a/docs/docs/building_applications/tools.mdx +++ b/docs/docs/building_applications/tools.mdx @@ -219,13 +219,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools") 1. Start by registering a Tavily API key at [Tavily](https://tavily.com/). -2. [Optional] Provide the API key directly to the Llama Stack server +2. [Optional] Set the API key in your environment before starting the Llama Stack server ```bash export TAVILY_SEARCH_API_KEY="your key" ``` -```bash ---env TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY} -``` @@ -273,9 +270,9 @@ for log in EventLogger().log(response): 1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access). -2. Provide the API key either when starting the Llama Stack server: +2. Provide the API key either by setting it in your environment before starting the Llama Stack server: ```bash - --env WOLFRAM_ALPHA_API_KEY=${WOLFRAM_ALPHA_API_KEY} + export WOLFRAM_ALPHA_API_KEY="your key" ``` or from the client side: ```python diff --git a/docs/docs/contributing/new_api_provider.mdx b/docs/docs/contributing/new_api_provider.mdx index 4ae6d5e72..6f9744771 100644 --- a/docs/docs/contributing/new_api_provider.mdx +++ b/docs/docs/contributing/new_api_provider.mdx @@ -76,7 +76,7 @@ Integration tests are located in [tests/integration](https://github.com/meta-lla Consult [tests/integration/README.md](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more details on how to run the tests. Note that each provider's `sample_run_config()` method (in the configuration class for that provider) - typically references some environment variables for specifying API keys and the like. You can set these in the environment or pass these via the `--env` flag to the test command. + typically references some environment variables for specifying API keys and the like. You can set these in the environment before running the test command. ### 2. Unit Testing diff --git a/docs/docs/distributions/building_distro.mdx b/docs/docs/distributions/building_distro.mdx index 5b65b7f16..a4f7e1f60 100644 --- a/docs/docs/distributions/building_distro.mdx +++ b/docs/docs/distributions/building_distro.mdx @@ -289,10 +289,10 @@ After this step is successful, you should be able to find the built container im docker run -d \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ + -e INFERENCE_MODEL=$INFERENCE_MODEL \ + -e OLLAMA_URL=http://host.docker.internal:11434 \ localhost/distribution-ollama:dev \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env OLLAMA_URL=http://host.docker.internal:11434 + --port $LLAMA_STACK_PORT ``` Here are the docker flags and their uses: @@ -305,12 +305,12 @@ Here are the docker flags and their uses: * `localhost/distribution-ollama:dev`: The name and tag of the container image to run +* `-e INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the INFERENCE_MODEL environment variable in the container + +* `-e OLLAMA_URL=http://host.docker.internal:11434`: Sets the OLLAMA_URL environment variable in the container + * `--port $LLAMA_STACK_PORT`: Port number for the server to listen on -* `--env INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the model to use for inference - -* `--env OLLAMA_URL=http://host.docker.internal:11434`: Configures the URL for the Ollama service - @@ -320,23 +320,22 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con ``` llama stack run -h -usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE] +usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--image-type {venv}] [--enable-ui] - [config | template] + [config | distro] Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution. positional arguments: - config | template Path to config file to use for the run or name of known template (`llama stack list` for a list). (default: None) + config | distro Path to config file to use for the run or name of known distro (`llama stack list` for a list). (default: None) options: -h, --help show this help message and exit --port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321) --image-name IMAGE_NAME - Name of the image to run. Defaults to the current environment (default: None) - --env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: None) + [DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running. (default: None) --image-type {venv} - Image Type used during the build. This should be venv. (default: None) + [DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running. (default: None) --enable-ui Start the UI server (default: False) ``` @@ -348,9 +347,6 @@ llama stack run tgi # Start using config file llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml - -# Start using a venv -llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml ``` ``` diff --git a/docs/docs/distributions/configuration.mdx b/docs/docs/distributions/configuration.mdx index dbf879024..81243c97b 100644 --- a/docs/docs/distributions/configuration.mdx +++ b/docs/docs/distributions/configuration.mdx @@ -101,7 +101,7 @@ A few things to note: - The id is a string you can choose freely. - You can instantiate any number of provider instances of the same type. - The configuration dictionary is provider-specific. -- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server (via docker or via `llama stack run`), you can specify `--env OLLAMA_URL=http://my-server:11434` to override the default value. +- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server, you can set environment variables in your shell before running `llama stack run` to override the default values. ### Environment Variable Substitution @@ -173,13 +173,10 @@ optional_token: ${env.OPTIONAL_TOKEN:+} #### Runtime Override -You can override environment variables at runtime when starting the server: +You can override environment variables at runtime by setting them in your shell before starting the server: ```bash -# Override specific environment variables -llama stack run --config run.yaml --env API_KEY=sk-123 --env BASE_URL=https://custom-api.com - -# Or set them in your shell +# Set environment variables in your shell export API_KEY=sk-123 export BASE_URL=https://custom-api.com llama stack run --config run.yaml diff --git a/docs/docs/distributions/remote_hosted_distro/watsonx.md b/docs/docs/distributions/remote_hosted_distro/watsonx.md index 977af90dd..5add678f3 100644 --- a/docs/docs/distributions/remote_hosted_distro/watsonx.md +++ b/docs/docs/distributions/remote_hosted_distro/watsonx.md @@ -69,10 +69,10 @@ docker run \ -it \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ + -e WATSONX_API_KEY=$WATSONX_API_KEY \ + -e WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ + -e WATSONX_BASE_URL=$WATSONX_BASE_URL \ llamastack/distribution-watsonx \ --config /root/my-run.yaml \ - --port $LLAMA_STACK_PORT \ - --env WATSONX_API_KEY=$WATSONX_API_KEY \ - --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ - --env WATSONX_BASE_URL=$WATSONX_BASE_URL + --port $LLAMA_STACK_PORT ``` diff --git a/docs/docs/distributions/self_hosted_distro/dell.md b/docs/docs/distributions/self_hosted_distro/dell.md index 52d40cf9d..851eac3bf 100644 --- a/docs/docs/distributions/self_hosted_distro/dell.md +++ b/docs/docs/distributions/self_hosted_distro/dell.md @@ -129,11 +129,11 @@ docker run -it \ # NOTE: mount the llama-stack / llama-model directories if testing local changes else not needed -v $HOME/git/llama-stack:/app/llama-stack-source -v $HOME/git/llama-models:/app/llama-models-source \ # localhost/distribution-dell:dev if building / testing locally - llamastack/distribution-dell\ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env DEH_URL=$DEH_URL \ - --env CHROMA_URL=$CHROMA_URL + -e INFERENCE_MODEL=$INFERENCE_MODEL \ + -e DEH_URL=$DEH_URL \ + -e CHROMA_URL=$CHROMA_URL \ + llamastack/distribution-dell \ + --port $LLAMA_STACK_PORT ``` @@ -154,14 +154,14 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v $HOME/.llama:/root/.llama \ -v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \ + -e INFERENCE_MODEL=$INFERENCE_MODEL \ + -e DEH_URL=$DEH_URL \ + -e SAFETY_MODEL=$SAFETY_MODEL \ + -e DEH_SAFETY_URL=$DEH_SAFETY_URL \ + -e CHROMA_URL=$CHROMA_URL \ llamastack/distribution-dell \ --config /root/my-run.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env DEH_URL=$DEH_URL \ - --env SAFETY_MODEL=$SAFETY_MODEL \ - --env DEH_SAFETY_URL=$DEH_SAFETY_URL \ - --env CHROMA_URL=$CHROMA_URL + --port $LLAMA_STACK_PORT ``` ### Via venv @@ -170,21 +170,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a ```bash llama stack build --distro dell --image-type venv -llama stack run dell - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env DEH_URL=$DEH_URL \ - --env CHROMA_URL=$CHROMA_URL +INFERENCE_MODEL=$INFERENCE_MODEL \ +DEH_URL=$DEH_URL \ +CHROMA_URL=$CHROMA_URL \ +llama stack run dell \ + --port $LLAMA_STACK_PORT ``` If you are using Llama Stack Safety / Shield APIs, use: ```bash +INFERENCE_MODEL=$INFERENCE_MODEL \ +DEH_URL=$DEH_URL \ +SAFETY_MODEL=$SAFETY_MODEL \ +DEH_SAFETY_URL=$DEH_SAFETY_URL \ +CHROMA_URL=$CHROMA_URL \ llama stack run ./run-with-safety.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env DEH_URL=$DEH_URL \ - --env SAFETY_MODEL=$SAFETY_MODEL \ - --env DEH_SAFETY_URL=$DEH_SAFETY_URL \ - --env CHROMA_URL=$CHROMA_URL + --port $LLAMA_STACK_PORT ``` diff --git a/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md index 84b85b91c..1c0ef5f6e 100644 --- a/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md +++ b/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md @@ -84,9 +84,9 @@ docker run \ --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ + -e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ llamastack/distribution-meta-reference-gpu \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct + --port $LLAMA_STACK_PORT ``` If you are using Llama Stack Safety / Shield APIs, use: @@ -98,10 +98,10 @@ docker run \ --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ + -e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ + -e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \ llamastack/distribution-meta-reference-gpu \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B + --port $LLAMA_STACK_PORT ``` ### Via venv @@ -110,16 +110,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL ```bash llama stack build --distro meta-reference-gpu --image-type venv +INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ llama stack run distributions/meta-reference-gpu/run.yaml \ - --port 8321 \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct + --port 8321 ``` If you are using Llama Stack Safety / Shield APIs, use: ```bash +INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ +SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \ llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \ - --port 8321 \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B + --port 8321 ``` diff --git a/docs/docs/distributions/self_hosted_distro/nvidia.md b/docs/docs/distributions/self_hosted_distro/nvidia.md index 1e52797db..a6e185442 100644 --- a/docs/docs/distributions/self_hosted_distro/nvidia.md +++ b/docs/docs/distributions/self_hosted_distro/nvidia.md @@ -129,10 +129,10 @@ docker run \ --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ + -e NVIDIA_API_KEY=$NVIDIA_API_KEY \ llamastack/distribution-nvidia \ --config /root/my-run.yaml \ - --port $LLAMA_STACK_PORT \ - --env NVIDIA_API_KEY=$NVIDIA_API_KEY + --port $LLAMA_STACK_PORT ``` ### Via venv @@ -142,10 +142,10 @@ If you've set up your local development environment, you can also build the imag ```bash INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct llama stack build --distro nvidia --image-type venv +NVIDIA_API_KEY=$NVIDIA_API_KEY \ +INFERENCE_MODEL=$INFERENCE_MODEL \ llama stack run ./run.yaml \ - --port 8321 \ - --env NVIDIA_API_KEY=$NVIDIA_API_KEY \ - --env INFERENCE_MODEL=$INFERENCE_MODEL + --port 8321 ``` ## Example Notebooks diff --git a/docs/docs/getting_started/detailed_tutorial.mdx b/docs/docs/getting_started/detailed_tutorial.mdx index 33786ac0e..e6c22224d 100644 --- a/docs/docs/getting_started/detailed_tutorial.mdx +++ b/docs/docs/getting_started/detailed_tutorial.mdx @@ -86,9 +86,9 @@ docker run -it \ --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ + -e OLLAMA_URL=http://host.docker.internal:11434 \ llamastack/distribution-starter \ - --port $LLAMA_STACK_PORT \ - --env OLLAMA_URL=http://host.docker.internal:11434 + --port $LLAMA_STACK_PORT ``` Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with `podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL` @@ -106,9 +106,9 @@ docker run -it \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ --network=host \ + -e OLLAMA_URL=http://localhost:11434 \ llamastack/distribution-starter \ - --port $LLAMA_STACK_PORT \ - --env OLLAMA_URL=http://localhost:11434 + --port $LLAMA_STACK_PORT ``` ::: You will see output like below: diff --git a/docs/docs/providers/inference/remote_anthropic.mdx b/docs/docs/providers/inference/remote_anthropic.mdx index 96162d25c..44c1fcbb1 100644 --- a/docs/docs/providers/inference/remote_anthropic.mdx +++ b/docs/docs/providers/inference/remote_anthropic.mdx @@ -15,6 +15,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | API key for Anthropic models | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_azure.mdx b/docs/docs/providers/inference/remote_azure.mdx index 721fe429c..56a14c100 100644 --- a/docs/docs/providers/inference/remote_azure.mdx +++ b/docs/docs/providers/inference/remote_azure.mdx @@ -22,6 +22,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `` | No | | Azure API key for Azure | | `api_base` | `` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) | | `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) | diff --git a/docs/docs/providers/inference/remote_bedrock.mdx b/docs/docs/providers/inference/remote_bedrock.mdx index 2a5d1b74d..683ec12f8 100644 --- a/docs/docs/providers/inference/remote_bedrock.mdx +++ b/docs/docs/providers/inference/remote_bedrock.mdx @@ -15,6 +15,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | | `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | | `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | diff --git a/docs/docs/providers/inference/remote_cerebras.mdx b/docs/docs/providers/inference/remote_cerebras.mdx index 1a543389d..d364b9884 100644 --- a/docs/docs/providers/inference/remote_cerebras.mdx +++ b/docs/docs/providers/inference/remote_cerebras.mdx @@ -15,6 +15,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `base_url` | `` | No | https://api.cerebras.ai | Base URL for the Cerebras API | | `api_key` | `` | No | | Cerebras API Key | diff --git a/docs/docs/providers/inference/remote_databricks.mdx b/docs/docs/providers/inference/remote_databricks.mdx index 670f8a7f9..d7b0bd38d 100644 --- a/docs/docs/providers/inference/remote_databricks.mdx +++ b/docs/docs/providers/inference/remote_databricks.mdx @@ -15,6 +15,7 @@ Databricks inference provider for running models on Databricks' unified analytic | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint | | `api_token` | `` | No | | The Databricks API token | diff --git a/docs/docs/providers/inference/remote_fireworks.mdx b/docs/docs/providers/inference/remote_fireworks.mdx index d2c3a664e..cfdfb993c 100644 --- a/docs/docs/providers/inference/remote_fireworks.mdx +++ b/docs/docs/providers/inference/remote_fireworks.mdx @@ -15,6 +15,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | diff --git a/docs/docs/providers/inference/remote_gemini.mdx b/docs/docs/providers/inference/remote_gemini.mdx index 5222eaa89..a13d1c82d 100644 --- a/docs/docs/providers/inference/remote_gemini.mdx +++ b/docs/docs/providers/inference/remote_gemini.mdx @@ -15,6 +15,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | API key for Gemini models | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_groq.mdx b/docs/docs/providers/inference/remote_groq.mdx index 77516ed1f..1edb4f9ea 100644 --- a/docs/docs/providers/inference/remote_groq.mdx +++ b/docs/docs/providers/inference/remote_groq.mdx @@ -15,6 +15,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | The Groq API key | | `url` | `` | No | https://api.groq.com | The URL for the Groq AI server | diff --git a/docs/docs/providers/inference/remote_llama-openai-compat.mdx b/docs/docs/providers/inference/remote_llama-openai-compat.mdx index bcd50f772..ca5830b09 100644 --- a/docs/docs/providers/inference/remote_llama-openai-compat.mdx +++ b/docs/docs/providers/inference/remote_llama-openai-compat.mdx @@ -15,6 +15,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | The Llama API key | | `openai_compat_api_base` | `` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server | diff --git a/docs/docs/providers/inference/remote_nvidia.mdx b/docs/docs/providers/inference/remote_nvidia.mdx index 348a42e59..6b5e36180 100644 --- a/docs/docs/providers/inference/remote_nvidia.mdx +++ b/docs/docs/providers/inference/remote_nvidia.mdx @@ -15,6 +15,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | diff --git a/docs/docs/providers/inference/remote_ollama.mdx b/docs/docs/providers/inference/remote_ollama.mdx index f075607d8..e00e34e4a 100644 --- a/docs/docs/providers/inference/remote_ollama.mdx +++ b/docs/docs/providers/inference/remote_ollama.mdx @@ -15,8 +15,8 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | http://localhost:11434 | | -| `refresh_models` | `` | No | False | Whether to refresh models periodically | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_openai.mdx b/docs/docs/providers/inference/remote_openai.mdx index b795d02b1..e0910c809 100644 --- a/docs/docs/providers/inference/remote_openai.mdx +++ b/docs/docs/providers/inference/remote_openai.mdx @@ -15,6 +15,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | API key for OpenAI models | | `base_url` | `` | No | https://api.openai.com/v1 | Base URL for OpenAI API | diff --git a/docs/docs/providers/inference/remote_passthrough.mdx b/docs/docs/providers/inference/remote_passthrough.mdx index 58d5619b8..e356384ad 100644 --- a/docs/docs/providers/inference/remote_passthrough.mdx +++ b/docs/docs/providers/inference/remote_passthrough.mdx @@ -15,6 +15,7 @@ Passthrough inference provider for connecting to any external inference service | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | | The URL for the passthrough endpoint | | `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint | diff --git a/docs/docs/providers/inference/remote_runpod.mdx b/docs/docs/providers/inference/remote_runpod.mdx index 92cc66eb1..876532029 100644 --- a/docs/docs/providers/inference/remote_runpod.mdx +++ b/docs/docs/providers/inference/remote_runpod.mdx @@ -15,6 +15,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint | | `api_token` | `str \| None` | No | | The API token | diff --git a/docs/docs/providers/inference/remote_sambanova.mdx b/docs/docs/providers/inference/remote_sambanova.mdx index b28471890..9bd7b7613 100644 --- a/docs/docs/providers/inference/remote_sambanova.mdx +++ b/docs/docs/providers/inference/remote_sambanova.mdx @@ -15,6 +15,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key | diff --git a/docs/docs/providers/inference/remote_tgi.mdx b/docs/docs/providers/inference/remote_tgi.mdx index 6ff82cc2b..67fe6d237 100644 --- a/docs/docs/providers/inference/remote_tgi.mdx +++ b/docs/docs/providers/inference/remote_tgi.mdx @@ -15,6 +15,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | | The URL for the TGI serving endpoint | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_together.mdx b/docs/docs/providers/inference/remote_together.mdx index da232a45b..6df2ca866 100644 --- a/docs/docs/providers/inference/remote_together.mdx +++ b/docs/docs/providers/inference/remote_together.mdx @@ -15,6 +15,7 @@ Together AI inference provider for open-source models and collaborative AI devel | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | diff --git a/docs/docs/providers/inference/remote_vertexai.mdx b/docs/docs/providers/inference/remote_vertexai.mdx index 48da6be24..c182ed485 100644 --- a/docs/docs/providers/inference/remote_vertexai.mdx +++ b/docs/docs/providers/inference/remote_vertexai.mdx @@ -54,6 +54,7 @@ Available Models: | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `project` | `` | No | | Google Cloud project ID for Vertex AI | | `location` | `` | No | us-central1 | Google Cloud location for Vertex AI | diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index 598f97b19..fbbd424a3 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -15,11 +15,11 @@ Remote vLLM inference provider for connecting to vLLM servers. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint | | `max_tokens` | `` | No | 4096 | Maximum number of tokens to generate. | | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | -| `refresh_models` | `` | No | False | Whether to refresh models periodically | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_watsonx.mdx b/docs/docs/providers/inference/remote_watsonx.mdx index 8cd3b2869..f081703ab 100644 --- a/docs/docs/providers/inference/remote_watsonx.mdx +++ b/docs/docs/providers/inference/remote_watsonx.mdx @@ -15,9 +15,10 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key | -| `project_id` | `str \| None` | No | | The Project ID key | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key | +| `project_id` | `str \| None` | No | | The watsonx.ai project ID | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | ## Sample Configuration diff --git a/docs/docs/providers/safety/remote_bedrock.mdx b/docs/docs/providers/safety/remote_bedrock.mdx index 530a208b5..663a761f0 100644 --- a/docs/docs/providers/safety/remote_bedrock.mdx +++ b/docs/docs/providers/safety/remote_bedrock.mdx @@ -15,6 +15,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | | `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | | `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index d7d544ad5..3dcedfed6 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -123,12 +123,12 @@ " del os.environ[\"UV_SYSTEM_PYTHON\"]\n", "\n", "# this command installs all the dependencies needed for the llama stack server with the together inference provider\n", - "!uv run --with llama-stack llama stack build --distro together --image-type venv\n", + "!uv run --with llama-stack llama stack build --distro together\n", "\n", "def run_llama_stack_server_background():\n", " log_file = open(\"llama_stack_server.log\", \"w\")\n", " process = subprocess.Popen(\n", - " \"uv run --with llama-stack llama stack run together --image-type venv\",\n", + " \"uv run --with llama-stack llama stack run together\",\n", " shell=True,\n", " stdout=log_file,\n", " stderr=log_file,\n", diff --git a/docs/getting_started_llama4.ipynb b/docs/getting_started_llama4.ipynb index cd5f83517..bca505b5e 100644 --- a/docs/getting_started_llama4.ipynb +++ b/docs/getting_started_llama4.ipynb @@ -233,12 +233,12 @@ " del os.environ[\"UV_SYSTEM_PYTHON\"]\n", "\n", "# this command installs all the dependencies needed for the llama stack server\n", - "!uv run --with llama-stack llama stack build --distro meta-reference-gpu --image-type venv\n", + "!uv run --with llama-stack llama stack build --distro meta-reference-gpu\n", "\n", "def run_llama_stack_server_background():\n", " log_file = open(\"llama_stack_server.log\", \"w\")\n", " process = subprocess.Popen(\n", - " f\"uv run --with llama-stack llama stack run meta-reference-gpu --image-type venv --env INFERENCE_MODEL={model_id}\",\n", + " f\"INFERENCE_MODEL={model_id} uv run --with llama-stack llama stack run meta-reference-gpu\",\n", " shell=True,\n", " stdout=log_file,\n", " stderr=log_file,\n", diff --git a/docs/getting_started_llama_api.ipynb b/docs/getting_started_llama_api.ipynb index f65566205..7680c4a0c 100644 --- a/docs/getting_started_llama_api.ipynb +++ b/docs/getting_started_llama_api.ipynb @@ -223,12 +223,12 @@ " del os.environ[\"UV_SYSTEM_PYTHON\"]\n", "\n", "# this command installs all the dependencies needed for the llama stack server\n", - "!uv run --with llama-stack llama stack build --distro llama_api --image-type venv\n", + "!uv run --with llama-stack llama stack build --distro llama_api\n", "\n", "def run_llama_stack_server_background():\n", " log_file = open(\"llama_stack_server.log\", \"w\")\n", " process = subprocess.Popen(\n", - " \"uv run --with llama-stack llama stack run llama_api --image-type venv\",\n", + " \"uv run --with llama-stack llama stack run llama_api\",\n", " shell=True,\n", " stdout=log_file,\n", " stderr=log_file,\n", diff --git a/docs/quick_start.ipynb b/docs/quick_start.ipynb index c194a901d..eebfd6686 100644 --- a/docs/quick_start.ipynb +++ b/docs/quick_start.ipynb @@ -145,12 +145,12 @@ " del os.environ[\"UV_SYSTEM_PYTHON\"]\n", "\n", "# this command installs all the dependencies needed for the llama stack server with the ollama inference provider\n", - "!uv run --with llama-stack llama stack build --distro starter --image-type venv\n", + "!uv run --with llama-stack llama stack build --distro starter\n", "\n", "def run_llama_stack_server_background():\n", " log_file = open(\"llama_stack_server.log\", \"w\")\n", " process = subprocess.Popen(\n", - " f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter --image-type venv\n", + " f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter\n", " shell=True,\n", " stdout=log_file,\n", " stderr=log_file,\n", diff --git a/docs/zero_to_hero_guide/README.md b/docs/zero_to_hero_guide/README.md index 183038a88..1b643d692 100644 --- a/docs/zero_to_hero_guide/README.md +++ b/docs/zero_to_hero_guide/README.md @@ -88,7 +88,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next ... Build Successful! You can find the newly-built template here: ~/.llama/distributions/starter/starter-run.yaml - You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter --image-type venv + You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter ``` 3. **Set the ENV variables by exporting them to the terminal**: @@ -102,12 +102,11 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next 3. **Run the Llama Stack**: Run the stack using uv: ```bash + INFERENCE_MODEL=$INFERENCE_MODEL \ + SAFETY_MODEL=$SAFETY_MODEL \ + OLLAMA_URL=$OLLAMA_URL \ uv run --with llama-stack llama stack run starter \ - --image-type venv \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env SAFETY_MODEL=$SAFETY_MODEL \ - --env OLLAMA_URL=$OLLAMA_URL + --port $LLAMA_STACK_PORT ``` Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model. diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index b14e6fe55..471d5cb66 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -444,12 +444,24 @@ def _run_stack_build_command_from_build_config( cprint("Build Successful!", color="green", file=sys.stderr) cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr) - cprint( - "You can run the new Llama Stack distro via: " - + colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"), - color="green", - file=sys.stderr, - ) + if build_config.image_type == LlamaStackImageType.VENV: + cprint( + "You can run the new Llama Stack distro (after activating " + + colored(image_name, "cyan") + + ") via: " + + colored(f"llama stack run {run_config_file}", "blue"), + color="green", + file=sys.stderr, + ) + elif build_config.image_type == LlamaStackImageType.CONTAINER: + cprint( + "You can run the container with: " + + colored( + f"docker run -p 8321:8321 -v ~/.llama:/root/.llama localhost/{image_name} --port 8321", "blue" + ), + color="green", + file=sys.stderr, + ) return distro_path else: return _generate_run_config(build_config, build_dir, image_name) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index cec101083..06dae7318 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -16,7 +16,7 @@ import yaml from llama_stack.cli.stack.utils import ImageType from llama_stack.cli.subcommand import Subcommand from llama_stack.core.datatypes import LoggingConfig, StackRunConfig -from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars, validate_env_pair +from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro from llama_stack.log import get_logger @@ -55,18 +55,12 @@ class StackRun(Subcommand): "--image-name", type=str, default=None, - help="Name of the image to run. Defaults to the current environment", - ) - self.parser.add_argument( - "--env", - action="append", - help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.", - metavar="KEY=VALUE", + help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.", ) self.parser.add_argument( "--image-type", type=str, - help="Image Type used during the build. This can be only venv.", + help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.", choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value], ) self.parser.add_argument( @@ -75,48 +69,22 @@ class StackRun(Subcommand): help="Start the UI server", ) - def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]: - """Resolve config file path and distribution name from args.config""" - from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR - - if not args.config: - return None, None - - config_file = Path(args.config) - has_yaml_suffix = args.config.endswith(".yaml") - distro_name = None - - if not config_file.exists() and not has_yaml_suffix: - # check if this is a distribution - config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml" - if config_file.exists(): - distro_name = args.config - - if not config_file.exists() and not has_yaml_suffix: - # check if it's a build config saved to ~/.llama dir - config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml") - - if not config_file.exists(): - self.parser.error( - f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file" - ) - - if not config_file.is_file(): - self.parser.error( - f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}" - ) - - return config_file, distro_name - def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import yaml from llama_stack.core.configure import parse_and_maybe_upgrade_config - from llama_stack.core.utils.exec import formulate_run_args, run_command + + if args.image_type or args.image_name: + self.parser.error( + "The --image-type and --image-name flags are no longer supported.\n\n" + "Please activate your virtual environment manually before running `llama stack run`.\n\n" + "For example:\n" + " source /path/to/venv/bin/activate\n" + " llama stack run \n" + ) if args.enable_ui: self._start_ui_development_server(args.port) - image_type, image_name = args.image_type, args.image_name if args.config: try: @@ -128,10 +96,6 @@ class StackRun(Subcommand): else: config_file = None - # Check if config is required based on image type - if image_type == ImageType.VENV.value and not config_file: - self.parser.error("Config file is required for venv environment") - if config_file: logger.info(f"Using run configuration: {config_file}") @@ -146,50 +110,13 @@ class StackRun(Subcommand): os.makedirs(str(config.external_providers_dir), exist_ok=True) except AttributeError as e: self.parser.error(f"failed to parse config file '{config_file}':\n {e}") - else: - config = None - # If neither image type nor image name is provided, assume the server should be run directly - # using the current environment packages. - if not image_type and not image_name: - logger.info("No image type or image name provided. Assuming environment packages.") - self._uvicorn_run(config_file, args) - else: - run_args = formulate_run_args(image_type, image_name) - - run_args.extend([str(args.port)]) - - if config_file: - run_args.extend(["--config", str(config_file)]) - - if args.env: - for env_var in args.env: - if "=" not in env_var: - self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format") - return - key, value = env_var.split("=", 1) # split on first = only - if not key: - self.parser.error(f"Environment variable '{env_var}' has empty key") - return - run_args.extend(["--env", f"{key}={value}"]) - - run_command(run_args) + self._uvicorn_run(config_file, args) def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None: if not config_file: self.parser.error("Config file is required") - # Set environment variables if provided - if args.env: - for env_pair in args.env: - try: - key, value = validate_env_pair(env_pair) - logger.info(f"Setting environment variable {key} => {value}") - os.environ[key] = value - except ValueError as e: - logger.error(f"Error: {str(e)}") - self.parser.error(f"Invalid environment variable format: {env_pair}") - config_file = resolve_config_or_distro(str(config_file), Mode.RUN) with open(config_file) as fp: config_contents = yaml.safe_load(fp) diff --git a/llama_stack/core/conversations/conversations.py b/llama_stack/core/conversations/conversations.py index bef138e69..612b2f68e 100644 --- a/llama_stack/core/conversations/conversations.py +++ b/llama_stack/core/conversations/conversations.py @@ -32,7 +32,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import ( sqlstore_impl, ) -logger = get_logger(name=__name__, category="openai::conversations") +logger = get_logger(name=__name__, category="openai_conversations") class ConversationServiceConfig(BaseModel): diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index c4338e614..847f6a2d2 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -611,7 +611,7 @@ class InferenceRouter(Inference): completion_text += "".join(choice_data["content_parts"]) # Add metrics to the chunk - if self.telemetry and chunk.usage: + if self.telemetry and hasattr(chunk, "usage") and chunk.usage: metrics = self._construct_metrics( prompt_tokens=chunk.usage.prompt_tokens, completion_tokens=chunk.usage.completion_tokens, diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 641c73c16..716be936a 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): try: models = await provider.list_models() except Exception as e: - logger.warning(f"Model refresh failed for provider {provider_id}: {e}") + logger.debug(f"Model refresh failed for provider {provider_id}: {e}") continue self.listed_providers.add(provider_id) @@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Provider {model.provider_id} not found in the routing table") return self.impls_by_provider_id[model.provider_id] + async def has_model(self, model_id: str) -> bool: + """ + Check if a model exists in the routing table. + + :param model_id: The model identifier to check + :return: True if the model exists, False otherwise + """ + try: + await lookup_model(self, model_id) + return True + except ModelNotFoundError: + return False + async def register_model( self, model_id: str, diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index d5d55319a..acc02eeff 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -274,22 +274,6 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]: return config_dict -def validate_env_pair(env_pair: str) -> tuple[str, str]: - """Validate and split an environment variable key-value pair.""" - try: - key, value = env_pair.split("=", 1) - key = key.strip() - if not key: - raise ValueError(f"Empty key in environment variable pair: {env_pair}") - if not all(c.isalnum() or c == "_" for c in key): - raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}") - return key, value - except ValueError as e: - raise ValueError( - f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value" - ) from e - - def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None: """Add internal implementations (inspect and providers) to the implementations dictionary. diff --git a/llama_stack/core/start_stack.sh b/llama_stack/core/start_stack.sh index 02b1cd408..cc0ae68d8 100755 --- a/llama_stack/core/start_stack.sh +++ b/llama_stack/core/start_stack.sh @@ -25,7 +25,7 @@ error_handler() { trap 'error_handler ${LINENO}' ERR if [ $# -lt 3 ]; then - echo "Usage: $0 [--config ] [--env KEY=VALUE]..." + echo "Usage: $0 [--config ]" exit 1 fi @@ -43,7 +43,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")") # Initialize variables yaml_config="" -env_vars="" other_args="" # Process remaining arguments @@ -58,15 +57,6 @@ while [[ $# -gt 0 ]]; do exit 1 fi ;; - --env) - if [[ -n "$2" ]]; then - env_vars="$env_vars --env $2" - shift 2 - else - echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2 - exit 1 - fi - ;; *) other_args="$other_args $1" shift @@ -119,7 +109,6 @@ if [[ "$env_type" == "venv" ]]; then llama stack run \ $yaml_config_arg \ --port "$port" \ - $env_vars \ $other_args elif [[ "$env_type" == "container" ]]; then echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}" diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 624dbd176..0486553d5 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -98,7 +98,10 @@ class DiskDistributionRegistry(DistributionRegistry): existing_obj = await self.get(obj.type, obj.identifier) # dont register if the object's providerid already exists if existing_obj and existing_obj.provider_id == obj.provider_id: - return False + raise ValueError( + f"Provider '{obj.provider_id}' is already registered." + f"Unregister the existing provider first before registering it again." + ) await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), diff --git a/llama_stack/distributions/dell/doc_template.md b/llama_stack/distributions/dell/doc_template.md index fcec3ea14..852e78d0e 100644 --- a/llama_stack/distributions/dell/doc_template.md +++ b/llama_stack/distributions/dell/doc_template.md @@ -117,11 +117,11 @@ docker run -it \ # NOTE: mount the llama-stack directory if testing local changes else not needed -v $HOME/git/llama-stack:/app/llama-stack-source \ # localhost/distribution-dell:dev if building / testing locally + -e INFERENCE_MODEL=$INFERENCE_MODEL \ + -e DEH_URL=$DEH_URL \ + -e CHROMA_URL=$CHROMA_URL \ llamastack/distribution-{{ name }}\ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env DEH_URL=$DEH_URL \ - --env CHROMA_URL=$CHROMA_URL + --port $LLAMA_STACK_PORT ``` @@ -142,14 +142,14 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v $HOME/.llama:/root/.llama \ -v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \ + -e INFERENCE_MODEL=$INFERENCE_MODEL \ + -e DEH_URL=$DEH_URL \ + -e SAFETY_MODEL=$SAFETY_MODEL \ + -e DEH_SAFETY_URL=$DEH_SAFETY_URL \ + -e CHROMA_URL=$CHROMA_URL \ llamastack/distribution-{{ name }} \ --config /root/my-run.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env DEH_URL=$DEH_URL \ - --env SAFETY_MODEL=$SAFETY_MODEL \ - --env DEH_SAFETY_URL=$DEH_SAFETY_URL \ - --env CHROMA_URL=$CHROMA_URL + --port $LLAMA_STACK_PORT ``` ### Via Conda @@ -158,21 +158,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a ```bash llama stack build --distro {{ name }} --image-type conda -llama stack run {{ name }} - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env DEH_URL=$DEH_URL \ - --env CHROMA_URL=$CHROMA_URL +INFERENCE_MODEL=$INFERENCE_MODEL \ +DEH_URL=$DEH_URL \ +CHROMA_URL=$CHROMA_URL \ +llama stack run {{ name }} \ + --port $LLAMA_STACK_PORT ``` If you are using Llama Stack Safety / Shield APIs, use: ```bash +INFERENCE_MODEL=$INFERENCE_MODEL \ +DEH_URL=$DEH_URL \ +SAFETY_MODEL=$SAFETY_MODEL \ +DEH_SAFETY_URL=$DEH_SAFETY_URL \ +CHROMA_URL=$CHROMA_URL \ llama stack run ./run-with-safety.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env DEH_URL=$DEH_URL \ - --env SAFETY_MODEL=$SAFETY_MODEL \ - --env DEH_SAFETY_URL=$DEH_SAFETY_URL \ - --env CHROMA_URL=$CHROMA_URL + --port $LLAMA_STACK_PORT ``` diff --git a/llama_stack/distributions/meta-reference-gpu/doc_template.md b/llama_stack/distributions/meta-reference-gpu/doc_template.md index 602d053c4..92dcc6102 100644 --- a/llama_stack/distributions/meta-reference-gpu/doc_template.md +++ b/llama_stack/distributions/meta-reference-gpu/doc_template.md @@ -72,9 +72,9 @@ docker run \ --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ + -e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ llamastack/distribution-{{ name }} \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct + --port $LLAMA_STACK_PORT ``` If you are using Llama Stack Safety / Shield APIs, use: @@ -86,10 +86,10 @@ docker run \ --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ + -e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ + -e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \ llamastack/distribution-{{ name }} \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B + --port $LLAMA_STACK_PORT ``` ### Via venv @@ -98,16 +98,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL ```bash llama stack build --distro {{ name }} --image-type venv +INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ llama stack run distributions/{{ name }}/run.yaml \ - --port 8321 \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct + --port 8321 ``` If you are using Llama Stack Safety / Shield APIs, use: ```bash +INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ +SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \ llama stack run distributions/{{ name }}/run-with-safety.yaml \ - --port 8321 \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B + --port 8321 ``` diff --git a/llama_stack/distributions/nvidia/doc_template.md b/llama_stack/distributions/nvidia/doc_template.md index fbee17ef8..df2b68ef7 100644 --- a/llama_stack/distributions/nvidia/doc_template.md +++ b/llama_stack/distributions/nvidia/doc_template.md @@ -118,10 +118,10 @@ docker run \ --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ + -e NVIDIA_API_KEY=$NVIDIA_API_KEY \ llamastack/distribution-{{ name }} \ --config /root/my-run.yaml \ - --port $LLAMA_STACK_PORT \ - --env NVIDIA_API_KEY=$NVIDIA_API_KEY + --port $LLAMA_STACK_PORT ``` ### Via venv @@ -131,10 +131,10 @@ If you've set up your local development environment, you can also build the imag ```bash INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct llama stack build --distro nvidia --image-type venv +NVIDIA_API_KEY=$NVIDIA_API_KEY \ +INFERENCE_MODEL=$INFERENCE_MODEL \ llama stack run ./run.yaml \ - --port 8321 \ - --env NVIDIA_API_KEY=$NVIDIA_API_KEY \ - --env INFERENCE_MODEL=$INFERENCE_MODEL + --port 8321 ``` ## Example Notebooks diff --git a/llama_stack/distributions/watsonx/__init__.py b/llama_stack/distributions/watsonx/__init__.py index 756f351d8..078d86144 100644 --- a/llama_stack/distributions/watsonx/__init__.py +++ b/llama_stack/distributions/watsonx/__init__.py @@ -3,3 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from .watsonx import get_distribution_template # noqa: F401 diff --git a/llama_stack/distributions/watsonx/build.yaml b/llama_stack/distributions/watsonx/build.yaml index bf4be7eaf..06349a741 100644 --- a/llama_stack/distributions/watsonx/build.yaml +++ b/llama_stack/distributions/watsonx/build.yaml @@ -3,44 +3,33 @@ distribution_spec: description: Use watsonx for running LLM inference providers: inference: - - provider_id: watsonx - provider_type: remote::watsonx - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers + - provider_type: remote::watsonx + - provider_type: inline::sentence-transformers vector_io: - - provider_id: faiss - provider_type: inline::faiss + - provider_type: inline::faiss safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - - provider_id: localfs - provider_type: inline::localfs + - provider_type: remote::huggingface + - provider_type: inline::localfs scoring: - - provider_id: basic - provider_type: inline::basic - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - - provider_id: braintrust - provider_type: inline::braintrust + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol + files: + - provider_type: inline::localfs image_type: venv additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] -- aiosqlite -- aiosqlite diff --git a/llama_stack/distributions/watsonx/run.yaml b/llama_stack/distributions/watsonx/run.yaml index 92f367910..e0c337f9d 100644 --- a/llama_stack/distributions/watsonx/run.yaml +++ b/llama_stack/distributions/watsonx/run.yaml @@ -4,13 +4,13 @@ apis: - agents - datasetio - eval +- files - inference - safety - scoring - telemetry - tool_runtime - vector_io -- files providers: inference: - provider_id: watsonx @@ -19,8 +19,6 @@ providers: url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com} api_key: ${env.WATSONX_API_KEY:=} project_id: ${env.WATSONX_PROJECT_ID:=} - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers vector_io: - provider_id: faiss provider_type: inline::faiss @@ -48,7 +46,7 @@ providers: provider_type: inline::meta-reference config: service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" - sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sinks: ${env.TELEMETRY_SINKS:=sqlite} sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} eval: @@ -109,102 +107,7 @@ metadata_store: inference_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db -models: -- metadata: {} - model_id: meta-llama/llama-3-3-70b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.3-70B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-2-13b-chat - provider_id: watsonx - provider_model_id: meta-llama/llama-2-13b-chat - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-2-13b - provider_id: watsonx - provider_model_id: meta-llama/llama-2-13b-chat - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-1-70b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-70B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-1-8b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-8B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-11b-vision-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-11B-Vision-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-1b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-1B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-3b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-3B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-90b-vision-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-90B-Vision-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-guard-3-11b-vision - provider_id: watsonx - provider_model_id: meta-llama/llama-guard-3-11b-vision - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-Guard-3-11B-Vision - provider_id: watsonx - provider_model_id: meta-llama/llama-guard-3-11b-vision - model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: sentence-transformers - model_type: embedding +models: [] shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/distributions/watsonx/watsonx.py b/llama_stack/distributions/watsonx/watsonx.py index c3cab5d1b..645770612 100644 --- a/llama_stack/distributions/watsonx/watsonx.py +++ b/llama_stack/distributions/watsonx/watsonx.py @@ -4,17 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pathlib import Path -from llama_stack.apis.models import ModelType -from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput -from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry +from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig -from llama_stack.providers.inline.inference.sentence_transformers import ( - SentenceTransformersInferenceConfig, -) from llama_stack.providers.remote.inference.watsonx import WatsonXConfig -from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: @@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: config=WatsonXConfig.sample_run_config(), ) - embedding_provider = Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - config=SentenceTransformersInferenceConfig.sample_run_config(), - ) - - available_models = { - "watsonx": MODEL_ENTRIES, - } default_tool_groups = [ ToolGroupInput( toolgroup_id="builtin::websearch", @@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: ), ] - embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", - provider_id="sentence-transformers", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 384, - }, - ) - files_provider = Provider( provider_id="meta-reference-files", provider_type="inline::localfs", config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), ) - default_models, _ = get_model_registry(available_models) return DistributionTemplate( name=name, distro_type="remote_hosted", description="Use watsonx for running LLM inference", container_image=None, - template_path=Path(__file__).parent / "doc_template.md", + template_path=None, providers=providers, - available_models_by_provider=available_models, run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ - "inference": [inference_provider, embedding_provider], + "inference": [inference_provider], "files": [files_provider], }, - default_models=default_models + [embedding_model], + default_models=[], default_tool_groups=default_tool_groups, ), }, diff --git a/llama_stack/log.py b/llama_stack/log.py index 8aee4c9a9..ce92219f4 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -31,12 +31,17 @@ CATEGORIES = [ "client", "telemetry", "openai_responses", + "openai_conversations", "testing", "providers", "models", "files", "vector_io", "tool_runtime", + "cli", + "post_training", + "scoring", + "tests", ] UNCATEGORIZED = "uncategorized" @@ -264,11 +269,12 @@ def get_logger( if root_category in _category_levels: log_level = _category_levels[root_category] else: - log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL) if category != UNCATEGORIZED: - logging.warning( - f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}" + raise ValueError( + f"Unknown logging category: {category}. To resolve, choose a valid category from the CATEGORIES list " + f"or add it to the CATEGORIES list. Available categories: {CATEGORIES}" ) + log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL) logger.setLevel(log_level) return logging.LoggerAdapter(logger, {"category": category}) diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py index 6191df61a..16e4068d7 100644 --- a/llama_stack/models/llama/prompt_format.py +++ b/llama_stack/models/llama/prompt_format.py @@ -11,19 +11,13 @@ # top-level folder for each specific model found within the models/ directory at # the top-level of this source tree. -import json import textwrap -from pathlib import Path from pydantic import BaseModel, Field from llama_stack.models.llama.datatypes import ( RawContent, - RawMediaItem, RawMessage, - RawTextItem, - StopReason, - ToolCall, ToolPromptFormat, ) from llama_stack.models.llama.llama4.tokenizer import Tokenizer @@ -175,25 +169,6 @@ def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat return messages -def llama3_1_builtin_tool_call_with_image_dialog( - tool_prompt_format=ToolPromptFormat.json, -): - this_dir = Path(__file__).parent - with open(this_dir / "llama3/dog.jpg", "rb") as f: - img = f.read() - - interface = LLama31Interface(tool_prompt_format) - - messages = interface.system_messages(**system_message_builtin_tools_only()) - messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")]) - messages += interface.assistant_response_messages( - "Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix", - StopReason.end_of_turn, - ) - messages += interface.user_message("Search the web for some food recommendations for the indentified breed") - return messages - - def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): interface = LLama31Interface(tool_prompt_format) @@ -202,35 +177,6 @@ def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): return messages -def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): - tool_response = json.dumps(["great song1", "awesome song2", "cool song3"]) - interface = LLama31Interface(tool_prompt_format) - - messages = interface.system_messages(**system_message_custom_tools_only()) - messages += interface.user_message(content="Use tools to get latest trending songs") - messages.append( - RawMessage( - role="assistant", - content="", - stop_reason=StopReason.end_of_message, - tool_calls=[ - ToolCall( - call_id="call_id", - tool_name="trending_songs", - arguments={"n": "10", "genre": "latest"}, - ) - ], - ), - ) - messages.append( - RawMessage( - role="assistant", - content=tool_response, - ) - ) - return messages - - def llama3_2_user_assistant_conversation(): return UseCase( title="User and assistant conversation", diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 334c32e15..37b0b50c8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -22,6 +22,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap deps[Api.tool_runtime], deps[Api.tool_groups], policy, + Api.telemetry in deps, ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 207f0daec..b17c720e9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -7,8 +7,6 @@ import copy import json import re -import secrets -import string import uuid import warnings from collections.abc import AsyncGenerator @@ -84,11 +82,6 @@ from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin - -def make_random_string(length: int = 8): - return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) - - TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" @@ -110,6 +103,7 @@ class ChatAgent(ShieldRunnerMixin): persistence_store: KVStore, created_at: str, policy: list[AccessRule], + telemetry_enabled: bool = False, ): self.agent_id = agent_id self.agent_config = agent_config @@ -120,6 +114,7 @@ class ChatAgent(ShieldRunnerMixin): self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api self.created_at = created_at + self.telemetry_enabled = telemetry_enabled ShieldRunnerMixin.__init__( self, @@ -188,28 +183,30 @@ class ChatAgent(ShieldRunnerMixin): async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: turn_id = str(uuid.uuid4()) - span = tracing.get_current_span() - if span: - span.set_attribute("session_id", request.session_id) - span.set_attribute("agent_id", self.agent_id) - span.set_attribute("request", request.model_dump_json()) - span.set_attribute("turn_id", turn_id) - if self.agent_config.name: - span.set_attribute("agent_name", self.agent_config.name) + if self.telemetry_enabled: + span = tracing.get_current_span() + if span is not None: + span.set_attribute("session_id", request.session_id) + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("request", request.model_dump_json()) + span.set_attribute("turn_id", turn_id) + if self.agent_config.name: + span.set_attribute("agent_name", self.agent_config.name) await self._initialize_tools(request.toolgroups) async for chunk in self._run_turn(request, turn_id): yield chunk async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: - span = tracing.get_current_span() - if span: - span.set_attribute("agent_id", self.agent_id) - span.set_attribute("session_id", request.session_id) - span.set_attribute("request", request.model_dump_json()) - span.set_attribute("turn_id", request.turn_id) - if self.agent_config.name: - span.set_attribute("agent_name", self.agent_config.name) + if self.telemetry_enabled: + span = tracing.get_current_span() + if span is not None: + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("session_id", request.session_id) + span.set_attribute("request", request.model_dump_json()) + span.set_attribute("turn_id", request.turn_id) + if self.agent_config.name: + span.set_attribute("agent_name", self.agent_config.name) await self._initialize_tools() async for chunk in self._run_turn(request): @@ -395,9 +392,12 @@ class ChatAgent(ShieldRunnerMixin): touchpoint: str, ) -> AsyncGenerator: async with tracing.span("run_shields") as span: - span.set_attribute("input", [m.model_dump_json() for m in messages]) + if self.telemetry_enabled and span is not None: + span.set_attribute("input", [m.model_dump_json() for m in messages]) + if len(shields) == 0: + span.set_attribute("output", "no shields") + if len(shields) == 0: - span.set_attribute("output", "no shields") return step_id = str(uuid.uuid4()) @@ -430,7 +430,8 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - span.set_attribute("output", e.violation.model_dump_json()) + if self.telemetry_enabled and span is not None: + span.set_attribute("output", e.violation.model_dump_json()) yield CompletionMessage( content=str(e), @@ -453,7 +454,8 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - span.set_attribute("output", "no violations") + if self.telemetry_enabled and span is not None: + span.set_attribute("output", "no violations") async def _run( self, @@ -518,8 +520,9 @@ class ChatAgent(ShieldRunnerMixin): stop_reason: StopReason | None = None async with tracing.span("inference") as span: - if self.agent_config.name: - span.set_attribute("agent_name", self.agent_config.name) + if self.telemetry_enabled and span is not None: + if self.agent_config.name: + span.set_attribute("agent_name", self.agent_config.name) def _serialize_nested(value): """Recursively serialize nested Pydantic models to dicts.""" @@ -637,18 +640,19 @@ class ChatAgent(ShieldRunnerMixin): else: raise ValueError(f"Unexpected delta type {type(delta)}") - span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn) - span.set_attribute( - "input", - json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), - ) - output_attr = json.dumps( - { - "content": content, - "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], - } - ) - span.set_attribute("output", output_attr) + if self.telemetry_enabled and span is not None: + span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn) + span.set_attribute( + "input", + json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), + ) + output_attr = json.dumps( + { + "content": content, + "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], + } + ) + span.set_attribute("output", output_attr) n_iter += 1 await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter) @@ -756,7 +760,9 @@ class ChatAgent(ShieldRunnerMixin): { "tool_name": tool_call.tool_name, "input": message.model_dump_json(), - }, + } + if self.telemetry_enabled + else {}, ) as span: tool_execution_start_time = datetime.now(UTC).isoformat() tool_result = await self.execute_tool_call_maybe( @@ -771,7 +777,8 @@ class ChatAgent(ShieldRunnerMixin): call_id=tool_call.call_id, content=tool_result.content, ) - span.set_attribute("output", result_message.model_dump_json()) + if self.telemetry_enabled and span is not None: + span.set_attribute("output", result_message.model_dump_json()) # Store tool execution step tool_execution_step = ToolExecutionStep( diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 5431e8f28..cfaf56a34 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -64,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents): tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, policy: list[AccessRule], + telemetry_enabled: bool = False, ): self.config = config self.inference_api = inference_api @@ -71,6 +72,7 @@ class MetaReferenceAgentsImpl(Agents): self.safety_api = safety_api self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api + self.telemetry_enabled = telemetry_enabled self.in_memory_store = InmemoryKVStoreImpl() self.openai_responses_impl: OpenAIResponsesImpl | None = None @@ -135,6 +137,7 @@ class MetaReferenceAgentsImpl(Agents): ), created_at=agent_info.created_at, policy=self.policy, + telemetry_enabled=self.telemetry_enabled, ) async def create_agent_session( diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 8ccdcb0e1..245203f10 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -269,7 +269,7 @@ class OpenAIResponsesImpl: response_tools=tools, temperature=temperature, response_format=response_format, - inputs=input, + inputs=all_input, ) # Create orchestrator and delegate streaming logic diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 0bb524f5c..895d13a7f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -97,6 +97,8 @@ class StreamingResponseOrchestrator: self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {} # Track final messages after all tool executions self.final_messages: list[OpenAIMessageParam] = [] + # mapping for annotations + self.citation_files: dict[str, str] = {} async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: # Initialize output messages @@ -126,6 +128,7 @@ class StreamingResponseOrchestrator: # Text is the default response format for chat completion so don't need to pass it # (some providers don't support non-empty response_format when tools are present) response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format + logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}") completion_result = await self.inference_api.openai_chat_completion( model=self.ctx.model, messages=messages, @@ -160,7 +163,7 @@ class StreamingResponseOrchestrator: # Handle choices with no tool calls for choice in current_response.choices: if not (choice.message.tool_calls and self.ctx.response_tools): - output_messages.append(await convert_chat_choice_to_response_message(choice)) + output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files)) # Execute tool calls and coordinate results async for stream_event in self._coordinate_tool_execution( @@ -172,6 +175,8 @@ class StreamingResponseOrchestrator: ): yield stream_event + messages = next_turn_messages + if not function_tool_calls and not non_function_tool_calls: break @@ -184,9 +189,7 @@ class StreamingResponseOrchestrator: logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}") break - messages = next_turn_messages - - self.final_messages = messages.copy() + [current_response.choices[0].message] + self.final_messages = messages.copy() # Create final response final_response = OpenAIResponseObject( @@ -211,6 +214,8 @@ class StreamingResponseOrchestrator: for choice in current_response.choices: next_turn_messages.append(choice.message) + logger.debug(f"Choice message content: {choice.message.content}") + logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}") if choice.message.tool_calls and self.ctx.response_tools: for tool_call in choice.message.tool_calls: @@ -227,9 +232,11 @@ class StreamingResponseOrchestrator: non_function_tool_calls.append(tool_call) else: logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}") + next_turn_messages.pop() else: logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}") approvals.append(tool_call) + next_turn_messages.pop() else: non_function_tool_calls.append(tool_call) @@ -470,6 +477,8 @@ class StreamingResponseOrchestrator: tool_call_log = result.final_output_message tool_response_message = result.final_input_message self.sequence_number = result.sequence_number + if result.citation_files: + self.citation_files.update(result.citation_files) if tool_call_log: output_messages.append(tool_call_log) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index b028c018b..b33b47454 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -94,7 +94,10 @@ class ToolExecutor: # Yield the final result yield ToolExecutionResult( - sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message + sequence_number=sequence_number, + final_output_message=output_message, + final_input_message=input_message, + citation_files=result.metadata.get("citation_files") if result and result.metadata else None, ) async def _execute_knowledge_search_via_vector_store( @@ -129,8 +132,6 @@ class ToolExecutor: for results in all_results: search_results.extend(results) - # Convert search results to tool result format matching memory.py - # Format the results as interleaved content similar to memory.py content_items = [] content_items.append( TextContentItem( @@ -138,27 +139,58 @@ class ToolExecutor: ) ) + unique_files = set() for i, result_item in enumerate(search_results): chunk_text = result_item.content[0].text if result_item.content else "" - metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" + # Get file_id from attributes if result_item.file_id is empty + file_id = result_item.file_id or ( + result_item.attributes.get("document_id") if result_item.attributes else None + ) + metadata_text = f"document_id: {file_id}, score: {result_item.score}" if result_item.attributes: metadata_text += f", attributes: {result_item.attributes}" - text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" + + text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n" content_items.append(TextContentItem(text=text_content)) + unique_files.add(file_id) content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + + citation_instruction = "" + if unique_files: + citation_instruction = ( + " Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). " + "Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)." + ) + content_items.append( TextContentItem( - text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', + text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n', ) ) + # handling missing attributes for old versions + citation_files = {} + for result in search_results: + file_id = result.file_id + if not file_id and result.attributes: + file_id = result.attributes.get("document_id") + + filename = result.filename + if not filename and result.attributes: + filename = result.attributes.get("filename") + if not filename: + filename = "unknown" + + citation_files[file_id] = filename + return ToolInvocationResult( content=content_items, metadata={ "document_ids": [r.file_id for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results], "scores": [r.score for r in search_results], + "citation_files": citation_files, }, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py index d3b5a16bd..fd5f44242 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/types.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel): sequence_number: int final_output_message: OpenAIResponseOutput | None = None final_input_message: OpenAIMessageParam | None = None + citation_files: dict[str, str] | None = None @dataclass diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 310a88298..5b013b9c4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -4,9 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import re import uuid from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseAnnotationFileCitation, OpenAIResponseInput, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputMessageContent, @@ -45,7 +47,9 @@ from llama_stack.apis.inference import ( ) -async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: +async def convert_chat_choice_to_response_message( + choice: OpenAIChoice, citation_files: dict[str, str] | None = None +) -> OpenAIResponseMessage: """Convert an OpenAI Chat Completion choice into an OpenAI Response output message.""" output_content = "" if isinstance(choice.message.content, str): @@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" ) + annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {}) + return OpenAIResponseMessage( id=f"msg_{uuid.uuid4()}", - content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], + content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)], status="completed", role="assistant", ) @@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str): return role_to_type.get(role) +def _extract_citations_from_text( + text: str, citation_files: dict[str, str] +) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]: + """Extract citation markers from text and create annotations + + Args: + text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A] + citation_files: Dictionary mapping file_id to filename + + Returns: + Tuple of (annotations_list, clean_text_without_markers) + """ + file_id_regex = re.compile(r"<\|(?Pfile-[A-Za-z0-9_-]+)\|>") + + annotations = [] + parts = [] + total_len = 0 + last_end = 0 + + for m in file_id_regex.finditer(text): + # segment before the marker + prefix = text[last_end : m.start()] + + # drop one space if it exists (since marker is at sentence end) + if prefix.endswith(" "): + prefix = prefix[:-1] + + parts.append(prefix) + total_len += len(prefix) + + fid = m.group(1) + if fid in citation_files: + annotations.append( + OpenAIResponseAnnotationFileCitation( + file_id=fid, + filename=citation_files[fid], + index=total_len, # index points to punctuation + ) + ) + + last_end = m.end() + + parts.append(text[last_end:]) + cleaned_text = "".join(parts) + return annotations, cleaned_text + + def is_function_tool_call( tool_call: OpenAIChatCompletionToolCall, tools: list[OpenAIResponseInputTool], diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index c8499a9b8..3ccfd0bcb 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -8,8 +8,6 @@ import asyncio import base64 import io import mimetypes -import secrets -import string from typing import Any import httpx @@ -52,10 +50,6 @@ from .context_retriever import generate_rag_query log = get_logger(name=__name__, category="tool_runtime") -def make_random_string(length: int = 8): - return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) - - async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]: """Get raw binary data and mime type from a RAGDocument for file upload.""" if isinstance(doc.content, URL): @@ -331,5 +325,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti return ToolInvocationResult( content=result.content or [], - metadata=result.metadata, + metadata={ + **(result.metadata or {}), + "citation_files": getattr(result, "citation_files", None), + }, ) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 405c134e5..5a456c7c9 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -225,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - # Cleanup if needed - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def health(self) -> HealthResponse: """ diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 26231a9b7..a433257b2 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -434,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - # nothing to do since we don't maintain a persistent connection - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def list_vector_dbs(self) -> list[VectorDB]: return [v.vector_db for v in self.cache.values()] diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 57110d129..bc46b4de2 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -36,6 +36,9 @@ def available_providers() -> list[ProviderSpec]: Api.tool_runtime, Api.tool_groups, ], + optional_api_dependencies=[ + Api.telemetry, + ], description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.", ), ] diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index bf6a09b6c..f89565892 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -268,7 +268,7 @@ Available Models: api=Api.inference, adapter_type="watsonx", provider_type="remote::watsonx", - pip_packages=["ibm_watsonx_ai"], + pip_packages=["litellm"], module="llama_stack.providers.remote.inference.watsonx", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index ad8c31dfd..39dc7fccd 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import ( ProviderSpec, RemoteProviderSpec, ) +from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS def available_providers() -> list[ProviderSpec]: @@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]: InlineProviderSpec( api=Api.tool_runtime, provider_type="inline::rag-runtime", - pip_packages=[ - "chardet", - "pypdf", + pip_packages=DEFAULT_VECTOR_IO_DEPS + + [ "tqdm", "numpy", "scikit-learn", diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index ebab7aaf9..da2a68535 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import ( RemoteProviderSpec, ) +# Common dependencies for all vector IO providers that support document processing +DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"] + def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.vector_io, provider_type="inline::meta-reference", - pip_packages=["faiss-cpu"], + pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.inline.vector_io.faiss", config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", deprecation_warning="Please use the `inline::faiss` provider instead.", @@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]: InlineProviderSpec( api=Api.vector_io, provider_type="inline::faiss", - pip_packages=["faiss-cpu"], + pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.inline.vector_io.faiss", config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", api_dependencies=[Api.inference], @@ -82,7 +85,7 @@ more details about Faiss in general. InlineProviderSpec( api=Api.vector_io, provider_type="inline::sqlite-vec", - pip_packages=["sqlite-vec"], + pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.inline.vector_io.sqlite_vec", config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig", api_dependencies=[Api.inference], @@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f InlineProviderSpec( api=Api.vector_io, provider_type="inline::sqlite_vec", - pip_packages=["sqlite-vec"], + pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.inline.vector_io.sqlite_vec", config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig", deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.", @@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation. api=Api.vector_io, adapter_type="chromadb", provider_type="remote::chromadb", - pip_packages=["chromadb-client"], + pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.remote.vector_io.chroma", config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", api_dependencies=[Api.inference], @@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti InlineProviderSpec( api=Api.vector_io, provider_type="inline::chromadb", - pip_packages=["chromadb"], + pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.inline.vector_io.chroma", config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig", api_dependencies=[Api.inference], @@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti api=Api.vector_io, adapter_type="pgvector", provider_type="remote::pgvector", - pip_packages=["psycopg2-binary"], + pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.remote.vector_io.pgvector", config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", api_dependencies=[Api.inference], @@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de api=Api.vector_io, adapter_type="weaviate", provider_type="remote::weaviate", - pip_packages=["weaviate-client>=4.16.5"], + pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.remote.vector_io.weaviate", config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", @@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more InlineProviderSpec( api=Api.vector_io, provider_type="inline::qdrant", - pip_packages=["qdrant-client"], + pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.inline.vector_io.qdrant", config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig", api_dependencies=[Api.inference], @@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta api=Api.vector_io, adapter_type="qdrant", provider_type="remote::qdrant", - pip_packages=["qdrant-client"], + pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.remote.vector_io.qdrant", config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", api_dependencies=[Api.inference], @@ -607,7 +610,7 @@ Please refer to the inline provider documentation. api=Api.vector_io, adapter_type="milvus", provider_type="remote::milvus", - pip_packages=["pymilvus>=2.4.10"], + pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.remote.vector_io.milvus", config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", api_dependencies=[Api.inference], @@ -813,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi InlineProviderSpec( api=Api.vector_io, provider_type="inline::milvus", - pip_packages=["pymilvus[milvus-lite]>=2.4.10"], + pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS, module="llama_stack.providers.inline.vector_io.milvus", config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", api_dependencies=[Api.inference], diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index f4ad1be94..200b36171 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -41,9 +41,6 @@ class DatabricksInferenceAdapter(OpenAIMixin): ).serving_endpoints.list() # TODO: this is not async ] - async def should_refresh_models(self) -> bool: - return False - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py deleted file mode 100644 index 0b0d7fcf3..000000000 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import warnings -from collections.abc import AsyncGenerator -from typing import Any - -from openai import AsyncStream -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.completion import Completion as OpenAICompletion -from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionRequest, - CompletionResponse, - CompletionResponseStreamChunk, - GreedySamplingStrategy, - JsonSchemaResponseFormat, - TokenLogProbs, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_stack.providers.utils.inference.openai_compat import ( - _convert_openai_finish_reason, - convert_message_to_openai_dict_new, - convert_tooldef_to_openai_tool, -) - - -async def convert_chat_completion_request( - request: ChatCompletionRequest, - n: int = 1, -) -> dict: - """ - Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. - """ - # model -> model - # messages -> messages - # sampling_params TODO(mattf): review strategy - # strategy=greedy -> nvext.top_k = -1, temperature = temperature - # strategy=top_p -> nvext.top_k = -1, top_p = top_p - # strategy=top_k -> nvext.top_k = top_k - # temperature -> temperature - # top_p -> top_p - # top_k -> nvext.top_k - # max_tokens -> max_tokens - # repetition_penalty -> nvext.repetition_penalty - # response_format -> GrammarResponseFormat TODO(mf) - # response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema - # tools -> tools - # tool_choice ("auto", "required") -> tool_choice - # tool_prompt_format -> TBD - # stream -> stream - # logprobs -> logprobs - - if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported." - ) - - nvext = {} - payload: dict[str, Any] = dict( - model=request.model, - messages=[await convert_message_to_openai_dict_new(message) for message in request.messages], - stream=request.stream, - n=n, - extra_body=dict(nvext=nvext), - extra_headers={ - b"User-Agent": b"llama-stack: nvidia-inference-adapter", - }, - ) - - if request.response_format: - # server bug - setting guided_json changes the behavior of response_format resulting in an error - # payload.update(response_format="json_object") - nvext.update(guided_json=request.response_format.json_schema) - - if request.tools: - payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools]) - if request.tool_config.tool_choice: - payload.update( - tool_choice=request.tool_config.tool_choice.value - ) # we cannot include tool_choice w/o tools, server will complain - - if request.logprobs: - payload.update(logprobs=True) - payload.update(top_logprobs=request.logprobs.top_k) - - if request.sampling_params: - nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) - - if request.sampling_params.max_tokens: - payload.update(max_tokens=request.sampling_params.max_tokens) - - strategy = request.sampling_params.strategy - if isinstance(strategy, TopPSamplingStrategy): - nvext.update(top_k=-1) - payload.update(top_p=strategy.top_p) - payload.update(temperature=strategy.temperature) - elif isinstance(strategy, TopKSamplingStrategy): - if strategy.top_k != -1 and strategy.top_k < 1: - warnings.warn("top_k must be -1 or >= 1", stacklevel=2) - nvext.update(top_k=strategy.top_k) - elif isinstance(strategy, GreedySamplingStrategy): - nvext.update(top_k=-1) - else: - raise ValueError(f"Unsupported sampling strategy: {strategy}") - - return payload - - -def convert_completion_request( - request: CompletionRequest, - n: int = 1, -) -> dict: - """ - Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. - """ - # model -> model - # prompt -> prompt - # sampling_params TODO(mattf): review strategy - # strategy=greedy -> nvext.top_k = -1, temperature = temperature - # strategy=top_p -> nvext.top_k = -1, top_p = top_p - # strategy=top_k -> nvext.top_k = top_k - # temperature -> temperature - # top_p -> top_p - # top_k -> nvext.top_k - # max_tokens -> max_tokens - # repetition_penalty -> nvext.repetition_penalty - # response_format -> nvext.guided_json - # stream -> stream - # logprobs.top_k -> logprobs - - nvext = {} - payload: dict[str, Any] = dict( - model=request.model, - prompt=request.content, - stream=request.stream, - extra_body=dict(nvext=nvext), - extra_headers={ - b"User-Agent": b"llama-stack: nvidia-inference-adapter", - }, - n=n, - ) - - if request.response_format: - # this is not openai compliant, it is a nim extension - nvext.update(guided_json=request.response_format.json_schema) - - if request.logprobs: - payload.update(logprobs=request.logprobs.top_k) - - if request.sampling_params: - nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) - - if request.sampling_params.max_tokens: - payload.update(max_tokens=request.sampling_params.max_tokens) - - if request.sampling_params.strategy == "top_p": - nvext.update(top_k=-1) - payload.update(top_p=request.sampling_params.top_p) - elif request.sampling_params.strategy == "top_k": - if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1: - warnings.warn("top_k must be -1 or >= 1", stacklevel=2) - nvext.update(top_k=request.sampling_params.top_k) - elif request.sampling_params.strategy == "greedy": - nvext.update(top_k=-1) - payload.update(temperature=request.sampling_params.temperature) - - return payload - - -def _convert_openai_completion_logprobs( - logprobs: OpenAICompletionLogprobs | None, -) -> list[TokenLogProbs] | None: - """ - Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. - """ - if not logprobs: - return None - - return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs] - - -def convert_openai_completion_choice( - choice: OpenAIChoice, -) -> CompletionResponse: - """ - Convert an OpenAI Completion Choice into a CompletionResponse. - """ - return CompletionResponse( - content=choice.text, - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - logprobs=_convert_openai_completion_logprobs(choice.logprobs), - ) - - -async def convert_openai_completion_stream( - stream: AsyncStream[OpenAICompletion], -) -> AsyncGenerator[CompletionResponse, None]: - """ - Convert a stream of OpenAI Completions into a stream - of ChatCompletionResponseStreamChunks. - """ - async for chunk in stream: - choice = chunk.choices[0] - yield CompletionResponseStreamChunk( - delta=choice.text, - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - logprobs=_convert_openai_completion_logprobs(choice.logprobs), - ) diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index b8431e859..46ee939d9 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -4,53 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import httpx - -from llama_stack.log import get_logger - from . import NVIDIAConfig -logger = get_logger(name=__name__, category="inference::nvidia") - def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: return "integrate.api.nvidia.com" in config.url - - -async def _get_health(url: str) -> tuple[bool, bool]: - """ - Query {url}/v1/health/{live,ready} to check if the server is running and ready - - Args: - url (str): URL of the server - - Returns: - Tuple[bool, bool]: (is_live, is_ready) - """ - async with httpx.AsyncClient() as client: - live = await client.get(f"{url}/v1/health/live") - ready = await client.get(f"{url}/v1/health/ready") - return live.status_code == 200, ready.status_code == 200 - - -async def check_health(config: NVIDIAConfig) -> None: - """ - Check if the server is running and ready - - Args: - url (str): URL of the server - - Raises: - RuntimeError: If the server is not running or ready - """ - if not _is_nvidia_hosted(config): - logger.info("Checking NVIDIA NIM health...") - try: - is_live, is_ready = await _get_health(config.url) - if not is_live: - raise ConnectionError("NVIDIA NIM is not running") - if not is_ready: - raise ConnectionError("NVIDIA NIM is not ready") - # TODO(mf): should we wait for the server to be ready? - except httpx.ConnectError as e: - raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index d2f104e1e..1e4ce9113 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -6,8 +6,6 @@ from typing import Any -from pydantic import Field - from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig DEFAULT_OLLAMA_URL = "http://localhost:11434" @@ -15,10 +13,6 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(RemoteInferenceProviderConfig): url: str = DEFAULT_OLLAMA_URL - refresh_models: bool = Field( - default=False, - description="Whether to refresh models periodically", - ) @classmethod def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index e5b08997c..67d0caa54 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -72,9 +72,6 @@ class OllamaInferenceAdapter(OpenAIMixin): f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal" ) - async def should_refresh_models(self) -> bool: - return self.config.refresh_models - async def health(self) -> HealthResponse: """ Performs a health check by verifying connectivity to the Ollama server. diff --git a/llama_stack/providers/remote/inference/runpod/__init__.py b/llama_stack/providers/remote/inference/runpod/__init__.py index 69bf95046..d1fd2b718 100644 --- a/llama_stack/providers/remote/inference/runpod/__init__.py +++ b/llama_stack/providers/remote/inference/runpod/__init__.py @@ -11,6 +11,6 @@ async def get_adapter_impl(config: RunpodImplConfig, _deps): from .runpod import RunpodInferenceAdapter assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}" - impl = RunpodInferenceAdapter(config) + impl = RunpodInferenceAdapter(config=config) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 08652f8c0..f752740e5 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -4,69 +4,86 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.inference import OpenAIEmbeddingsResponse - -# from llama_stack.providers.datatypes import ModelsProtocolPrivate -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry -from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_options, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, +from llama_stack.apis.inference import ( + OpenAIMessageParam, + OpenAIResponseFormatParam, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import RunpodImplConfig -# https://docs.runpod.io/serverless/vllm/overview#compatible-models -# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures -RUNPOD_SUPPORTED_MODELS = { - "Llama3.1-8B": "meta-llama/Llama-3.1-8B", - "Llama3.1-70B": "meta-llama/Llama-3.1-70B", - "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", - "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", - "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", - "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", - "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", - "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", - "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.2-1B": "meta-llama/Llama-3.2-1B", - "Llama3.2-3B": "meta-llama/Llama-3.2-3B", -} -SAFETY_MODELS_ENTRIES = [] +class RunpodInferenceAdapter(OpenAIMixin): + """ + Adapter for RunPod's OpenAI-compatible API endpoints. + Supports VLLM for serverless endpoint self-hosted or public endpoints. + Can work with any runpod endpoints that support OpenAI-compatible API + """ -# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template -MODEL_ENTRIES = [ - build_hf_repo_model_entry(provider_model_id, model_descriptor) - for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items() -] + SAFETY_MODELS_ENTRIES + config: RunpodImplConfig + def get_api_key(self) -> str: + """Get API key for OpenAI client.""" + return self.config.api_token -class RunpodInferenceAdapter( - ModelRegistryHelper, - Inference, -): - def __init__(self, config: RunpodImplConfig) -> None: - ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) - self.config = config + def get_base_url(self) -> str: + """Get base URL for OpenAI client.""" + return self.config.url - def _get_params(self, request: ChatCompletionRequest) -> dict: - return { - "model": self.map_to_provider_model(request.model), - "prompt": chat_completion_request_to_prompt(request), - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - - async def openai_embeddings( + async def openai_chat_completion( self, model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + ): + """Override to add RunPod-specific stream_options requirement.""" + if stream and not stream_options: + stream_options = {"include_usage": True} + + return await super().openai_chat_completion( + model=model, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index fbefe630f..224de6721 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -63,9 +63,6 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData): # Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client return [m.id for m in await self._get_client().models.list()] - async def should_refresh_models(self) -> bool: - return True - async def openai_embeddings( self, model: str, diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 86ef3fe26..87c5408d3 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -30,10 +30,6 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): default=True, description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", ) - refresh_models: bool = Field( - default=False, - description="Whether to refresh models periodically", - ) @field_validator("tls_verify") @classmethod diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 4e7884cd2..310eaf7b6 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -53,10 +53,6 @@ class VLLMInferenceAdapter(OpenAIMixin): "You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM." ) - async def should_refresh_models(self) -> bool: - # Strictly respecting the refresh_models directive - return self.config.refresh_models - async def health(self) -> HealthResponse: """ Performs a health check by verifying connectivity to the remote vLLM server. diff --git a/llama_stack/providers/remote/inference/watsonx/__init__.py b/llama_stack/providers/remote/inference/watsonx/__init__.py index e59e873b6..35e74a720 100644 --- a/llama_stack/providers/remote/inference/watsonx/__init__.py +++ b/llama_stack/providers/remote/inference/watsonx/__init__.py @@ -4,19 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import Inference - from .config import WatsonXConfig -async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference: - # import dynamically so `llama stack build` does not fail due to missing dependencies +async def get_adapter_impl(config: WatsonXConfig, _deps): + # import dynamically so the import is used only when it is needed from .watsonx import WatsonXInferenceAdapter - if not isinstance(config, WatsonXConfig): - raise RuntimeError(f"Unexpected config type: {type(config)}") adapter = WatsonXInferenceAdapter(config) return adapter - - -__all__ = ["get_adapter_impl", "WatsonXConfig"] diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 4bc0173c4..9e98d4003 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,16 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, ConfigDict, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - url: str - api_key: str - project_id: str + model_config = ConfigDict( + from_attributes=True, + extra="forbid", + ) + watsonx_api_key: str | None @json_schema_type @@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig): default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), description="A base url for accessing the watsonx.ai", ) + # This seems like it should be required, but none of the other remote inference + # providers require it, so this is optional here too for consistency. + # The OpenAIConfig uses default=None instead, so this is following that precedent. api_key: SecretStr | None = Field( - default_factory=lambda: os.getenv("WATSONX_API_KEY"), - description="The watsonx API key", + default=None, + description="The watsonx.ai API key", ) + # As above, this is optional here too for consistency. project_id: str | None = Field( - default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), - description="The Project ID key", + default=None, + description="The watsonx.ai project ID", ) timeout: int = Field( default=60, diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py deleted file mode 100644 index d98f0510a..000000000 --- a/llama_stack/providers/remote/inference/watsonx/models.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "meta-llama/llama-3-3-70b-instruct", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-2-13b-chat", - CoreModelId.llama2_13b.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-1-70b-instruct", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-1-8b-instruct", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-1b-instruct", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-3b-instruct", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-guard-3-11b-vision", - CoreModelId.llama_guard_3_11b_vision.value, - ), -] diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index fc58691e2..d04472936 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,240 +4,120 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator, AsyncIterator from typing import Any -from ibm_watsonx_ai.foundation_models import Model -from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams -from openai import AsyncOpenAI +import requests -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionRequest, - GreedySamplingStrategy, - Inference, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingsResponse, - OpenAIMessageParam, - OpenAIResponseFormatParam, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper -from llama_stack.providers.utils.inference.openai_compat import ( - prepare_openai_completion_params, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, - completion_request_to_prompt, - request_has_media, -) - -from . import WatsonXConfig -from .models import MODEL_ENTRIES - -logger = get_logger(name=__name__, category="inference::watsonx") +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.models import Model +from llama_stack.apis.models.models import ModelType +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -# Note on structured output -# WatsonX returns responses with a json embedded into a string. -# Examples: +class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): + _model_cache: dict[str, Model] = {} -# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n -# "first_name": "Michael",\n "last_name": "Jordan",\n'...) -# Not even a valid JSON, but we can still extract the JSON from the content + def __init__(self, config: WatsonXConfig): + LiteLLMOpenAIMixin.__init__( + self, + litellm_provider_name="watsonx", + api_key_from_config=config.api_key.get_secret_value() if config.api_key else None, + provider_data_api_key_field="watsonx_api_key", + ) + self.available_models = None + self.config = config -# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan", -# "year_born": "1963", "year_retired": "2003"\\}}$') -# Find the start of the boxed content + def get_base_url(self) -> str: + return self.config.url + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) -class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): - def __init__(self, config: WatsonXConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) - - logger.info(f"Initializing watsonx InferenceAdapter({config.url})...") - self._config = config - self._openai_client: AsyncOpenAI | None = None - - self._project_id = self._config.project_id - - def _get_client(self, model_id) -> Model: - config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None - config_url = self._config.url - project_id = self._config.project_id - credentials = {"url": config_url, "apikey": config_api_key} - - return Model(model_id=model_id, credentials=credentials, project_id=project_id) - - def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - self._openai_client = AsyncOpenAI( - base_url=f"{self._config.url}/openai/v1", - api_key=self._config.api_key, - ) - return self._openai_client - - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: - input_dict = {"params": {}} - media_present = request_has_media(request) - llama_model = self.get_llama_model(request.model) - if isinstance(request, ChatCompletionRequest): - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) - else: - assert not media_present, "Together does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request) - if request.sampling_params: - if request.sampling_params.strategy: - input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type - if request.sampling_params.max_tokens: - input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens - if request.sampling_params.repetition_penalty: - input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty - - if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature - if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k - if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): - input_dict["params"][GenParams.TEMPERATURE] = 0.0 - - input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] - - params = { - **input_dict, - } + # Add watsonx.ai specific parameters + params["project_id"] = self.config.project_id + params["time_limit"] = self.config.timeout return params - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + # Copied from OpenAIMixin + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from the provider's /v1/models. - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - ) - return await self._get_openai_client().completions.create(**params) # type: ignore + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + if not self._model_cache: + await self.list_models() + return model in self._model_cache - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - if params.get("stream", False): - return self._stream_openai_chat_completion(params) - return await self._get_openai_client().chat.completions.create(**params) # type: ignore + async def list_models(self) -> list[Model] | None: + self._model_cache = {} + models = [] + for model_spec in self._get_model_specs(): + functions = [f["id"] for f in model_spec.get("functions", [])] + # Format: {"embedding_dimension": 1536, "context_length": 8192} - async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: - # watsonx.ai sometimes adds usage data to the stream - include_usage = False - if params.get("stream_options", None): - include_usage = params["stream_options"].get("include_usage", False) - stream = await self._get_openai_client().chat.completions.create(**params) + # Example of an embedding model: + # {'model_id': 'ibm/granite-embedding-278m-multilingual', + # 'label': 'granite-embedding-278m-multilingual', + # 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768}, + # ... + provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}" + if "embedding" in functions: + embedding_dimension = model_spec["model_limits"]["embedding_dimension"] + context_length = model_spec["model_limits"]["max_sequence_length"] + embedding_metadata = { + "embedding_dimension": embedding_dimension, + "context_length": context_length, + } + model = Model( + identifier=model_spec["model_id"], + provider_resource_id=provider_resource_id, + provider_id=self.__provider_id__, + metadata=embedding_metadata, + model_type=ModelType.embedding, + ) + self._model_cache[provider_resource_id] = model + models.append(model) + if "text_chat" in functions: + model = Model( + identifier=model_spec["model_id"], + provider_resource_id=provider_resource_id, + provider_id=self.__provider_id__, + metadata={}, + model_type=ModelType.llm, + ) + # In theory, I guess it is possible that a model could be both an embedding model and a text chat model. + # In that case, the cache will record the generator Model object, and the list which we return will have + # both the generator Model object and the text chat Model object. That's fine because the cache is + # only used for check_model_availability() anyway. + self._model_cache[provider_resource_id] = model + models.append(model) + return models - seen_finish_reason = False - async for chunk in stream: - # Final usage chunk with no choices that the user didn't request, so discard - if not include_usage and seen_finish_reason and len(chunk.choices) == 0: - break - yield chunk - for choice in chunk.choices: - if choice.finish_reason: - seen_finish_reason = True - break + # LiteLLM provides methods to list models for many providers, but not for watsonx.ai. + # So we need to implement our own method to list models by calling the watsonx.ai API. + def _get_model_specs(self) -> list[dict[str, Any]]: + """ + Retrieves foundation model specifications from the watsonx.ai API. + """ + url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25" + headers = { + # Note that there is no authorization header. Listing models does not require authentication. + "Content-Type": "application/json", + } + + response = requests.get(url, headers=headers) + + # --- Process the Response --- + # Raise an exception for bad status codes (4xx or 5xx) + response.raise_for_status() + + # If the request is successful, parse and return the JSON response. + # The response should contain a list of model specifications + response_data = response.json() + if "resources" not in response_data: + raise ValueError("Resources not found in response") + return response_data["resources"] diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 511123d6e..331e5432e 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -167,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 0acc90595..029eacfe3 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -349,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def shutdown(self) -> None: self.client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index dfdfef6eb..21c388b1d 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -390,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco if self.conn is not None: self.conn.close() log.info("Connection to PGVector database server closed") + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db(self, vector_db: VectorDB) -> None: # Persist vector DB metadata in the KV store diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 6b386840c..021938afd 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -191,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def shutdown(self) -> None: await self.client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 54ac6f8d3..21df3bc45 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -347,6 +347,8 @@ class WeaviateVectorIOAdapter( async def shutdown(self) -> None: for client in self.client_cache.values(): client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 6c8f61c3b..6bef97dd5 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import struct from collections.abc import AsyncIterator from typing import Any @@ -16,6 +18,7 @@ from llama_stack.apis.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, OpenAIMessageParam, @@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( - b64_encode_openai_embeddings_response, convert_message_to_openai_dict_new, convert_tooldef_to_openai_tool, get_sampling_options, @@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin( return False return model in litellm.models_by_provider[self.litellm_provider_name] + + +def b64_encode_openai_embeddings_response( + response_data: list[dict], encoding_format: str | None = "float" +) -> list[OpenAIEmbeddingData]: + """ + Process the OpenAI embeddings response to encode the embeddings in base64 format if specified. + """ + data = [] + for i, embedding_data in enumerate(response_data): + if encoding_format == "base64": + byte_array = bytearray() + for embedding_value in embedding_data["embedding"]: + byte_array.extend(struct.pack("f", float(embedding_value))) + + response_embedding = base64.b64encode(byte_array).decode("utf-8") + else: + response_embedding = embedding_data["embedding"] + data.append( + OpenAIEmbeddingData( + embedding=response_embedding, + index=i, + ) + ) + return data diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 4913c2e1f..9d42d68c6 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -24,6 +24,10 @@ class RemoteInferenceProviderConfig(BaseModel): default=None, description="List of models that should be registered with the model registry. If None, all models are allowed.", ) + refresh_models: bool = Field( + default=False, + description="Whether to refresh models periodically from the provider", + ) # TODO: this class is more confusing than useful right now. We need to make it diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index d863eb53a..7e465a14c 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,9 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 import json -import struct import time import uuid import warnings @@ -103,7 +101,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, Message, OpenAIChatCompletion, - OpenAIEmbeddingData, OpenAIMessageParam, OpenAIResponseFormatParam, SamplingParams, @@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params( params["user"] = user return params - - -def b64_encode_openai_embeddings_response( - response_data: dict, encoding_format: str | None = "float" -) -> list[OpenAIEmbeddingData]: - """ - Process the OpenAI embeddings response to encode the embeddings in base64 format if specified. - """ - data = [] - for i, embedding_data in enumerate(response_data): - if encoding_format == "base64": - byte_array = bytearray() - for embedding_value in embedding_data.embedding: - byte_array.extend(struct.pack("f", float(embedding_value))) - - response_embedding = base64.b64encode(byte_array).decode("utf-8") - else: - response_embedding = embedding_data.embedding - data.append( - OpenAIEmbeddingData( - embedding=response_embedding, - index=i, - ) - ) - return data diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 9137013ee..cba7508a2 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -474,17 +474,23 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): async def check_model_availability(self, model: str) -> bool: """ - Check if a specific model is available from the provider's /v1/models. + Check if a specific model is available from the provider's /v1/models or pre-registered. :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + :return: True if the model is available dynamically or pre-registered, False otherwise. """ + # First check if the model is pre-registered in the model store + if hasattr(self, "model_store") and self.model_store: + if await self.model_store.has_model(model): + return True + + # Then check the provider's dynamic model cache if not self._model_cache: await self.list_models() return model in self._model_cache async def should_refresh_models(self) -> bool: - return False + return self.config.refresh_models # # The model_dump implementations are to avoid serializing the extra fields, diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 0d0aa25a4..c179eba6c 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -293,6 +293,18 @@ class OpenAIVectorStoreMixin(ABC): await self._resume_incomplete_batches() self._last_file_batch_cleanup_time = 0 + async def shutdown(self) -> None: + """Clean up mixin resources including background tasks.""" + # Cancel any running file batch tasks gracefully + tasks_to_cancel = list(self._file_batch_tasks.items()) + for _, task in tasks_to_cancel: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + @abstractmethod async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete chunks from a vector store.""" @@ -587,7 +599,7 @@ class OpenAIVectorStoreMixin(ABC): content = self._chunk_to_vector_store_content(chunk) response_data_item = VectorStoreSearchResponse( - file_id=chunk.metadata.get("file_id", ""), + file_id=chunk.metadata.get("document_id", ""), filename=chunk.metadata.get("filename", ""), score=score, attributes=chunk.metadata, @@ -746,12 +758,15 @@ class OpenAIVectorStoreMixin(ABC): content = content_from_data_and_mime_type(content_response.body, mime_type) + chunk_attributes = attributes.copy() + chunk_attributes["filename"] = file_response.filename + chunks = make_overlapped_chunks( file_id, content, max_chunk_size_tokens, chunk_overlap_tokens, - attributes, + chunk_attributes, ) if not chunks: vector_store_file_object.status = "failed" diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 857fbe910..c0534a875 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -20,7 +20,6 @@ from pydantic import BaseModel from llama_stack.apis.common.content_types import ( URL, InterleavedContent, - TextContentItem, ) from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB @@ -129,26 +128,6 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en return "" -def concat_interleaved_content(content: list[InterleavedContent]) -> InterleavedContent: - """concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list""" - - ret = [] - - def _process(c): - if isinstance(c, str): - ret.append(TextContentItem(text=c)) - elif isinstance(c, list): - for item in c: - _process(item) - else: - ret.append(c) - - for c in content: - _process(c) - - return ret - - async def content_from_doc(doc: RAGDocument) -> str: if isinstance(doc.content, URL): if doc.content.uri.startswith("data:"): diff --git a/scripts/install.sh b/scripts/install.sh index f6fbc259c..571468dc5 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -221,8 +221,8 @@ fi cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \ --network llama-net \ -p "${PORT}:${PORT}" \ - "${SERVER_IMAGE}" --port "${PORT}" \ - --env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}") + -e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \ + "${SERVER_IMAGE}" --port "${PORT}") log "🦙 Starting Llama Stack..." if ! execute_with_log $ENGINE "${cmd[@]}"; then diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index b009ad696..4ae73f170 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -191,9 +191,11 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then echo "Llama Stack Server is already running, skipping start" else echo "=== Starting Llama Stack Server ===" - # Set a reasonable log width for better readability in server.log export LLAMA_STACK_LOG_WIDTH=120 - nohup llama stack run ci-tests --image-type venv > server.log 2>&1 & + + # remove "server:" from STACK_CONFIG + stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://') + nohup llama stack run $stack_config > server.log 2>&1 & echo "Waiting for Llama Stack Server to start..." for i in {1..30}; do diff --git a/scripts/telemetry/setup_telemetry.sh b/scripts/telemetry/setup_telemetry.sh index e0b57a354..ecdd56175 100755 --- a/scripts/telemetry/setup_telemetry.sh +++ b/scripts/telemetry/setup_telemetry.sh @@ -16,10 +16,19 @@ set -Eeuo pipefail -CONTAINER_RUNTIME=${CONTAINER_RUNTIME:-docker} -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +if command -v podman &> /dev/null; then + CONTAINER_RUNTIME="podman" +elif command -v docker &> /dev/null; then + CONTAINER_RUNTIME="docker" +else + echo "🚨 Neither Podman nor Docker could be found" + echo "Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation" + exit 1 +fi -echo "🚀 Setting up telemetry stack for Llama Stack using Podman..." +echo "🚀 Setting up telemetry stack for Llama Stack using $CONTAINER_RUNTIME..." + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then echo "🚨 $CONTAINER_RUNTIME could not be found" diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 54a9dd72e..a1c3d1e95 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry): non_existent = await table.get_object_by_identifier("model", "non-existent-model") assert non_existent is None + # Test has_model + assert await table.has_model("test_provider/test-model") + assert await table.has_model("test_provider/test-model-2") + assert not await table.has_model("non-existent-model") + assert not await table.has_model("test_provider/non-existent-model") + await table.unregister_model(model_id="test_provider/test-model") await table.unregister_model(model_id="test_provider/test-model-2") diff --git a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py index 187540f82..2698b88c8 100644 --- a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py +++ b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py @@ -8,6 +8,7 @@ import pytest from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseAnnotationFileCitation, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, @@ -35,6 +36,7 @@ from llama_stack.apis.inference import ( OpenAIUserMessageParam, ) from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( + _extract_citations_from_text, convert_chat_choice_to_response_message, convert_response_content_to_chat_content, convert_response_input_to_chat_messages, @@ -340,3 +342,26 @@ class TestIsFunctionToolCall: result = is_function_tool_call(tool_call, tools) assert result is False + + +class TestExtractCitationsFromText: + def test_extract_citations_and_annotations(self): + text = "Start [not-a-file]. New source <|file-abc123|>. " + text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation." + file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"} + + annotations, cleaned_text = _extract_citations_from_text(text, file_mapping) + + expected_annotations = [ + OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30), + OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44), + OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59), + ] + expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation." + + assert cleaned_text == expected_clean_text + assert annotations == expected_annotations + # OpenAI cites at the end of the sentence + assert cleaned_text[expected_annotations[0].index] == "." + assert cleaned_text[expected_annotations[1].index] == "?" + assert cleaned_text[expected_annotations[2].index] == "!" diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index d30b5b12a..55a6793c2 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter @pytest.mark.parametrize( @@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} ): assert inference_adapter.client.api_key == api_key + + +@pytest.mark.parametrize( + "config_cls,adapter_cls,provider_data_validator", + [ + ( + WatsonXConfig, + WatsonXInferenceAdapter, + "llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", + ), + ], +) +def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str): + """Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the + assumption that there is an OpenAI-compatible client object.""" + + inference_adapter = adapter_cls(config=config_cls()) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + assert inference_adapter.get_api_key() == api_key diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 2806f618c..6d6bb20d5 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -186,43 +186,3 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" - - -async def test_should_refresh_models(): - """ - Test the should_refresh_models method with different refresh_models configurations. - - This test verifies that: - 1. When refresh_models is True, should_refresh_models returns True regardless of api_token - 2. When refresh_models is False, should_refresh_models returns False regardless of api_token - """ - - # Test case 1: refresh_models is True, api_token is None - config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True) - adapter1 = VLLMInferenceAdapter(config=config1) - result1 = await adapter1.should_refresh_models() - assert result1 is True, "should_refresh_models should return True when refresh_models is True" - - # Test case 2: refresh_models is True, api_token is empty string - config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True) - adapter2 = VLLMInferenceAdapter(config=config2) - result2 = await adapter2.should_refresh_models() - assert result2 is True, "should_refresh_models should return True when refresh_models is True" - - # Test case 3: refresh_models is True, api_token is "fake" (default) - config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True) - adapter3 = VLLMInferenceAdapter(config=config3) - result3 = await adapter3.should_refresh_models() - assert result3 is True, "should_refresh_models should return True when refresh_models is True" - - # Test case 4: refresh_models is True, api_token is real token - config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) - adapter4 = VLLMInferenceAdapter(config=config4) - result4 = await adapter4.should_refresh_models() - assert result4 is True, "should_refresh_models should return True when refresh_models is True" - - # Test case 5: refresh_models is False, api_token is real token - config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False) - adapter5 = VLLMInferenceAdapter(config=config5) - result5 = await adapter5.should_refresh_models() - assert result5 is False, "should_refresh_models should return False when refresh_models is False" diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index ac4c29fea..ad9406951 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -44,11 +44,12 @@ def mixin(): config = RemoteInferenceProviderConfig() mixin_instance = OpenAIMixinImpl(config=config) - # just enough to satisfy _get_provider_model_id calls - mock_model_store = MagicMock() + # Mock model_store with async methods + mock_model_store = AsyncMock() mock_model = MagicMock() mock_model.provider_resource_id = "test-provider-resource-id" mock_model_store.get_model = AsyncMock(return_value=mock_model) + mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override mixin_instance.model_store = mock_model_store return mixin_instance @@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability: assert len(mixin._model_cache) == 3 + async def test_check_model_availability_with_pre_registered_model( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability returns True for pre-registered models in model_store""" + # Mock model_store.has_model to return True for a specific model + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=True) + mixin.model_store = mock_model_store + + # Test that pre-registered model is found without calling the provider's API + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("pre-registered-model") + # Should not call the provider's list_models since model was found in store + mock_client_with_models.models.list.assert_not_called() + mock_model_store.has_model.assert_called_once_with("pre-registered-model") + + async def test_check_model_availability_fallback_to_provider_when_not_in_store( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability falls back to provider when model not in store""" + # Mock model_store.has_model to return False + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=False) + mixin.model_store = mock_model_store + + # Test that it falls back to provider's model cache + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("some-mock-model-id") + # Should call the provider's list_models since model was not found in store + mock_client_with_models.models.list.assert_called_once() + mock_model_store.has_model.assert_called_once_with("some-mock-model-id") + class TestOpenAIMixinCacheBehavior: """Test cases for cache behavior and edge cases""" @@ -466,10 +501,16 @@ class TestOpenAIMixinModelRegistration: assert result is None async def test_should_refresh_models(self, mixin): - """Test should_refresh_models method (should always return False)""" + """Test should_refresh_models method returns config value""" + # Default config has refresh_models=False result = await mixin.should_refresh_models() assert result is False + config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True) + mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh) + result_with_refresh = await mixin_with_refresh.should_refresh_models() + assert result_with_refresh is True + async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context): """Test that errors from provider API are properly propagated during registration""" model = Model( diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 70ace695e..d122f9323 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -145,10 +145,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): @pytest.fixture -async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension): +async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): config = SQLiteVectorIOConfig( db_path=sqlite_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = SQLiteVecVectorIOAdapter( config=config, @@ -187,10 +187,10 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): @pytest.fixture -async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): +async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api): config = MilvusVectorIOConfig( db_path=milvus_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = MilvusVectorIOAdapter( config=config, @@ -264,10 +264,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension): @pytest.fixture -async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension): +async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): config = ChromaVectorIOConfig( db_path=chroma_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = ChromaVectorIOAdapter( config=config, @@ -296,12 +296,12 @@ def qdrant_vec_db_path(tmp_path_factory): @pytest.fixture -async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension): +async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): import uuid config = QdrantVectorIOConfig( db_path=qdrant_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = QdrantVectorIOAdapter( config=config, @@ -386,14 +386,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection): @pytest.fixture -async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): +async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension): config = PGVectorVectorIOConfig( host="localhost", port=5432, db="test_db", user="test_user", password="test_password", - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None) @@ -476,7 +476,7 @@ async def weaviate_vec_index(weaviate_vec_db_path): @pytest.fixture -async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension): +async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): import pytest_socket import weaviate @@ -492,7 +492,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi config = WeaviateVectorIOConfig( weaviate_cluster_url="localhost:8080", weaviate_api_key=None, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = WeaviateVectorIOAdapter( config=config, diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 4ea4a20b9..c1f834d5d 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -125,8 +125,15 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): provider_resource_id="test_vector_db_2", provider_id="baz", # Same provider_id ) - await cached_disk_dist_registry.register(duplicate_vector_db) + # Now we expect a ValueError to be raised for duplicate registration + with pytest.raises( + ValueError, + match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.", + ): + await cached_disk_dist_registry.register(duplicate_vector_db) + + # Verify the original registration is still intact result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None assert result.embedding_model == original_vector_db.embedding_model # Original values preserved