mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
Merge remote-tracking branch 'origin/main' into test_isolation_server
This commit is contained in:
commit
889b2716ef
107 changed files with 817 additions and 1298 deletions
2
.github/workflows/integration-auth-tests.yml
vendored
2
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -86,7 +86,7 @@ jobs:
|
||||||
|
|
||||||
# avoid line breaks in the server log, especially because we grep it below.
|
# avoid line breaks in the server log, especially because we grep it below.
|
||||||
export COLUMNS=1984
|
export COLUMNS=1984
|
||||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
nohup uv run llama stack run $run_dir/run.yaml > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
run: |
|
run: |
|
||||||
|
|
|
||||||
2
.github/workflows/stale_bot.yml
vendored
2
.github/workflows/stale_bot.yml
vendored
|
|
@ -24,7 +24,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Stale Action
|
- name: Stale Action
|
||||||
uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
|
uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
|
||||||
with:
|
with:
|
||||||
stale-issue-label: 'stale'
|
stale-issue-label: 'stale'
|
||||||
stale-issue-message: >
|
stale-issue-message: >
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ jobs:
|
||||||
# Use the virtual environment created by the build step (name comes from build config)
|
# Use the virtual environment created by the build step (name comes from build config)
|
||||||
source ramalama-stack-test/bin/activate
|
source ramalama-stack-test/bin/activate
|
||||||
uv pip list
|
uv pip list
|
||||||
nohup llama stack run tests/external/ramalama-stack/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
nohup llama stack run tests/external/ramalama-stack/run.yaml > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
run: |
|
run: |
|
||||||
|
|
|
||||||
2
.github/workflows/test-external.yml
vendored
2
.github/workflows/test-external.yml
vendored
|
|
@ -59,7 +59,7 @@ jobs:
|
||||||
# Use the virtual environment created by the build step (name comes from build config)
|
# Use the virtual environment created by the build step (name comes from build config)
|
||||||
source ci-test/bin/activate
|
source ci-test/bin/activate
|
||||||
uv pip list
|
uv pip list
|
||||||
nohup llama stack run tests/external/run-byoa.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
nohup llama stack run tests/external/run-byoa.yaml > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
run: |
|
run: |
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ You can access the HuggingFace trainer via the `starter` distribution:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --distro starter --image-type venv
|
llama stack build --distro starter --image-type venv
|
||||||
llama stack run --image-type venv ~/.llama/distributions/starter/starter-run.yaml
|
llama stack run ~/.llama/distributions/starter/starter-run.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
|
||||||
|
|
@ -219,13 +219,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
||||||
<TabItem value="setup" label="Setup & Configuration">
|
<TabItem value="setup" label="Setup & Configuration">
|
||||||
|
|
||||||
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
||||||
2. [Optional] Provide the API key directly to the Llama Stack server
|
2. [Optional] Set the API key in your environment before starting the Llama Stack server
|
||||||
```bash
|
```bash
|
||||||
export TAVILY_SEARCH_API_KEY="your key"
|
export TAVILY_SEARCH_API_KEY="your key"
|
||||||
```
|
```
|
||||||
```bash
|
|
||||||
--env TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY}
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="implementation" label="Implementation">
|
<TabItem value="implementation" label="Implementation">
|
||||||
|
|
@ -273,9 +270,9 @@ for log in EventLogger().log(response):
|
||||||
<TabItem value="setup" label="Setup & Configuration">
|
<TabItem value="setup" label="Setup & Configuration">
|
||||||
|
|
||||||
1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access).
|
1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access).
|
||||||
2. Provide the API key either when starting the Llama Stack server:
|
2. Provide the API key either by setting it in your environment before starting the Llama Stack server:
|
||||||
```bash
|
```bash
|
||||||
--env WOLFRAM_ALPHA_API_KEY=${WOLFRAM_ALPHA_API_KEY}
|
export WOLFRAM_ALPHA_API_KEY="your key"
|
||||||
```
|
```
|
||||||
or from the client side:
|
or from the client side:
|
||||||
```python
|
```python
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ Integration tests are located in [tests/integration](https://github.com/meta-lla
|
||||||
Consult [tests/integration/README.md](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more details on how to run the tests.
|
Consult [tests/integration/README.md](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more details on how to run the tests.
|
||||||
|
|
||||||
Note that each provider's `sample_run_config()` method (in the configuration class for that provider)
|
Note that each provider's `sample_run_config()` method (in the configuration class for that provider)
|
||||||
typically references some environment variables for specifying API keys and the like. You can set these in the environment or pass these via the `--env` flag to the test command.
|
typically references some environment variables for specifying API keys and the like. You can set these in the environment before running the test command.
|
||||||
|
|
||||||
|
|
||||||
### 2. Unit Testing
|
### 2. Unit Testing
|
||||||
|
|
|
||||||
|
|
@ -289,10 +289,10 @@ After this step is successful, you should be able to find the built container im
|
||||||
docker run -d \
|
docker run -d \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
|
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
-e OLLAMA_URL=http://host.docker.internal:11434 \
|
||||||
localhost/distribution-ollama:dev \
|
localhost/distribution-ollama:dev \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env OLLAMA_URL=http://host.docker.internal:11434
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Here are the docker flags and their uses:
|
Here are the docker flags and their uses:
|
||||||
|
|
@ -305,12 +305,12 @@ Here are the docker flags and their uses:
|
||||||
|
|
||||||
* `localhost/distribution-ollama:dev`: The name and tag of the container image to run
|
* `localhost/distribution-ollama:dev`: The name and tag of the container image to run
|
||||||
|
|
||||||
|
* `-e INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the INFERENCE_MODEL environment variable in the container
|
||||||
|
|
||||||
|
* `-e OLLAMA_URL=http://host.docker.internal:11434`: Sets the OLLAMA_URL environment variable in the container
|
||||||
|
|
||||||
* `--port $LLAMA_STACK_PORT`: Port number for the server to listen on
|
* `--port $LLAMA_STACK_PORT`: Port number for the server to listen on
|
||||||
|
|
||||||
* `--env INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the model to use for inference
|
|
||||||
|
|
||||||
* `--env OLLAMA_URL=http://host.docker.internal:11434`: Configures the URL for the Ollama service
|
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
@ -320,23 +320,22 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
|
||||||
|
|
||||||
```
|
```
|
||||||
llama stack run -h
|
llama stack run -h
|
||||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE]
|
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME]
|
||||||
[--image-type {venv}] [--enable-ui]
|
[--image-type {venv}] [--enable-ui]
|
||||||
[config | template]
|
[config | distro]
|
||||||
|
|
||||||
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
||||||
|
|
||||||
positional arguments:
|
positional arguments:
|
||||||
config | template Path to config file to use for the run or name of known template (`llama stack list` for a list). (default: None)
|
config | distro Path to config file to use for the run or name of known distro (`llama stack list` for a list). (default: None)
|
||||||
|
|
||||||
options:
|
options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
||||||
--image-name IMAGE_NAME
|
--image-name IMAGE_NAME
|
||||||
Name of the image to run. Defaults to the current environment (default: None)
|
[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running. (default: None)
|
||||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: None)
|
|
||||||
--image-type {venv}
|
--image-type {venv}
|
||||||
Image Type used during the build. This should be venv. (default: None)
|
[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running. (default: None)
|
||||||
--enable-ui Start the UI server (default: False)
|
--enable-ui Start the UI server (default: False)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -348,9 +347,6 @@ llama stack run tgi
|
||||||
|
|
||||||
# Start using config file
|
# Start using config file
|
||||||
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||||
|
|
||||||
# Start using a venv
|
|
||||||
llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ A few things to note:
|
||||||
- The id is a string you can choose freely.
|
- The id is a string you can choose freely.
|
||||||
- You can instantiate any number of provider instances of the same type.
|
- You can instantiate any number of provider instances of the same type.
|
||||||
- The configuration dictionary is provider-specific.
|
- The configuration dictionary is provider-specific.
|
||||||
- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server (via docker or via `llama stack run`), you can specify `--env OLLAMA_URL=http://my-server:11434` to override the default value.
|
- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server, you can set environment variables in your shell before running `llama stack run` to override the default values.
|
||||||
|
|
||||||
### Environment Variable Substitution
|
### Environment Variable Substitution
|
||||||
|
|
||||||
|
|
@ -173,13 +173,10 @@ optional_token: ${env.OPTIONAL_TOKEN:+}
|
||||||
|
|
||||||
#### Runtime Override
|
#### Runtime Override
|
||||||
|
|
||||||
You can override environment variables at runtime when starting the server:
|
You can override environment variables at runtime by setting them in your shell before starting the server:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Override specific environment variables
|
# Set environment variables in your shell
|
||||||
llama stack run --config run.yaml --env API_KEY=sk-123 --env BASE_URL=https://custom-api.com
|
|
||||||
|
|
||||||
# Or set them in your shell
|
|
||||||
export API_KEY=sk-123
|
export API_KEY=sk-123
|
||||||
export BASE_URL=https://custom-api.com
|
export BASE_URL=https://custom-api.com
|
||||||
llama stack run --config run.yaml
|
llama stack run --config run.yaml
|
||||||
|
|
|
||||||
|
|
@ -69,10 +69,10 @@ docker run \
|
||||||
-it \
|
-it \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
-e WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
-e WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
-e WATSONX_BASE_URL=$WATSONX_BASE_URL \
|
||||||
llamastack/distribution-watsonx \
|
llamastack/distribution-watsonx \
|
||||||
--config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
|
||||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
|
||||||
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -129,11 +129,11 @@ docker run -it \
|
||||||
# NOTE: mount the llama-stack / llama-model directories if testing local changes else not needed
|
# NOTE: mount the llama-stack / llama-model directories if testing local changes else not needed
|
||||||
-v $HOME/git/llama-stack:/app/llama-stack-source -v $HOME/git/llama-models:/app/llama-models-source \
|
-v $HOME/git/llama-stack:/app/llama-stack-source -v $HOME/git/llama-models:/app/llama-models-source \
|
||||||
# localhost/distribution-dell:dev if building / testing locally
|
# localhost/distribution-dell:dev if building / testing locally
|
||||||
llamastack/distribution-dell\
|
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--port $LLAMA_STACK_PORT \
|
-e DEH_URL=$DEH_URL \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
-e CHROMA_URL=$CHROMA_URL \
|
||||||
--env DEH_URL=$DEH_URL \
|
llamastack/distribution-dell \
|
||||||
--env CHROMA_URL=$CHROMA_URL
|
--port $LLAMA_STACK_PORT
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -154,14 +154,14 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v $HOME/.llama:/root/.llama \
|
-v $HOME/.llama:/root/.llama \
|
||||||
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
|
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
-e DEH_URL=$DEH_URL \
|
||||||
|
-e SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
-e DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||||
|
-e CHROMA_URL=$CHROMA_URL \
|
||||||
llamastack/distribution-dell \
|
llamastack/distribution-dell \
|
||||||
--config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env DEH_URL=$DEH_URL \
|
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
|
||||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
|
||||||
--env CHROMA_URL=$CHROMA_URL
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
@ -170,21 +170,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --distro dell --image-type venv
|
llama stack build --distro dell --image-type venv
|
||||||
llama stack run dell
|
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--port $LLAMA_STACK_PORT \
|
DEH_URL=$DEH_URL \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
CHROMA_URL=$CHROMA_URL \
|
||||||
--env DEH_URL=$DEH_URL \
|
llama stack run dell \
|
||||||
--env CHROMA_URL=$CHROMA_URL
|
--port $LLAMA_STACK_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
DEH_URL=$DEH_URL \
|
||||||
|
SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||||
|
CHROMA_URL=$CHROMA_URL \
|
||||||
llama stack run ./run-with-safety.yaml \
|
llama stack run ./run-with-safety.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env DEH_URL=$DEH_URL \
|
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
|
||||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
|
||||||
--env CHROMA_URL=$CHROMA_URL
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -84,9 +84,9 @@ docker run \
|
||||||
--gpu all \
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
|
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
|
|
@ -98,10 +98,10 @@ docker run \
|
||||||
--gpu all \
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
|
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
|
-e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
@ -110,16 +110,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --distro meta-reference-gpu --image-type venv
|
llama stack build --distro meta-reference-gpu --image-type venv
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
llama stack run distributions/meta-reference-gpu/run.yaml \
|
llama stack run distributions/meta-reference-gpu/run.yaml \
|
||||||
--port 8321 \
|
--port 8321
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
|
SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||||
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
|
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
|
||||||
--port 8321 \
|
--port 8321
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -129,10 +129,10 @@ docker run \
|
||||||
--pull always \
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
llamastack/distribution-nvidia \
|
llamastack/distribution-nvidia \
|
||||||
--config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
@ -142,10 +142,10 @@ If you've set up your local development environment, you can also build the imag
|
||||||
```bash
|
```bash
|
||||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||||
llama stack build --distro nvidia --image-type venv
|
llama stack build --distro nvidia --image-type venv
|
||||||
|
NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 8321 \
|
--port 8321
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Example Notebooks
|
## Example Notebooks
|
||||||
|
|
|
||||||
|
|
@ -86,9 +86,9 @@ docker run -it \
|
||||||
--pull always \
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
|
-e OLLAMA_URL=http://host.docker.internal:11434 \
|
||||||
llamastack/distribution-starter \
|
llamastack/distribution-starter \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env OLLAMA_URL=http://host.docker.internal:11434
|
|
||||||
```
|
```
|
||||||
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
|
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
|
||||||
`podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL`
|
`podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL`
|
||||||
|
|
@ -106,9 +106,9 @@ docker run -it \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
--network=host \
|
--network=host \
|
||||||
|
-e OLLAMA_URL=http://localhost:11434 \
|
||||||
llamastack/distribution-starter \
|
llamastack/distribution-starter \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env OLLAMA_URL=http://localhost:11434
|
|
||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
You will see output like below:
|
You will see output like below:
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | API key for Anthropic models |
|
| `api_key` | `str \| None` | No | | API key for Anthropic models |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure |
|
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure |
|
||||||
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
|
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
|
||||||
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
||||||
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
||||||
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
|
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
|
||||||
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |
|
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Databricks inference provider for running models on Databricks' unified analytic
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
|
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
|
||||||
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
|
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | API key for Gemini models |
|
| `api_key` | `str \| None` | No | | API key for Gemini models |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | The Groq API key |
|
| `api_key` | `str \| None` | No | | The Groq API key |
|
||||||
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
|
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | The Llama API key |
|
| `api_key` | `str \| None` | No | | The Llama API key |
|
||||||
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
|
||||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@ Ollama inference provider for running local models through the Ollama runtime.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
||||||
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Passthrough inference provider for connecting to any external inference service
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
|
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
||||||
| `api_token` | `str \| None` | No | | The API token |
|
| `api_token` | `str \| None` | No | | The API token |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | | The URL for the TGI serving endpoint |
|
| `url` | `<class 'str'>` | No | | The URL for the TGI serving endpoint |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ Together AI inference provider for open-source models and collaborative AI devel
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ Available Models:
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI |
|
| `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI |
|
||||||
| `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI |
|
| `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,11 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
||||||
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
||||||
| `api_token` | `str \| None` | No | fake | The API token |
|
| `api_token` | `str \| None` | No | fake | The API token |
|
||||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,10 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
|
||||||
| `project_id` | `str \| None` | No | | The Project ID key |
|
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
||||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
||||||
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
||||||
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
||||||
|
|
|
||||||
|
|
@ -123,12 +123,12 @@
|
||||||
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# this command installs all the dependencies needed for the llama stack server with the together inference provider\n",
|
"# this command installs all the dependencies needed for the llama stack server with the together inference provider\n",
|
||||||
"!uv run --with llama-stack llama stack build --distro together --image-type venv\n",
|
"!uv run --with llama-stack llama stack build --distro together\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def run_llama_stack_server_background():\n",
|
"def run_llama_stack_server_background():\n",
|
||||||
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
||||||
" process = subprocess.Popen(\n",
|
" process = subprocess.Popen(\n",
|
||||||
" \"uv run --with llama-stack llama stack run together --image-type venv\",\n",
|
" \"uv run --with llama-stack llama stack run together\",\n",
|
||||||
" shell=True,\n",
|
" shell=True,\n",
|
||||||
" stdout=log_file,\n",
|
" stdout=log_file,\n",
|
||||||
" stderr=log_file,\n",
|
" stderr=log_file,\n",
|
||||||
|
|
|
||||||
|
|
@ -233,12 +233,12 @@
|
||||||
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# this command installs all the dependencies needed for the llama stack server\n",
|
"# this command installs all the dependencies needed for the llama stack server\n",
|
||||||
"!uv run --with llama-stack llama stack build --distro meta-reference-gpu --image-type venv\n",
|
"!uv run --with llama-stack llama stack build --distro meta-reference-gpu\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def run_llama_stack_server_background():\n",
|
"def run_llama_stack_server_background():\n",
|
||||||
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
||||||
" process = subprocess.Popen(\n",
|
" process = subprocess.Popen(\n",
|
||||||
" f\"uv run --with llama-stack llama stack run meta-reference-gpu --image-type venv --env INFERENCE_MODEL={model_id}\",\n",
|
" f\"INFERENCE_MODEL={model_id} uv run --with llama-stack llama stack run meta-reference-gpu\",\n",
|
||||||
" shell=True,\n",
|
" shell=True,\n",
|
||||||
" stdout=log_file,\n",
|
" stdout=log_file,\n",
|
||||||
" stderr=log_file,\n",
|
" stderr=log_file,\n",
|
||||||
|
|
|
||||||
|
|
@ -223,12 +223,12 @@
|
||||||
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# this command installs all the dependencies needed for the llama stack server\n",
|
"# this command installs all the dependencies needed for the llama stack server\n",
|
||||||
"!uv run --with llama-stack llama stack build --distro llama_api --image-type venv\n",
|
"!uv run --with llama-stack llama stack build --distro llama_api\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def run_llama_stack_server_background():\n",
|
"def run_llama_stack_server_background():\n",
|
||||||
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
||||||
" process = subprocess.Popen(\n",
|
" process = subprocess.Popen(\n",
|
||||||
" \"uv run --with llama-stack llama stack run llama_api --image-type venv\",\n",
|
" \"uv run --with llama-stack llama stack run llama_api\",\n",
|
||||||
" shell=True,\n",
|
" shell=True,\n",
|
||||||
" stdout=log_file,\n",
|
" stdout=log_file,\n",
|
||||||
" stderr=log_file,\n",
|
" stderr=log_file,\n",
|
||||||
|
|
|
||||||
|
|
@ -145,12 +145,12 @@
|
||||||
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# this command installs all the dependencies needed for the llama stack server with the ollama inference provider\n",
|
"# this command installs all the dependencies needed for the llama stack server with the ollama inference provider\n",
|
||||||
"!uv run --with llama-stack llama stack build --distro starter --image-type venv\n",
|
"!uv run --with llama-stack llama stack build --distro starter\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def run_llama_stack_server_background():\n",
|
"def run_llama_stack_server_background():\n",
|
||||||
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
||||||
" process = subprocess.Popen(\n",
|
" process = subprocess.Popen(\n",
|
||||||
" f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter --image-type venv\n",
|
" f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter\n",
|
||||||
" shell=True,\n",
|
" shell=True,\n",
|
||||||
" stdout=log_file,\n",
|
" stdout=log_file,\n",
|
||||||
" stderr=log_file,\n",
|
" stderr=log_file,\n",
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
...
|
...
|
||||||
Build Successful!
|
Build Successful!
|
||||||
You can find the newly-built template here: ~/.llama/distributions/starter/starter-run.yaml
|
You can find the newly-built template here: ~/.llama/distributions/starter/starter-run.yaml
|
||||||
You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter --image-type venv
|
You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **Set the ENV variables by exporting them to the terminal**:
|
3. **Set the ENV variables by exporting them to the terminal**:
|
||||||
|
|
@ -102,12 +102,11 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
3. **Run the Llama Stack**:
|
3. **Run the Llama Stack**:
|
||||||
Run the stack using uv:
|
Run the stack using uv:
|
||||||
```bash
|
```bash
|
||||||
|
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
OLLAMA_URL=$OLLAMA_URL \
|
||||||
uv run --with llama-stack llama stack run starter \
|
uv run --with llama-stack llama stack run starter \
|
||||||
--image-type venv \
|
--port $LLAMA_STACK_PORT
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
|
||||||
--env OLLAMA_URL=$OLLAMA_URL
|
|
||||||
```
|
```
|
||||||
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
|
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -444,12 +444,24 @@ def _run_stack_build_command_from_build_config(
|
||||||
|
|
||||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||||
cprint(
|
if build_config.image_type == LlamaStackImageType.VENV:
|
||||||
"You can run the new Llama Stack distro via: "
|
cprint(
|
||||||
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
|
"You can run the new Llama Stack distro (after activating "
|
||||||
color="green",
|
+ colored(image_name, "cyan")
|
||||||
file=sys.stderr,
|
+ ") via: "
|
||||||
)
|
+ colored(f"llama stack run {run_config_file}", "blue"),
|
||||||
|
color="green",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
elif build_config.image_type == LlamaStackImageType.CONTAINER:
|
||||||
|
cprint(
|
||||||
|
"You can run the container with: "
|
||||||
|
+ colored(
|
||||||
|
f"docker run -p 8321:8321 -v ~/.llama:/root/.llama localhost/{image_name} --port 8321", "blue"
|
||||||
|
),
|
||||||
|
color="green",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
return distro_path
|
return distro_path
|
||||||
else:
|
else:
|
||||||
return _generate_run_config(build_config, build_dir, image_name)
|
return _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import yaml
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
|
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
|
||||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars, validate_env_pair
|
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -55,18 +55,12 @@ class StackRun(Subcommand):
|
||||||
"--image-name",
|
"--image-name",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Name of the image to run. Defaults to the current environment",
|
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--env",
|
|
||||||
action="append",
|
|
||||||
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
|
|
||||||
metavar="KEY=VALUE",
|
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--image-type",
|
"--image-type",
|
||||||
type=str,
|
type=str,
|
||||||
help="Image Type used during the build. This can be only venv.",
|
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||||
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
|
@ -75,48 +69,22 @@ class StackRun(Subcommand):
|
||||||
help="Start the UI server",
|
help="Start the UI server",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
|
||||||
"""Resolve config file path and distribution name from args.config"""
|
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
|
||||||
|
|
||||||
if not args.config:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
config_file = Path(args.config)
|
|
||||||
has_yaml_suffix = args.config.endswith(".yaml")
|
|
||||||
distro_name = None
|
|
||||||
|
|
||||||
if not config_file.exists() and not has_yaml_suffix:
|
|
||||||
# check if this is a distribution
|
|
||||||
config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml"
|
|
||||||
if config_file.exists():
|
|
||||||
distro_name = args.config
|
|
||||||
|
|
||||||
if not config_file.exists() and not has_yaml_suffix:
|
|
||||||
# check if it's a build config saved to ~/.llama dir
|
|
||||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
|
||||||
|
|
||||||
if not config_file.exists():
|
|
||||||
self.parser.error(
|
|
||||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not config_file.is_file():
|
|
||||||
self.parser.error(
|
|
||||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return config_file, distro_name
|
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
|
||||||
|
if args.image_type or args.image_name:
|
||||||
|
self.parser.error(
|
||||||
|
"The --image-type and --image-name flags are no longer supported.\n\n"
|
||||||
|
"Please activate your virtual environment manually before running `llama stack run`.\n\n"
|
||||||
|
"For example:\n"
|
||||||
|
" source /path/to/venv/bin/activate\n"
|
||||||
|
" llama stack run <config>\n"
|
||||||
|
)
|
||||||
|
|
||||||
if args.enable_ui:
|
if args.enable_ui:
|
||||||
self._start_ui_development_server(args.port)
|
self._start_ui_development_server(args.port)
|
||||||
image_type, image_name = args.image_type, args.image_name
|
|
||||||
|
|
||||||
if args.config:
|
if args.config:
|
||||||
try:
|
try:
|
||||||
|
|
@ -128,10 +96,6 @@ class StackRun(Subcommand):
|
||||||
else:
|
else:
|
||||||
config_file = None
|
config_file = None
|
||||||
|
|
||||||
# Check if config is required based on image type
|
|
||||||
if image_type == ImageType.VENV.value and not config_file:
|
|
||||||
self.parser.error("Config file is required for venv environment")
|
|
||||||
|
|
||||||
if config_file:
|
if config_file:
|
||||||
logger.info(f"Using run configuration: {config_file}")
|
logger.info(f"Using run configuration: {config_file}")
|
||||||
|
|
||||||
|
|
@ -146,50 +110,13 @@ class StackRun(Subcommand):
|
||||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||||
else:
|
|
||||||
config = None
|
|
||||||
|
|
||||||
# If neither image type nor image name is provided, assume the server should be run directly
|
self._uvicorn_run(config_file, args)
|
||||||
# using the current environment packages.
|
|
||||||
if not image_type and not image_name:
|
|
||||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
|
||||||
self._uvicorn_run(config_file, args)
|
|
||||||
else:
|
|
||||||
run_args = formulate_run_args(image_type, image_name)
|
|
||||||
|
|
||||||
run_args.extend([str(args.port)])
|
|
||||||
|
|
||||||
if config_file:
|
|
||||||
run_args.extend(["--config", str(config_file)])
|
|
||||||
|
|
||||||
if args.env:
|
|
||||||
for env_var in args.env:
|
|
||||||
if "=" not in env_var:
|
|
||||||
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
|
|
||||||
return
|
|
||||||
key, value = env_var.split("=", 1) # split on first = only
|
|
||||||
if not key:
|
|
||||||
self.parser.error(f"Environment variable '{env_var}' has empty key")
|
|
||||||
return
|
|
||||||
run_args.extend(["--env", f"{key}={value}"])
|
|
||||||
|
|
||||||
run_command(run_args)
|
|
||||||
|
|
||||||
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
|
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
|
||||||
if not config_file:
|
if not config_file:
|
||||||
self.parser.error("Config file is required")
|
self.parser.error("Config file is required")
|
||||||
|
|
||||||
# Set environment variables if provided
|
|
||||||
if args.env:
|
|
||||||
for env_pair in args.env:
|
|
||||||
try:
|
|
||||||
key, value = validate_env_pair(env_pair)
|
|
||||||
logger.info(f"Setting environment variable {key} => {value}")
|
|
||||||
os.environ[key] = value
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Error: {str(e)}")
|
|
||||||
self.parser.error(f"Invalid environment variable format: {env_pair}")
|
|
||||||
|
|
||||||
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
|
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
|
||||||
with open(config_file) as fp:
|
with open(config_file) as fp:
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||||
sqlstore_impl,
|
sqlstore_impl,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai::conversations")
|
logger = get_logger(name=__name__, category="openai_conversations")
|
||||||
|
|
||||||
|
|
||||||
class ConversationServiceConfig(BaseModel):
|
class ConversationServiceConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -611,7 +611,7 @@ class InferenceRouter(Inference):
|
||||||
completion_text += "".join(choice_data["content_parts"])
|
completion_text += "".join(choice_data["content_parts"])
|
||||||
|
|
||||||
# Add metrics to the chunk
|
# Add metrics to the chunk
|
||||||
if self.telemetry and chunk.usage:
|
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(
|
||||||
prompt_tokens=chunk.usage.prompt_tokens,
|
prompt_tokens=chunk.usage.prompt_tokens,
|
||||||
completion_tokens=chunk.usage.completion_tokens,
|
completion_tokens=chunk.usage.completion_tokens,
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
try:
|
try:
|
||||||
models = await provider.list_models()
|
models = await provider.list_models()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.listed_providers.add(provider_id)
|
self.listed_providers.add(provider_id)
|
||||||
|
|
@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
||||||
return self.impls_by_provider_id[model.provider_id]
|
return self.impls_by_provider_id[model.provider_id]
|
||||||
|
|
||||||
|
async def has_model(self, model_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a model exists in the routing table.
|
||||||
|
|
||||||
|
:param model_id: The model identifier to check
|
||||||
|
:return: True if the model exists, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await lookup_model(self, model_id)
|
||||||
|
return True
|
||||||
|
except ModelNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
||||||
|
|
@ -274,22 +274,6 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|
||||||
"""Validate and split an environment variable key-value pair."""
|
|
||||||
try:
|
|
||||||
key, value = env_pair.split("=", 1)
|
|
||||||
key = key.strip()
|
|
||||||
if not key:
|
|
||||||
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
|
||||||
if not all(c.isalnum() or c == "_" for c in key):
|
|
||||||
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
|
|
||||||
return key, value
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ error_handler() {
|
||||||
trap 'error_handler ${LINENO}' ERR
|
trap 'error_handler ${LINENO}' ERR
|
||||||
|
|
||||||
if [ $# -lt 3 ]; then
|
if [ $# -lt 3 ]; then
|
||||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
|
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>]"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
@ -43,7 +43,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
|
|
||||||
# Initialize variables
|
# Initialize variables
|
||||||
yaml_config=""
|
yaml_config=""
|
||||||
env_vars=""
|
|
||||||
other_args=""
|
other_args=""
|
||||||
|
|
||||||
# Process remaining arguments
|
# Process remaining arguments
|
||||||
|
|
@ -58,15 +57,6 @@ while [[ $# -gt 0 ]]; do
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
--env)
|
|
||||||
if [[ -n "$2" ]]; then
|
|
||||||
env_vars="$env_vars --env $2"
|
|
||||||
shift 2
|
|
||||||
else
|
|
||||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
;;
|
|
||||||
*)
|
*)
|
||||||
other_args="$other_args $1"
|
other_args="$other_args $1"
|
||||||
shift
|
shift
|
||||||
|
|
@ -119,7 +109,6 @@ if [[ "$env_type" == "venv" ]]; then
|
||||||
llama stack run \
|
llama stack run \
|
||||||
$yaml_config_arg \
|
$yaml_config_arg \
|
||||||
--port "$port" \
|
--port "$port" \
|
||||||
$env_vars \
|
|
||||||
$other_args
|
$other_args
|
||||||
elif [[ "$env_type" == "container" ]]; then
|
elif [[ "$env_type" == "container" ]]; then
|
||||||
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,10 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
existing_obj = await self.get(obj.type, obj.identifier)
|
existing_obj = await self.get(obj.type, obj.identifier)
|
||||||
# dont register if the object's providerid already exists
|
# dont register if the object's providerid already exists
|
||||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||||
return False
|
raise ValueError(
|
||||||
|
f"Provider '{obj.provider_id}' is already registered."
|
||||||
|
f"Unregister the existing provider first before registering it again."
|
||||||
|
)
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||||
|
|
|
||||||
|
|
@ -117,11 +117,11 @@ docker run -it \
|
||||||
# NOTE: mount the llama-stack directory if testing local changes else not needed
|
# NOTE: mount the llama-stack directory if testing local changes else not needed
|
||||||
-v $HOME/git/llama-stack:/app/llama-stack-source \
|
-v $HOME/git/llama-stack:/app/llama-stack-source \
|
||||||
# localhost/distribution-dell:dev if building / testing locally
|
# localhost/distribution-dell:dev if building / testing locally
|
||||||
|
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
-e DEH_URL=$DEH_URL \
|
||||||
|
-e CHROMA_URL=$CHROMA_URL \
|
||||||
llamastack/distribution-{{ name }}\
|
llamastack/distribution-{{ name }}\
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env DEH_URL=$DEH_URL \
|
|
||||||
--env CHROMA_URL=$CHROMA_URL
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -142,14 +142,14 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v $HOME/.llama:/root/.llama \
|
-v $HOME/.llama:/root/.llama \
|
||||||
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
|
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
-e DEH_URL=$DEH_URL \
|
||||||
|
-e SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
-e DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||||
|
-e CHROMA_URL=$CHROMA_URL \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
--config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env DEH_URL=$DEH_URL \
|
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
|
||||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
|
||||||
--env CHROMA_URL=$CHROMA_URL
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
@ -158,21 +158,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --distro {{ name }} --image-type conda
|
llama stack build --distro {{ name }} --image-type conda
|
||||||
llama stack run {{ name }}
|
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--port $LLAMA_STACK_PORT \
|
DEH_URL=$DEH_URL \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
CHROMA_URL=$CHROMA_URL \
|
||||||
--env DEH_URL=$DEH_URL \
|
llama stack run {{ name }} \
|
||||||
--env CHROMA_URL=$CHROMA_URL
|
--port $LLAMA_STACK_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
DEH_URL=$DEH_URL \
|
||||||
|
SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||||
|
CHROMA_URL=$CHROMA_URL \
|
||||||
llama stack run ./run-with-safety.yaml \
|
llama stack run ./run-with-safety.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env DEH_URL=$DEH_URL \
|
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
|
||||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
|
||||||
--env CHROMA_URL=$CHROMA_URL
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -72,9 +72,9 @@ docker run \
|
||||||
--gpu all \
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
|
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
|
|
@ -86,10 +86,10 @@ docker run \
|
||||||
--gpu all \
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
|
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
|
-e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
@ -98,16 +98,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --distro {{ name }} --image-type venv
|
llama stack build --distro {{ name }} --image-type venv
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
llama stack run distributions/{{ name }}/run.yaml \
|
llama stack run distributions/{{ name }}/run.yaml \
|
||||||
--port 8321 \
|
--port 8321
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
|
SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||||
llama stack run distributions/{{ name }}/run-with-safety.yaml \
|
llama stack run distributions/{{ name }}/run-with-safety.yaml \
|
||||||
--port 8321 \
|
--port 8321
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -118,10 +118,10 @@ docker run \
|
||||||
--pull always \
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
--config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
@ -131,10 +131,10 @@ If you've set up your local development environment, you can also build the imag
|
||||||
```bash
|
```bash
|
||||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||||
llama stack build --distro nvidia --image-type venv
|
llama stack build --distro nvidia --image-type venv
|
||||||
|
NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 8321 \
|
--port 8321
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Example Notebooks
|
## Example Notebooks
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .watsonx import get_distribution_template # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -3,44 +3,33 @@ distribution_spec:
|
||||||
description: Use watsonx for running LLM inference
|
description: Use watsonx for running LLM inference
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: watsonx
|
- provider_type: remote::watsonx
|
||||||
provider_type: remote::watsonx
|
- provider_type: inline::sentence-transformers
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_type: inline::faiss
|
||||||
provider_type: inline::faiss
|
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_type: inline::llama-guard
|
||||||
provider_type: inline::llama-guard
|
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
datasetio:
|
datasetio:
|
||||||
- provider_id: huggingface
|
- provider_type: remote::huggingface
|
||||||
provider_type: remote::huggingface
|
- provider_type: inline::localfs
|
||||||
- provider_id: localfs
|
|
||||||
provider_type: inline::localfs
|
|
||||||
scoring:
|
scoring:
|
||||||
- provider_id: basic
|
- provider_type: inline::basic
|
||||||
provider_type: inline::basic
|
- provider_type: inline::llm-as-judge
|
||||||
- provider_id: llm-as-judge
|
- provider_type: inline::braintrust
|
||||||
provider_type: inline::llm-as-judge
|
|
||||||
- provider_id: braintrust
|
|
||||||
provider_type: inline::braintrust
|
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_type: remote::brave-search
|
- provider_type: remote::brave-search
|
||||||
- provider_type: remote::tavily-search
|
- provider_type: remote::tavily-search
|
||||||
- provider_type: inline::rag-runtime
|
- provider_type: inline::rag-runtime
|
||||||
- provider_type: remote::model-context-protocol
|
- provider_type: remote::model-context-protocol
|
||||||
|
files:
|
||||||
|
- provider_type: inline::localfs
|
||||||
image_type: venv
|
image_type: venv
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
- aiosqlite
|
|
||||||
- aiosqlite
|
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,13 @@ apis:
|
||||||
- agents
|
- agents
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
- vector_io
|
- vector_io
|
||||||
- files
|
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: watsonx
|
- provider_id: watsonx
|
||||||
|
|
@ -19,8 +19,6 @@ providers:
|
||||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||||
api_key: ${env.WATSONX_API_KEY:=}
|
api_key: ${env.WATSONX_API_KEY:=}
|
||||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
@ -48,7 +46,7 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
||||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
eval:
|
eval:
|
||||||
|
|
@ -109,102 +107,7 @@ metadata_store:
|
||||||
inference_store:
|
inference_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
||||||
models:
|
models: []
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-2-13b-chat
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-2-13b-chat
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-2-13b
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-2-13b-chat
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
model_type: llm
|
|
||||||
- metadata:
|
|
||||||
embedding_dimension: 384
|
|
||||||
model_id: all-MiniLM-L6-v2
|
|
||||||
provider_id: sentence-transformers
|
|
||||||
model_type: embedding
|
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
|
||||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
|
||||||
SentenceTransformersInferenceConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
|
|
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
config=WatsonXConfig.sample_run_config(),
|
config=WatsonXConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_provider = Provider(
|
|
||||||
provider_id="sentence-transformers",
|
|
||||||
provider_type="inline::sentence-transformers",
|
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
|
||||||
)
|
|
||||||
|
|
||||||
available_models = {
|
|
||||||
"watsonx": MODEL_ENTRIES,
|
|
||||||
}
|
|
||||||
default_tool_groups = [
|
default_tool_groups = [
|
||||||
ToolGroupInput(
|
ToolGroupInput(
|
||||||
toolgroup_id="builtin::websearch",
|
toolgroup_id="builtin::websearch",
|
||||||
|
|
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
embedding_model = ModelInput(
|
|
||||||
model_id="all-MiniLM-L6-v2",
|
|
||||||
provider_id="sentence-transformers",
|
|
||||||
model_type=ModelType.embedding,
|
|
||||||
metadata={
|
|
||||||
"embedding_dimension": 384,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
files_provider = Provider(
|
files_provider = Provider(
|
||||||
provider_id="meta-reference-files",
|
provider_id="meta-reference-files",
|
||||||
provider_type="inline::localfs",
|
provider_type="inline::localfs",
|
||||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
)
|
)
|
||||||
default_models, _ = get_model_registry(available_models)
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
distro_type="remote_hosted",
|
distro_type="remote_hosted",
|
||||||
description="Use watsonx for running LLM inference",
|
description="Use watsonx for running LLM inference",
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=Path(__file__).parent / "doc_template.md",
|
template_path=None,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
available_models_by_provider=available_models,
|
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider, embedding_provider],
|
"inference": [inference_provider],
|
||||||
"files": [files_provider],
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models + [embedding_model],
|
default_models=[],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -31,12 +31,17 @@ CATEGORIES = [
|
||||||
"client",
|
"client",
|
||||||
"telemetry",
|
"telemetry",
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
|
"openai_conversations",
|
||||||
"testing",
|
"testing",
|
||||||
"providers",
|
"providers",
|
||||||
"models",
|
"models",
|
||||||
"files",
|
"files",
|
||||||
"vector_io",
|
"vector_io",
|
||||||
"tool_runtime",
|
"tool_runtime",
|
||||||
|
"cli",
|
||||||
|
"post_training",
|
||||||
|
"scoring",
|
||||||
|
"tests",
|
||||||
]
|
]
|
||||||
UNCATEGORIZED = "uncategorized"
|
UNCATEGORIZED = "uncategorized"
|
||||||
|
|
||||||
|
|
@ -264,11 +269,12 @@ def get_logger(
|
||||||
if root_category in _category_levels:
|
if root_category in _category_levels:
|
||||||
log_level = _category_levels[root_category]
|
log_level = _category_levels[root_category]
|
||||||
else:
|
else:
|
||||||
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
|
||||||
if category != UNCATEGORIZED:
|
if category != UNCATEGORIZED:
|
||||||
logging.warning(
|
raise ValueError(
|
||||||
f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}"
|
f"Unknown logging category: {category}. To resolve, choose a valid category from the CATEGORIES list "
|
||||||
|
f"or add it to the CATEGORIES list. Available categories: {CATEGORIES}"
|
||||||
)
|
)
|
||||||
|
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
return logging.LoggerAdapter(logger, {"category": category})
|
return logging.LoggerAdapter(logger, {"category": category})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,19 +11,13 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawContent,
|
RawContent,
|
||||||
RawMediaItem,
|
|
||||||
RawMessage,
|
RawMessage,
|
||||||
RawTextItem,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||||
|
|
@ -175,25 +169,6 @@ def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def llama3_1_builtin_tool_call_with_image_dialog(
|
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
|
||||||
):
|
|
||||||
this_dir = Path(__file__).parent
|
|
||||||
with open(this_dir / "llama3/dog.jpg", "rb") as f:
|
|
||||||
img = f.read()
|
|
||||||
|
|
||||||
interface = LLama31Interface(tool_prompt_format)
|
|
||||||
|
|
||||||
messages = interface.system_messages(**system_message_builtin_tools_only())
|
|
||||||
messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")])
|
|
||||||
messages += interface.assistant_response_messages(
|
|
||||||
"Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix",
|
|
||||||
StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
messages += interface.user_message("Search the web for some food recommendations for the indentified breed")
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
interface = LLama31Interface(tool_prompt_format)
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
|
@ -202,35 +177,6 @@ def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
|
||||||
tool_response = json.dumps(["great song1", "awesome song2", "cool song3"])
|
|
||||||
interface = LLama31Interface(tool_prompt_format)
|
|
||||||
|
|
||||||
messages = interface.system_messages(**system_message_custom_tools_only())
|
|
||||||
messages += interface.user_message(content="Use tools to get latest trending songs")
|
|
||||||
messages.append(
|
|
||||||
RawMessage(
|
|
||||||
role="assistant",
|
|
||||||
content="",
|
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="call_id",
|
|
||||||
tool_name="trending_songs",
|
|
||||||
arguments={"n": "10", "genre": "latest"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
messages.append(
|
|
||||||
RawMessage(
|
|
||||||
role="assistant",
|
|
||||||
content=tool_response,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def llama3_2_user_assistant_conversation():
|
def llama3_2_user_assistant_conversation():
|
||||||
return UseCase(
|
return UseCase(
|
||||||
title="User and assistant conversation",
|
title="User and assistant conversation",
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
||||||
deps[Api.tool_runtime],
|
deps[Api.tool_runtime],
|
||||||
deps[Api.tool_groups],
|
deps[Api.tool_groups],
|
||||||
policy,
|
policy,
|
||||||
|
Api.telemetry in deps,
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,6 @@
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import secrets
|
|
||||||
import string
|
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
@ -84,11 +82,6 @@ from llama_stack.providers.utils.telemetry import tracing
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_QUERY_TOOL = "knowledge_search"
|
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
|
|
@ -110,6 +103,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
created_at: str,
|
created_at: str,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
|
telemetry_enabled: bool = False,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
|
|
@ -120,6 +114,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
self.created_at = created_at
|
self.created_at = created_at
|
||||||
|
self.telemetry_enabled = telemetry_enabled
|
||||||
|
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -188,28 +183,30 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
span = tracing.get_current_span()
|
if self.telemetry_enabled:
|
||||||
if span:
|
span = tracing.get_current_span()
|
||||||
span.set_attribute("session_id", request.session_id)
|
if span is not None:
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("request", request.model_dump_json())
|
||||||
if self.agent_config.name:
|
span.set_attribute("turn_id", turn_id)
|
||||||
span.set_attribute("agent_name", self.agent_config.name)
|
if self.agent_config.name:
|
||||||
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
|
|
||||||
await self._initialize_tools(request.toolgroups)
|
await self._initialize_tools(request.toolgroups)
|
||||||
async for chunk in self._run_turn(request, turn_id):
|
async for chunk in self._run_turn(request, turn_id):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
span = tracing.get_current_span()
|
if self.telemetry_enabled:
|
||||||
if span:
|
span = tracing.get_current_span()
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
if span is not None:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("turn_id", request.turn_id)
|
span.set_attribute("request", request.model_dump_json())
|
||||||
if self.agent_config.name:
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
span.set_attribute("agent_name", self.agent_config.name)
|
if self.agent_config.name:
|
||||||
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
|
|
||||||
await self._initialize_tools()
|
await self._initialize_tools()
|
||||||
async for chunk in self._run_turn(request):
|
async for chunk in self._run_turn(request):
|
||||||
|
|
@ -395,9 +392,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
async with tracing.span("run_shields") as span:
|
async with tracing.span("run_shields") as span:
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
if self.telemetry_enabled and span is not None:
|
||||||
|
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||||
|
if len(shields) == 0:
|
||||||
|
span.set_attribute("output", "no shields")
|
||||||
|
|
||||||
if len(shields) == 0:
|
if len(shields) == 0:
|
||||||
span.set_attribute("output", "no shields")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
|
|
@ -430,7 +430,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
span.set_attribute("output", e.violation.model_dump_json())
|
if self.telemetry_enabled and span is not None:
|
||||||
|
span.set_attribute("output", e.violation.model_dump_json())
|
||||||
|
|
||||||
yield CompletionMessage(
|
yield CompletionMessage(
|
||||||
content=str(e),
|
content=str(e),
|
||||||
|
|
@ -453,7 +454,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
span.set_attribute("output", "no violations")
|
if self.telemetry_enabled and span is not None:
|
||||||
|
span.set_attribute("output", "no violations")
|
||||||
|
|
||||||
async def _run(
|
async def _run(
|
||||||
self,
|
self,
|
||||||
|
|
@ -518,8 +520,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason: StopReason | None = None
|
stop_reason: StopReason | None = None
|
||||||
|
|
||||||
async with tracing.span("inference") as span:
|
async with tracing.span("inference") as span:
|
||||||
if self.agent_config.name:
|
if self.telemetry_enabled and span is not None:
|
||||||
span.set_attribute("agent_name", self.agent_config.name)
|
if self.agent_config.name:
|
||||||
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
|
|
||||||
def _serialize_nested(value):
|
def _serialize_nested(value):
|
||||||
"""Recursively serialize nested Pydantic models to dicts."""
|
"""Recursively serialize nested Pydantic models to dicts."""
|
||||||
|
|
@ -637,18 +640,19 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||||
|
|
||||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
if self.telemetry_enabled and span is not None:
|
||||||
span.set_attribute(
|
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||||
"input",
|
span.set_attribute(
|
||||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
"input",
|
||||||
)
|
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||||
output_attr = json.dumps(
|
)
|
||||||
{
|
output_attr = json.dumps(
|
||||||
"content": content,
|
{
|
||||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
"content": content,
|
||||||
}
|
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||||
)
|
}
|
||||||
span.set_attribute("output", output_attr)
|
)
|
||||||
|
span.set_attribute("output", output_attr)
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||||
|
|
@ -756,7 +760,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
{
|
{
|
||||||
"tool_name": tool_call.tool_name,
|
"tool_name": tool_call.tool_name,
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
}
|
||||||
|
if self.telemetry_enabled
|
||||||
|
else {},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
tool_execution_start_time = datetime.now(UTC).isoformat()
|
||||||
tool_result = await self.execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
|
|
@ -771,7 +777,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
call_id=tool_call.call_id,
|
call_id=tool_call.call_id,
|
||||||
content=tool_result.content,
|
content=tool_result.content,
|
||||||
)
|
)
|
||||||
span.set_attribute("output", result_message.model_dump_json())
|
if self.telemetry_enabled and span is not None:
|
||||||
|
span.set_attribute("output", result_message.model_dump_json())
|
||||||
|
|
||||||
# Store tool execution step
|
# Store tool execution step
|
||||||
tool_execution_step = ToolExecutionStep(
|
tool_execution_step = ToolExecutionStep(
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
|
telemetry_enabled: bool = False,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
@ -71,6 +72,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
self.telemetry_enabled = telemetry_enabled
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||||
|
|
@ -135,6 +137,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
),
|
),
|
||||||
created_at=agent_info.created_at,
|
created_at=agent_info.created_at,
|
||||||
policy=self.policy,
|
policy=self.policy,
|
||||||
|
telemetry_enabled=self.telemetry_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
|
|
|
||||||
|
|
@ -269,7 +269,7 @@ class OpenAIResponsesImpl:
|
||||||
response_tools=tools,
|
response_tools=tools,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
inputs=input,
|
inputs=all_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create orchestrator and delegate streaming logic
|
# Create orchestrator and delegate streaming logic
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,8 @@ class StreamingResponseOrchestrator:
|
||||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||||
# Track final messages after all tool executions
|
# Track final messages after all tool executions
|
||||||
self.final_messages: list[OpenAIMessageParam] = []
|
self.final_messages: list[OpenAIMessageParam] = []
|
||||||
|
# mapping for annotations
|
||||||
|
self.citation_files: dict[str, str] = {}
|
||||||
|
|
||||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Initialize output messages
|
# Initialize output messages
|
||||||
|
|
@ -126,6 +128,7 @@ class StreamingResponseOrchestrator:
|
||||||
# Text is the default response format for chat completion so don't need to pass it
|
# Text is the default response format for chat completion so don't need to pass it
|
||||||
# (some providers don't support non-empty response_format when tools are present)
|
# (some providers don't support non-empty response_format when tools are present)
|
||||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||||
|
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||||
completion_result = await self.inference_api.openai_chat_completion(
|
completion_result = await self.inference_api.openai_chat_completion(
|
||||||
model=self.ctx.model,
|
model=self.ctx.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
@ -160,7 +163,7 @@ class StreamingResponseOrchestrator:
|
||||||
# Handle choices with no tool calls
|
# Handle choices with no tool calls
|
||||||
for choice in current_response.choices:
|
for choice in current_response.choices:
|
||||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||||
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))
|
||||||
|
|
||||||
# Execute tool calls and coordinate results
|
# Execute tool calls and coordinate results
|
||||||
async for stream_event in self._coordinate_tool_execution(
|
async for stream_event in self._coordinate_tool_execution(
|
||||||
|
|
@ -172,6 +175,8 @@ class StreamingResponseOrchestrator:
|
||||||
):
|
):
|
||||||
yield stream_event
|
yield stream_event
|
||||||
|
|
||||||
|
messages = next_turn_messages
|
||||||
|
|
||||||
if not function_tool_calls and not non_function_tool_calls:
|
if not function_tool_calls and not non_function_tool_calls:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -184,9 +189,7 @@ class StreamingResponseOrchestrator:
|
||||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
||||||
break
|
break
|
||||||
|
|
||||||
messages = next_turn_messages
|
self.final_messages = messages.copy()
|
||||||
|
|
||||||
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
|
||||||
|
|
||||||
# Create final response
|
# Create final response
|
||||||
final_response = OpenAIResponseObject(
|
final_response = OpenAIResponseObject(
|
||||||
|
|
@ -211,6 +214,8 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
for choice in current_response.choices:
|
for choice in current_response.choices:
|
||||||
next_turn_messages.append(choice.message)
|
next_turn_messages.append(choice.message)
|
||||||
|
logger.debug(f"Choice message content: {choice.message.content}")
|
||||||
|
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
|
||||||
|
|
||||||
if choice.message.tool_calls and self.ctx.response_tools:
|
if choice.message.tool_calls and self.ctx.response_tools:
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
|
|
@ -227,9 +232,11 @@ class StreamingResponseOrchestrator:
|
||||||
non_function_tool_calls.append(tool_call)
|
non_function_tool_calls.append(tool_call)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
||||||
|
next_turn_messages.pop()
|
||||||
else:
|
else:
|
||||||
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
||||||
approvals.append(tool_call)
|
approvals.append(tool_call)
|
||||||
|
next_turn_messages.pop()
|
||||||
else:
|
else:
|
||||||
non_function_tool_calls.append(tool_call)
|
non_function_tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
|
@ -470,6 +477,8 @@ class StreamingResponseOrchestrator:
|
||||||
tool_call_log = result.final_output_message
|
tool_call_log = result.final_output_message
|
||||||
tool_response_message = result.final_input_message
|
tool_response_message = result.final_input_message
|
||||||
self.sequence_number = result.sequence_number
|
self.sequence_number = result.sequence_number
|
||||||
|
if result.citation_files:
|
||||||
|
self.citation_files.update(result.citation_files)
|
||||||
|
|
||||||
if tool_call_log:
|
if tool_call_log:
|
||||||
output_messages.append(tool_call_log)
|
output_messages.append(tool_call_log)
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,10 @@ class ToolExecutor:
|
||||||
|
|
||||||
# Yield the final result
|
# Yield the final result
|
||||||
yield ToolExecutionResult(
|
yield ToolExecutionResult(
|
||||||
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
sequence_number=sequence_number,
|
||||||
|
final_output_message=output_message,
|
||||||
|
final_input_message=input_message,
|
||||||
|
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_knowledge_search_via_vector_store(
|
async def _execute_knowledge_search_via_vector_store(
|
||||||
|
|
@ -129,8 +132,6 @@ class ToolExecutor:
|
||||||
for results in all_results:
|
for results in all_results:
|
||||||
search_results.extend(results)
|
search_results.extend(results)
|
||||||
|
|
||||||
# Convert search results to tool result format matching memory.py
|
|
||||||
# Format the results as interleaved content similar to memory.py
|
|
||||||
content_items = []
|
content_items = []
|
||||||
content_items.append(
|
content_items.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
|
|
@ -138,27 +139,58 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unique_files = set()
|
||||||
for i, result_item in enumerate(search_results):
|
for i, result_item in enumerate(search_results):
|
||||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
# Get file_id from attributes if result_item.file_id is empty
|
||||||
|
file_id = result_item.file_id or (
|
||||||
|
result_item.attributes.get("document_id") if result_item.attributes else None
|
||||||
|
)
|
||||||
|
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
|
||||||
if result_item.attributes:
|
if result_item.attributes:
|
||||||
metadata_text += f", attributes: {result_item.attributes}"
|
metadata_text += f", attributes: {result_item.attributes}"
|
||||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
|
||||||
|
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
|
||||||
content_items.append(TextContentItem(text=text_content))
|
content_items.append(TextContentItem(text=text_content))
|
||||||
|
unique_files.add(file_id)
|
||||||
|
|
||||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
|
||||||
|
citation_instruction = ""
|
||||||
|
if unique_files:
|
||||||
|
citation_instruction = (
|
||||||
|
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
|
||||||
|
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
|
||||||
|
)
|
||||||
|
|
||||||
content_items.append(
|
content_items.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# handling missing attributes for old versions
|
||||||
|
citation_files = {}
|
||||||
|
for result in search_results:
|
||||||
|
file_id = result.file_id
|
||||||
|
if not file_id and result.attributes:
|
||||||
|
file_id = result.attributes.get("document_id")
|
||||||
|
|
||||||
|
filename = result.filename
|
||||||
|
if not filename and result.attributes:
|
||||||
|
filename = result.attributes.get("filename")
|
||||||
|
if not filename:
|
||||||
|
filename = "unknown"
|
||||||
|
|
||||||
|
citation_files[file_id] = filename
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=content_items,
|
content=content_items,
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [r.file_id for r in search_results],
|
"document_ids": [r.file_id for r in search_results],
|
||||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||||
"scores": [r.score for r in search_results],
|
"scores": [r.score for r in search_results],
|
||||||
|
"citation_files": citation_files,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel):
|
||||||
sequence_number: int
|
sequence_number: int
|
||||||
final_output_message: OpenAIResponseOutput | None = None
|
final_output_message: OpenAIResponseOutput | None = None
|
||||||
final_input_message: OpenAIMessageParam | None = None
|
final_input_message: OpenAIMessageParam | None = None
|
||||||
|
citation_files: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseAnnotationFileCitation,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputFunctionToolCallOutput,
|
OpenAIResponseInputFunctionToolCallOutput,
|
||||||
OpenAIResponseInputMessageContent,
|
OpenAIResponseInputMessageContent,
|
||||||
|
|
@ -45,7 +47,9 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
async def convert_chat_choice_to_response_message(
|
||||||
|
choice: OpenAIChoice, citation_files: dict[str, str] | None = None
|
||||||
|
) -> OpenAIResponseMessage:
|
||||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||||
output_content = ""
|
output_content = ""
|
||||||
if isinstance(choice.message.content, str):
|
if isinstance(choice.message.content, str):
|
||||||
|
|
@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
|
||||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||||
|
|
||||||
return OpenAIResponseMessage(
|
return OpenAIResponseMessage(
|
||||||
id=f"msg_{uuid.uuid4()}",
|
id=f"msg_{uuid.uuid4()}",
|
||||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||||
status="completed",
|
status="completed",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
|
|
@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str):
|
||||||
return role_to_type.get(role)
|
return role_to_type.get(role)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_citations_from_text(
|
||||||
|
text: str, citation_files: dict[str, str]
|
||||||
|
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
||||||
|
"""Extract citation markers from text and create annotations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
||||||
|
citation_files: Dictionary mapping file_id to filename
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (annotations_list, clean_text_without_markers)
|
||||||
|
"""
|
||||||
|
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
parts = []
|
||||||
|
total_len = 0
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
for m in file_id_regex.finditer(text):
|
||||||
|
# segment before the marker
|
||||||
|
prefix = text[last_end : m.start()]
|
||||||
|
|
||||||
|
# drop one space if it exists (since marker is at sentence end)
|
||||||
|
if prefix.endswith(" "):
|
||||||
|
prefix = prefix[:-1]
|
||||||
|
|
||||||
|
parts.append(prefix)
|
||||||
|
total_len += len(prefix)
|
||||||
|
|
||||||
|
fid = m.group(1)
|
||||||
|
if fid in citation_files:
|
||||||
|
annotations.append(
|
||||||
|
OpenAIResponseAnnotationFileCitation(
|
||||||
|
file_id=fid,
|
||||||
|
filename=citation_files[fid],
|
||||||
|
index=total_len, # index points to punctuation
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
last_end = m.end()
|
||||||
|
|
||||||
|
parts.append(text[last_end:])
|
||||||
|
cleaned_text = "".join(parts)
|
||||||
|
return annotations, cleaned_text
|
||||||
|
|
||||||
|
|
||||||
def is_function_tool_call(
|
def is_function_tool_call(
|
||||||
tool_call: OpenAIChatCompletionToolCall,
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
tools: list[OpenAIResponseInputTool],
|
tools: list[OpenAIResponseInputTool],
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,6 @@ import asyncio
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import secrets
|
|
||||||
import string
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -52,10 +50,6 @@ from .context_retriever import generate_rag_query
|
||||||
log = get_logger(name=__name__, category="tool_runtime")
|
log = get_logger(name=__name__, category="tool_runtime")
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
|
||||||
|
|
||||||
|
|
||||||
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||||
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||||
if isinstance(doc.content, URL):
|
if isinstance(doc.content, URL):
|
||||||
|
|
@ -331,5 +325,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=result.content or [],
|
content=result.content or [],
|
||||||
metadata=result.metadata,
|
metadata={
|
||||||
|
**(result.metadata or {}),
|
||||||
|
"citation_files": getattr(result, "citation_files", None),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -225,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# Cleanup if needed
|
# Clean up mixin resources (file batch tasks)
|
||||||
pass
|
await super().shutdown()
|
||||||
|
|
||||||
async def health(self) -> HealthResponse:
|
async def health(self) -> HealthResponse:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -434,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# nothing to do since we don't maintain a persistent connection
|
# Clean up mixin resources (file batch tasks)
|
||||||
pass
|
await super().shutdown()
|
||||||
|
|
||||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||||
return [v.vector_db for v in self.cache.values()]
|
return [v.vector_db for v in self.cache.values()]
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,9 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
Api.tool_runtime,
|
Api.tool_runtime,
|
||||||
Api.tool_groups,
|
Api.tool_groups,
|
||||||
],
|
],
|
||||||
|
optional_api_dependencies=[
|
||||||
|
Api.telemetry,
|
||||||
|
],
|
||||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -268,7 +268,7 @@ Available Models:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter_type="watsonx",
|
adapter_type="watsonx",
|
||||||
provider_type="remote::watsonx",
|
provider_type="remote::watsonx",
|
||||||
pip_packages=["ibm_watsonx_ai"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.watsonx",
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import (
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> list[ProviderSpec]:
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
|
@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::rag-runtime",
|
provider_type="inline::rag-runtime",
|
||||||
pip_packages=[
|
pip_packages=DEFAULT_VECTOR_IO_DEPS
|
||||||
"chardet",
|
+ [
|
||||||
"pypdf",
|
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"numpy",
|
"numpy",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
|
|
||||||
|
|
@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import (
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Common dependencies for all vector IO providers that support document processing
|
||||||
|
DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"]
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> list[ProviderSpec]:
|
def available_providers() -> list[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::meta-reference",
|
||||||
pip_packages=["faiss-cpu"],
|
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.inline.vector_io.faiss",
|
module="llama_stack.providers.inline.vector_io.faiss",
|
||||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
|
|
@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
pip_packages=["faiss-cpu"],
|
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.inline.vector_io.faiss",
|
module="llama_stack.providers.inline.vector_io.faiss",
|
||||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -82,7 +85,7 @@ more details about Faiss in general.
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::sqlite-vec",
|
provider_type="inline::sqlite-vec",
|
||||||
pip_packages=["sqlite-vec"],
|
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::sqlite_vec",
|
provider_type="inline::sqlite_vec",
|
||||||
pip_packages=["sqlite-vec"],
|
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||||
|
|
@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation.
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
adapter_type="chromadb",
|
adapter_type="chromadb",
|
||||||
provider_type="remote::chromadb",
|
provider_type="remote::chromadb",
|
||||||
pip_packages=["chromadb-client"],
|
pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.remote.vector_io.chroma",
|
module="llama_stack.providers.remote.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::chromadb",
|
provider_type="inline::chromadb",
|
||||||
pip_packages=["chromadb"],
|
pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.inline.vector_io.chroma",
|
module="llama_stack.providers.inline.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
adapter_type="pgvector",
|
adapter_type="pgvector",
|
||||||
provider_type="remote::pgvector",
|
provider_type="remote::pgvector",
|
||||||
pip_packages=["psycopg2-binary"],
|
pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
adapter_type="weaviate",
|
adapter_type="weaviate",
|
||||||
provider_type="remote::weaviate",
|
provider_type="remote::weaviate",
|
||||||
pip_packages=["weaviate-client>=4.16.5"],
|
pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||||
|
|
@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::qdrant",
|
provider_type="inline::qdrant",
|
||||||
pip_packages=["qdrant-client"],
|
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.inline.vector_io.qdrant",
|
module="llama_stack.providers.inline.vector_io.qdrant",
|
||||||
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
adapter_type="qdrant",
|
adapter_type="qdrant",
|
||||||
provider_type="remote::qdrant",
|
provider_type="remote::qdrant",
|
||||||
pip_packages=["qdrant-client"],
|
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -607,7 +610,7 @@ Please refer to the inline provider documentation.
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
adapter_type="milvus",
|
adapter_type="milvus",
|
||||||
provider_type="remote::milvus",
|
provider_type="remote::milvus",
|
||||||
pip_packages=["pymilvus>=2.4.10"],
|
pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.remote.vector_io.milvus",
|
module="llama_stack.providers.remote.vector_io.milvus",
|
||||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
@ -813,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::milvus",
|
provider_type="inline::milvus",
|
||||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
module="llama_stack.providers.inline.vector_io.milvus",
|
module="llama_stack.providers.inline.vector_io.milvus",
|
||||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
|
||||||
|
|
@ -41,9 +41,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
).serving_endpoints.list() # TODO: this is not async
|
).serving_endpoints.list() # TODO: this is not async
|
||||||
]
|
]
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
||||||
|
|
@ -1,217 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openai import AsyncStream
|
|
||||||
from openai.types.chat.chat_completion import (
|
|
||||||
Choice as OpenAIChoice,
|
|
||||||
)
|
|
||||||
from openai.types.completion import Completion as OpenAICompletion
|
|
||||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
CompletionRequest,
|
|
||||||
CompletionResponse,
|
|
||||||
CompletionResponseStreamChunk,
|
|
||||||
GreedySamplingStrategy,
|
|
||||||
JsonSchemaResponseFormat,
|
|
||||||
TokenLogProbs,
|
|
||||||
TopKSamplingStrategy,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
_convert_openai_finish_reason,
|
|
||||||
convert_message_to_openai_dict_new,
|
|
||||||
convert_tooldef_to_openai_tool,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_chat_completion_request(
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
n: int = 1,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
|
||||||
"""
|
|
||||||
# model -> model
|
|
||||||
# messages -> messages
|
|
||||||
# sampling_params TODO(mattf): review strategy
|
|
||||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
|
||||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
|
||||||
# strategy=top_k -> nvext.top_k = top_k
|
|
||||||
# temperature -> temperature
|
|
||||||
# top_p -> top_p
|
|
||||||
# top_k -> nvext.top_k
|
|
||||||
# max_tokens -> max_tokens
|
|
||||||
# repetition_penalty -> nvext.repetition_penalty
|
|
||||||
# response_format -> GrammarResponseFormat TODO(mf)
|
|
||||||
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
|
|
||||||
# tools -> tools
|
|
||||||
# tool_choice ("auto", "required") -> tool_choice
|
|
||||||
# tool_prompt_format -> TBD
|
|
||||||
# stream -> stream
|
|
||||||
# logprobs -> logprobs
|
|
||||||
|
|
||||||
if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat):
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
nvext = {}
|
|
||||||
payload: dict[str, Any] = dict(
|
|
||||||
model=request.model,
|
|
||||||
messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
|
|
||||||
stream=request.stream,
|
|
||||||
n=n,
|
|
||||||
extra_body=dict(nvext=nvext),
|
|
||||||
extra_headers={
|
|
||||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.response_format:
|
|
||||||
# server bug - setting guided_json changes the behavior of response_format resulting in an error
|
|
||||||
# payload.update(response_format="json_object")
|
|
||||||
nvext.update(guided_json=request.response_format.json_schema)
|
|
||||||
|
|
||||||
if request.tools:
|
|
||||||
payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools])
|
|
||||||
if request.tool_config.tool_choice:
|
|
||||||
payload.update(
|
|
||||||
tool_choice=request.tool_config.tool_choice.value
|
|
||||||
) # we cannot include tool_choice w/o tools, server will complain
|
|
||||||
|
|
||||||
if request.logprobs:
|
|
||||||
payload.update(logprobs=True)
|
|
||||||
payload.update(top_logprobs=request.logprobs.top_k)
|
|
||||||
|
|
||||||
if request.sampling_params:
|
|
||||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
|
||||||
|
|
||||||
if request.sampling_params.max_tokens:
|
|
||||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
|
||||||
|
|
||||||
strategy = request.sampling_params.strategy
|
|
||||||
if isinstance(strategy, TopPSamplingStrategy):
|
|
||||||
nvext.update(top_k=-1)
|
|
||||||
payload.update(top_p=strategy.top_p)
|
|
||||||
payload.update(temperature=strategy.temperature)
|
|
||||||
elif isinstance(strategy, TopKSamplingStrategy):
|
|
||||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
|
||||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
|
||||||
nvext.update(top_k=strategy.top_k)
|
|
||||||
elif isinstance(strategy, GreedySamplingStrategy):
|
|
||||||
nvext.update(top_k=-1)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported sampling strategy: {strategy}")
|
|
||||||
|
|
||||||
return payload
|
|
||||||
|
|
||||||
|
|
||||||
def convert_completion_request(
|
|
||||||
request: CompletionRequest,
|
|
||||||
n: int = 1,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
|
||||||
"""
|
|
||||||
# model -> model
|
|
||||||
# prompt -> prompt
|
|
||||||
# sampling_params TODO(mattf): review strategy
|
|
||||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
|
||||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
|
||||||
# strategy=top_k -> nvext.top_k = top_k
|
|
||||||
# temperature -> temperature
|
|
||||||
# top_p -> top_p
|
|
||||||
# top_k -> nvext.top_k
|
|
||||||
# max_tokens -> max_tokens
|
|
||||||
# repetition_penalty -> nvext.repetition_penalty
|
|
||||||
# response_format -> nvext.guided_json
|
|
||||||
# stream -> stream
|
|
||||||
# logprobs.top_k -> logprobs
|
|
||||||
|
|
||||||
nvext = {}
|
|
||||||
payload: dict[str, Any] = dict(
|
|
||||||
model=request.model,
|
|
||||||
prompt=request.content,
|
|
||||||
stream=request.stream,
|
|
||||||
extra_body=dict(nvext=nvext),
|
|
||||||
extra_headers={
|
|
||||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
|
||||||
},
|
|
||||||
n=n,
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.response_format:
|
|
||||||
# this is not openai compliant, it is a nim extension
|
|
||||||
nvext.update(guided_json=request.response_format.json_schema)
|
|
||||||
|
|
||||||
if request.logprobs:
|
|
||||||
payload.update(logprobs=request.logprobs.top_k)
|
|
||||||
|
|
||||||
if request.sampling_params:
|
|
||||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
|
||||||
|
|
||||||
if request.sampling_params.max_tokens:
|
|
||||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
|
||||||
|
|
||||||
if request.sampling_params.strategy == "top_p":
|
|
||||||
nvext.update(top_k=-1)
|
|
||||||
payload.update(top_p=request.sampling_params.top_p)
|
|
||||||
elif request.sampling_params.strategy == "top_k":
|
|
||||||
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
|
|
||||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
|
||||||
nvext.update(top_k=request.sampling_params.top_k)
|
|
||||||
elif request.sampling_params.strategy == "greedy":
|
|
||||||
nvext.update(top_k=-1)
|
|
||||||
payload.update(temperature=request.sampling_params.temperature)
|
|
||||||
|
|
||||||
return payload
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_completion_logprobs(
|
|
||||||
logprobs: OpenAICompletionLogprobs | None,
|
|
||||||
) -> list[TokenLogProbs] | None:
|
|
||||||
"""
|
|
||||||
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
|
||||||
"""
|
|
||||||
if not logprobs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_completion_choice(
|
|
||||||
choice: OpenAIChoice,
|
|
||||||
) -> CompletionResponse:
|
|
||||||
"""
|
|
||||||
Convert an OpenAI Completion Choice into a CompletionResponse.
|
|
||||||
"""
|
|
||||||
return CompletionResponse(
|
|
||||||
content=choice.text,
|
|
||||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
|
||||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_openai_completion_stream(
|
|
||||||
stream: AsyncStream[OpenAICompletion],
|
|
||||||
) -> AsyncGenerator[CompletionResponse, None]:
|
|
||||||
"""
|
|
||||||
Convert a stream of OpenAI Completions into a stream
|
|
||||||
of ChatCompletionResponseStreamChunks.
|
|
||||||
"""
|
|
||||||
async for chunk in stream:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
yield CompletionResponseStreamChunk(
|
|
||||||
delta=choice.text,
|
|
||||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
|
||||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
|
||||||
)
|
|
||||||
|
|
@ -4,53 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from . import NVIDIAConfig
|
from . import NVIDIAConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
return "integrate.api.nvidia.com" in config.url
|
return "integrate.api.nvidia.com" in config.url
|
||||||
|
|
||||||
|
|
||||||
async def _get_health(url: str) -> tuple[bool, bool]:
|
|
||||||
"""
|
|
||||||
Query {url}/v1/health/{live,ready} to check if the server is running and ready
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): URL of the server
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, bool]: (is_live, is_ready)
|
|
||||||
"""
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
live = await client.get(f"{url}/v1/health/live")
|
|
||||||
ready = await client.get(f"{url}/v1/health/ready")
|
|
||||||
return live.status_code == 200, ready.status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
async def check_health(config: NVIDIAConfig) -> None:
|
|
||||||
"""
|
|
||||||
Check if the server is running and ready
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): URL of the server
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the server is not running or ready
|
|
||||||
"""
|
|
||||||
if not _is_nvidia_hosted(config):
|
|
||||||
logger.info("Checking NVIDIA NIM health...")
|
|
||||||
try:
|
|
||||||
is_live, is_ready = await _get_health(config.url)
|
|
||||||
if not is_live:
|
|
||||||
raise ConnectionError("NVIDIA NIM is not running")
|
|
||||||
if not is_ready:
|
|
||||||
raise ConnectionError("NVIDIA NIM is not ready")
|
|
||||||
# TODO(mf): should we wait for the server to be ready?
|
|
||||||
except httpx.ConnectError as e:
|
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,6 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
|
|
||||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||||
|
|
@ -15,10 +13,6 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||||
|
|
||||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = DEFAULT_OLLAMA_URL
|
url: str = DEFAULT_OLLAMA_URL
|
||||||
refresh_models: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="Whether to refresh models periodically",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -72,9 +72,6 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
||||||
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
|
||||||
return self.config.refresh_models
|
|
||||||
|
|
||||||
async def health(self) -> HealthResponse:
|
async def health(self) -> HealthResponse:
|
||||||
"""
|
"""
|
||||||
Performs a health check by verifying connectivity to the Ollama server.
|
Performs a health check by verifying connectivity to the Ollama server.
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,6 @@ async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
||||||
from .runpod import RunpodInferenceAdapter
|
from .runpod import RunpodInferenceAdapter
|
||||||
|
|
||||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||||
impl = RunpodInferenceAdapter(config)
|
impl = RunpodInferenceAdapter(config=config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -4,69 +4,86 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
get_sampling_options,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import RunpodImplConfig
|
from .config import RunpodImplConfig
|
||||||
|
|
||||||
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
|
|
||||||
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
|
|
||||||
RUNPOD_SUPPORTED_MODELS = {
|
|
||||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
|
||||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
|
||||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
|
||||||
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
|
||||||
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
|
||||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
|
||||||
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
|
||||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
|
||||||
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
|
||||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
|
||||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
|
||||||
}
|
|
||||||
|
|
||||||
SAFETY_MODELS_ENTRIES = []
|
class RunpodInferenceAdapter(OpenAIMixin):
|
||||||
|
"""
|
||||||
|
Adapter for RunPod's OpenAI-compatible API endpoints.
|
||||||
|
Supports VLLM for serverless endpoint self-hosted or public endpoints.
|
||||||
|
Can work with any runpod endpoints that support OpenAI-compatible API
|
||||||
|
"""
|
||||||
|
|
||||||
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
|
config: RunpodImplConfig
|
||||||
MODEL_ENTRIES = [
|
|
||||||
build_hf_repo_model_entry(provider_model_id, model_descriptor)
|
|
||||||
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
|
|
||||||
] + SAFETY_MODELS_ENTRIES
|
|
||||||
|
|
||||||
|
def get_api_key(self) -> str:
|
||||||
|
"""Get API key for OpenAI client."""
|
||||||
|
return self.config.api_token
|
||||||
|
|
||||||
class RunpodInferenceAdapter(
|
def get_base_url(self) -> str:
|
||||||
ModelRegistryHelper,
|
"""Get base URL for OpenAI client."""
|
||||||
Inference,
|
return self.config.url
|
||||||
):
|
|
||||||
def __init__(self, config: RunpodImplConfig) -> None:
|
|
||||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def openai_chat_completion(
|
||||||
return {
|
|
||||||
"model": self.map_to_provider_model(request.model),
|
|
||||||
"prompt": chat_completion_request_to_prompt(request),
|
|
||||||
"stream": request.stream,
|
|
||||||
**get_sampling_options(request.sampling_params),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
input: str | list[str],
|
messages: list[OpenAIMessageParam],
|
||||||
encoding_format: str | None = "float",
|
frequency_penalty: float | None = None,
|
||||||
dimensions: int | None = None,
|
function_call: str | dict[str, Any] | None = None,
|
||||||
|
functions: list[dict[str, Any]] | None = None,
|
||||||
|
logit_bias: dict[str, float] | None = None,
|
||||||
|
logprobs: bool | None = None,
|
||||||
|
max_completion_tokens: int | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
n: int | None = None,
|
||||||
|
parallel_tool_calls: bool | None = None,
|
||||||
|
presence_penalty: float | None = None,
|
||||||
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
|
seed: int | None = None,
|
||||||
|
stop: str | list[str] | None = None,
|
||||||
|
stream: bool | None = None,
|
||||||
|
stream_options: dict[str, Any] | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
top_logprobs: int | None = None,
|
||||||
|
top_p: float | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
):
|
||||||
raise NotImplementedError()
|
"""Override to add RunPod-specific stream_options requirement."""
|
||||||
|
if stream and not stream_options:
|
||||||
|
stream_options = {"include_usage": True}
|
||||||
|
|
||||||
|
return await super().openai_chat_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
function_call=function_call,
|
||||||
|
functions=functions,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
response_format=response_format,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tools=tools,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -63,9 +63,6 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||||
return [m.id for m in await self._get_client().models.list()]
|
return [m.id for m in await self._get_client().models.list()]
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,6 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||||
)
|
)
|
||||||
refresh_models: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="Whether to refresh models periodically",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("tls_verify")
|
@field_validator("tls_verify")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -53,10 +53,6 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
||||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
|
||||||
# Strictly respecting the refresh_models directive
|
|
||||||
return self.config.refresh_models
|
|
||||||
|
|
||||||
async def health(self) -> HealthResponse:
|
async def health(self) -> HealthResponse:
|
||||||
"""
|
"""
|
||||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||||
|
|
|
||||||
|
|
@ -4,19 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference
|
|
||||||
|
|
||||||
from .config import WatsonXConfig
|
from .config import WatsonXConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
async def get_adapter_impl(config: WatsonXConfig, _deps):
|
||||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
# import dynamically so the import is used only when it is needed
|
||||||
from .watsonx import WatsonXInferenceAdapter
|
from .watsonx import WatsonXInferenceAdapter
|
||||||
|
|
||||||
if not isinstance(config, WatsonXConfig):
|
|
||||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
|
||||||
adapter = WatsonXInferenceAdapter(config)
|
adapter = WatsonXInferenceAdapter(config)
|
||||||
return adapter
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
|
||||||
|
|
|
||||||
|
|
@ -7,16 +7,18 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class WatsonXProviderDataValidator(BaseModel):
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
url: str
|
model_config = ConfigDict(
|
||||||
api_key: str
|
from_attributes=True,
|
||||||
project_id: str
|
extra="forbid",
|
||||||
|
)
|
||||||
|
watsonx_api_key: str | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
description="A base url for accessing the watsonx.ai",
|
description="A base url for accessing the watsonx.ai",
|
||||||
)
|
)
|
||||||
|
# This seems like it should be required, but none of the other remote inference
|
||||||
|
# providers require it, so this is optional here too for consistency.
|
||||||
|
# The OpenAIConfig uses default=None instead, so this is following that precedent.
|
||||||
api_key: SecretStr | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
default=None,
|
||||||
description="The watsonx API key",
|
description="The watsonx.ai API key",
|
||||||
)
|
)
|
||||||
|
# As above, this is optional here too for consistency.
|
||||||
project_id: str | None = Field(
|
project_id: str | None = Field(
|
||||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
default=None,
|
||||||
description="The Project ID key",
|
description="The watsonx.ai project ID",
|
||||||
)
|
)
|
||||||
timeout: int = Field(
|
timeout: int = Field(
|
||||||
default=60,
|
default=60,
|
||||||
|
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from llama_stack.models.llama.sku_types import CoreModelId
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
|
||||||
|
|
||||||
MODEL_ENTRIES = [
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-3-70b-instruct",
|
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-2-13b-chat",
|
|
||||||
CoreModelId.llama2_13b.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-1-70b-instruct",
|
|
||||||
CoreModelId.llama3_1_70b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-1-8b-instruct",
|
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-2-11b-vision-instruct",
|
|
||||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-2-1b-instruct",
|
|
||||||
CoreModelId.llama3_2_1b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-2-3b-instruct",
|
|
||||||
CoreModelId.llama3_2_3b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-2-90b-vision-instruct",
|
|
||||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-guard-3-11b-vision",
|
|
||||||
CoreModelId.llama_guard_3_11b_vision.value,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
@ -4,240 +4,120 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ibm_watsonx_ai.foundation_models import Model
|
import requests
|
||||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import ChatCompletionRequest
|
||||||
ChatCompletionRequest,
|
from llama_stack.apis.models import Model
|
||||||
CompletionRequest,
|
from llama_stack.apis.models.models import ModelType
|
||||||
GreedySamplingStrategy,
|
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||||
Inference,
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
TopKSamplingStrategy,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
prepare_openai_completion_params,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
completion_request_to_prompt,
|
|
||||||
request_has_media,
|
|
||||||
)
|
|
||||||
|
|
||||||
from . import WatsonXConfig
|
|
||||||
from .models import MODEL_ENTRIES
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference::watsonx")
|
|
||||||
|
|
||||||
|
|
||||||
# Note on structured output
|
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
# WatsonX returns responses with a json embedded into a string.
|
_model_cache: dict[str, Model] = {}
|
||||||
# Examples:
|
|
||||||
|
|
||||||
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
|
def __init__(self, config: WatsonXConfig):
|
||||||
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
|
LiteLLMOpenAIMixin.__init__(
|
||||||
# Not even a valid JSON, but we can still extract the JSON from the content
|
self,
|
||||||
|
litellm_provider_name="watsonx",
|
||||||
|
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||||
|
provider_data_api_key_field="watsonx_api_key",
|
||||||
|
)
|
||||||
|
self.available_models = None
|
||||||
|
self.config = config
|
||||||
|
|
||||||
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
|
def get_base_url(self) -> str:
|
||||||
# "year_born": "1963", "year_retired": "2003"\\}}$')
|
return self.config.url
|
||||||
# Find the start of the boxed content
|
|
||||||
|
|
||||||
|
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||||
|
# Get base parameters from parent
|
||||||
|
params = await super()._get_params(request)
|
||||||
|
|
||||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
# Add watsonx.ai specific parameters
|
||||||
def __init__(self, config: WatsonXConfig) -> None:
|
params["project_id"] = self.config.project_id
|
||||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
params["time_limit"] = self.config.timeout
|
||||||
|
|
||||||
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
|
|
||||||
self._config = config
|
|
||||||
self._openai_client: AsyncOpenAI | None = None
|
|
||||||
|
|
||||||
self._project_id = self._config.project_id
|
|
||||||
|
|
||||||
def _get_client(self, model_id) -> Model:
|
|
||||||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
|
||||||
config_url = self._config.url
|
|
||||||
project_id = self._config.project_id
|
|
||||||
credentials = {"url": config_url, "apikey": config_api_key}
|
|
||||||
|
|
||||||
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
|
||||||
|
|
||||||
def _get_openai_client(self) -> AsyncOpenAI:
|
|
||||||
if not self._openai_client:
|
|
||||||
self._openai_client = AsyncOpenAI(
|
|
||||||
base_url=f"{self._config.url}/openai/v1",
|
|
||||||
api_key=self._config.api_key,
|
|
||||||
)
|
|
||||||
return self._openai_client
|
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
|
||||||
input_dict = {"params": {}}
|
|
||||||
media_present = request_has_media(request)
|
|
||||||
llama_model = self.get_llama_model(request.model)
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
|
||||||
else:
|
|
||||||
assert not media_present, "Together does not support media for Completion requests"
|
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
|
||||||
if request.sampling_params:
|
|
||||||
if request.sampling_params.strategy:
|
|
||||||
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
|
||||||
if request.sampling_params.max_tokens:
|
|
||||||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
|
||||||
if request.sampling_params.repetition_penalty:
|
|
||||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
|
||||||
|
|
||||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
|
||||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
|
||||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
|
||||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
|
||||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
|
||||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
|
||||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
|
||||||
|
|
||||||
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
|
||||||
|
|
||||||
params = {
|
|
||||||
**input_dict,
|
|
||||||
}
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
async def openai_embeddings(
|
# Copied from OpenAIMixin
|
||||||
self,
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
model: str,
|
"""
|
||||||
input: str | list[str],
|
Check if a specific model is available from the provider's /v1/models.
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def openai_completion(
|
:param model: The model identifier to check.
|
||||||
self,
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
model: str,
|
"""
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
if not self._model_cache:
|
||||||
best_of: int | None = None,
|
await self.list_models()
|
||||||
echo: bool | None = None,
|
return model in self._model_cache
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
prompt=prompt,
|
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def list_models(self) -> list[Model] | None:
|
||||||
self,
|
self._model_cache = {}
|
||||||
model: str,
|
models = []
|
||||||
messages: list[OpenAIMessageParam],
|
for model_spec in self._get_model_specs():
|
||||||
frequency_penalty: float | None = None,
|
functions = [f["id"] for f in model_spec.get("functions", [])]
|
||||||
function_call: str | dict[str, Any] | None = None,
|
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if params.get("stream", False):
|
|
||||||
return self._stream_openai_chat_completion(params)
|
|
||||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
|
||||||
|
|
||||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
# Example of an embedding model:
|
||||||
# watsonx.ai sometimes adds usage data to the stream
|
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
||||||
include_usage = False
|
# 'label': 'granite-embedding-278m-multilingual',
|
||||||
if params.get("stream_options", None):
|
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
||||||
include_usage = params["stream_options"].get("include_usage", False)
|
# ...
|
||||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
||||||
|
if "embedding" in functions:
|
||||||
|
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
||||||
|
context_length = model_spec["model_limits"]["max_sequence_length"]
|
||||||
|
embedding_metadata = {
|
||||||
|
"embedding_dimension": embedding_dimension,
|
||||||
|
"context_length": context_length,
|
||||||
|
}
|
||||||
|
model = Model(
|
||||||
|
identifier=model_spec["model_id"],
|
||||||
|
provider_resource_id=provider_resource_id,
|
||||||
|
provider_id=self.__provider_id__,
|
||||||
|
metadata=embedding_metadata,
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
self._model_cache[provider_resource_id] = model
|
||||||
|
models.append(model)
|
||||||
|
if "text_chat" in functions:
|
||||||
|
model = Model(
|
||||||
|
identifier=model_spec["model_id"],
|
||||||
|
provider_resource_id=provider_resource_id,
|
||||||
|
provider_id=self.__provider_id__,
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
||||||
|
# In that case, the cache will record the generator Model object, and the list which we return will have
|
||||||
|
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
||||||
|
# only used for check_model_availability() anyway.
|
||||||
|
self._model_cache[provider_resource_id] = model
|
||||||
|
models.append(model)
|
||||||
|
return models
|
||||||
|
|
||||||
seen_finish_reason = False
|
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
||||||
async for chunk in stream:
|
# So we need to implement our own method to list models by calling the watsonx.ai API.
|
||||||
# Final usage chunk with no choices that the user didn't request, so discard
|
def _get_model_specs(self) -> list[dict[str, Any]]:
|
||||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
"""
|
||||||
break
|
Retrieves foundation model specifications from the watsonx.ai API.
|
||||||
yield chunk
|
"""
|
||||||
for choice in chunk.choices:
|
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||||
if choice.finish_reason:
|
headers = {
|
||||||
seen_finish_reason = True
|
# Note that there is no authorization header. Listing models does not require authentication.
|
||||||
break
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(url, headers=headers)
|
||||||
|
|
||||||
|
# --- Process the Response ---
|
||||||
|
# Raise an exception for bad status codes (4xx or 5xx)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# If the request is successful, parse and return the JSON response.
|
||||||
|
# The response should contain a list of model specifications
|
||||||
|
response_data = response.json()
|
||||||
|
if "resources" not in response_data:
|
||||||
|
raise ValueError("Resources not found in response")
|
||||||
|
return response_data["resources"]
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -349,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -390,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
if self.conn is not None:
|
if self.conn is not None:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
log.info("Connection to PGVector database server closed")
|
log.info("Connection to PGVector database server closed")
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
# Persist vector DB metadata in the KV store
|
# Persist vector DB metadata in the KV store
|
||||||
|
|
|
||||||
|
|
@ -191,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
await self.client.close()
|
await self.client.close()
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -347,6 +347,8 @@ class WeaviateVectorIOAdapter(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for client in self.client_cache.values():
|
for client in self.client_cache.values():
|
||||||
client.close()
|
client.close()
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
|
@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
b64_encode_openai_embeddings_response,
|
|
||||||
convert_message_to_openai_dict_new,
|
convert_message_to_openai_dict_new,
|
||||||
convert_tooldef_to_openai_tool,
|
convert_tooldef_to_openai_tool,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
|
@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||||
|
|
||||||
|
|
||||||
|
def b64_encode_openai_embeddings_response(
|
||||||
|
response_data: list[dict], encoding_format: str | None = "float"
|
||||||
|
) -> list[OpenAIEmbeddingData]:
|
||||||
|
"""
|
||||||
|
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response_data):
|
||||||
|
if encoding_format == "base64":
|
||||||
|
byte_array = bytearray()
|
||||||
|
for embedding_value in embedding_data["embedding"]:
|
||||||
|
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||||
|
|
||||||
|
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||||
|
else:
|
||||||
|
response_embedding = embedding_data["embedding"]
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=response_embedding,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,10 @@ class RemoteInferenceProviderConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||||
)
|
)
|
||||||
|
refresh_models: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to refresh models periodically from the provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: this class is more confusing than useful right now. We need to make it
|
# TODO: this class is more confusing than useful right now. We need to make it
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import struct
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
|
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
|
||||||
params["user"] = user
|
params["user"] = user
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def b64_encode_openai_embeddings_response(
|
|
||||||
response_data: dict, encoding_format: str | None = "float"
|
|
||||||
) -> list[OpenAIEmbeddingData]:
|
|
||||||
"""
|
|
||||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
|
||||||
"""
|
|
||||||
data = []
|
|
||||||
for i, embedding_data in enumerate(response_data):
|
|
||||||
if encoding_format == "base64":
|
|
||||||
byte_array = bytearray()
|
|
||||||
for embedding_value in embedding_data.embedding:
|
|
||||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
|
||||||
|
|
||||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
|
||||||
else:
|
|
||||||
response_embedding = embedding_data.embedding
|
|
||||||
data.append(
|
|
||||||
OpenAIEmbeddingData(
|
|
||||||
embedding=response_embedding,
|
|
||||||
index=i,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
|
||||||
|
|
@ -474,17 +474,23 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a specific model is available from the provider's /v1/models.
|
Check if a specific model is available from the provider's /v1/models or pre-registered.
|
||||||
|
|
||||||
:param model: The model identifier to check.
|
:param model: The model identifier to check.
|
||||||
:return: True if the model is available dynamically, False otherwise.
|
:return: True if the model is available dynamically or pre-registered, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
# First check if the model is pre-registered in the model store
|
||||||
|
if hasattr(self, "model_store") and self.model_store:
|
||||||
|
if await self.model_store.has_model(model):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Then check the provider's dynamic model cache
|
||||||
if not self._model_cache:
|
if not self._model_cache:
|
||||||
await self.list_models()
|
await self.list_models()
|
||||||
return model in self._model_cache
|
return model in self._model_cache
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
async def should_refresh_models(self) -> bool:
|
||||||
return False
|
return self.config.refresh_models
|
||||||
|
|
||||||
#
|
#
|
||||||
# The model_dump implementations are to avoid serializing the extra fields,
|
# The model_dump implementations are to avoid serializing the extra fields,
|
||||||
|
|
|
||||||
|
|
@ -293,6 +293,18 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
await self._resume_incomplete_batches()
|
await self._resume_incomplete_batches()
|
||||||
self._last_file_batch_cleanup_time = 0
|
self._last_file_batch_cleanup_time = 0
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Clean up mixin resources including background tasks."""
|
||||||
|
# Cancel any running file batch tasks gracefully
|
||||||
|
tasks_to_cancel = list(self._file_batch_tasks.items())
|
||||||
|
for _, task in tasks_to_cancel:
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
"""Delete chunks from a vector store."""
|
"""Delete chunks from a vector store."""
|
||||||
|
|
@ -587,7 +599,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
content = self._chunk_to_vector_store_content(chunk)
|
content = self._chunk_to_vector_store_content(chunk)
|
||||||
|
|
||||||
response_data_item = VectorStoreSearchResponse(
|
response_data_item = VectorStoreSearchResponse(
|
||||||
file_id=chunk.metadata.get("file_id", ""),
|
file_id=chunk.metadata.get("document_id", ""),
|
||||||
filename=chunk.metadata.get("filename", ""),
|
filename=chunk.metadata.get("filename", ""),
|
||||||
score=score,
|
score=score,
|
||||||
attributes=chunk.metadata,
|
attributes=chunk.metadata,
|
||||||
|
|
@ -746,12 +758,15 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
|
|
||||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||||
|
|
||||||
|
chunk_attributes = attributes.copy()
|
||||||
|
chunk_attributes["filename"] = file_response.filename
|
||||||
|
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
file_id,
|
file_id,
|
||||||
content,
|
content,
|
||||||
max_chunk_size_tokens,
|
max_chunk_size_tokens,
|
||||||
chunk_overlap_tokens,
|
chunk_overlap_tokens,
|
||||||
attributes,
|
chunk_attributes,
|
||||||
)
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
vector_store_file_object.status = "failed"
|
vector_store_file_object.status = "failed"
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ from pydantic import BaseModel
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
TextContentItem,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import RAGDocument
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
|
@ -129,26 +128,6 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def concat_interleaved_content(content: list[InterleavedContent]) -> InterleavedContent:
|
|
||||||
"""concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list"""
|
|
||||||
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
def _process(c):
|
|
||||||
if isinstance(c, str):
|
|
||||||
ret.append(TextContentItem(text=c))
|
|
||||||
elif isinstance(c, list):
|
|
||||||
for item in c:
|
|
||||||
_process(item)
|
|
||||||
else:
|
|
||||||
ret.append(c)
|
|
||||||
|
|
||||||
for c in content:
|
|
||||||
_process(c)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
async def content_from_doc(doc: RAGDocument) -> str:
|
async def content_from_doc(doc: RAGDocument) -> str:
|
||||||
if isinstance(doc.content, URL):
|
if isinstance(doc.content, URL):
|
||||||
if doc.content.uri.startswith("data:"):
|
if doc.content.uri.startswith("data:"):
|
||||||
|
|
|
||||||
|
|
@ -221,8 +221,8 @@ fi
|
||||||
cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
||||||
--network llama-net \
|
--network llama-net \
|
||||||
-p "${PORT}:${PORT}" \
|
-p "${PORT}:${PORT}" \
|
||||||
"${SERVER_IMAGE}" --port "${PORT}" \
|
-e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
|
||||||
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}")
|
"${SERVER_IMAGE}" --port "${PORT}")
|
||||||
|
|
||||||
log "🦙 Starting Llama Stack..."
|
log "🦙 Starting Llama Stack..."
|
||||||
if ! execute_with_log $ENGINE "${cmd[@]}"; then
|
if ! execute_with_log $ENGINE "${cmd[@]}"; then
|
||||||
|
|
|
||||||
|
|
@ -191,9 +191,11 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
||||||
echo "Llama Stack Server is already running, skipping start"
|
echo "Llama Stack Server is already running, skipping start"
|
||||||
else
|
else
|
||||||
echo "=== Starting Llama Stack Server ==="
|
echo "=== Starting Llama Stack Server ==="
|
||||||
# Set a reasonable log width for better readability in server.log
|
|
||||||
export LLAMA_STACK_LOG_WIDTH=120
|
export LLAMA_STACK_LOG_WIDTH=120
|
||||||
nohup llama stack run ci-tests --image-type venv > server.log 2>&1 &
|
|
||||||
|
# remove "server:" from STACK_CONFIG
|
||||||
|
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
||||||
|
nohup llama stack run $stack_config > server.log 2>&1 &
|
||||||
|
|
||||||
echo "Waiting for Llama Stack Server to start..."
|
echo "Waiting for Llama Stack Server to start..."
|
||||||
for i in {1..30}; do
|
for i in {1..30}; do
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,19 @@
|
||||||
|
|
||||||
set -Eeuo pipefail
|
set -Eeuo pipefail
|
||||||
|
|
||||||
CONTAINER_RUNTIME=${CONTAINER_RUNTIME:-docker}
|
if command -v podman &> /dev/null; then
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
CONTAINER_RUNTIME="podman"
|
||||||
|
elif command -v docker &> /dev/null; then
|
||||||
|
CONTAINER_RUNTIME="docker"
|
||||||
|
else
|
||||||
|
echo "🚨 Neither Podman nor Docker could be found"
|
||||||
|
echo "Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
echo "🚀 Setting up telemetry stack for Llama Stack using Podman..."
|
echo "🚀 Setting up telemetry stack for Llama Stack using $CONTAINER_RUNTIME..."
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then
|
if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then
|
||||||
echo "🚨 $CONTAINER_RUNTIME could not be found"
|
echo "🚨 $CONTAINER_RUNTIME could not be found"
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue