mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
Merge remote-tracking branch 'origin/main' into test_isolation_server
This commit is contained in:
commit
889b2716ef
107 changed files with 817 additions and 1298 deletions
2
.github/workflows/integration-auth-tests.yml
vendored
2
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -86,7 +86,7 @@ jobs:
|
|||
|
||||
# avoid line breaks in the server log, especially because we grep it below.
|
||||
export COLUMNS=1984
|
||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||
nohup uv run llama stack run $run_dir/run.yaml > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
|
|
|
|||
2
.github/workflows/stale_bot.yml
vendored
2
.github/workflows/stale_bot.yml
vendored
|
|
@ -24,7 +24,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Stale Action
|
||||
uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
|
||||
uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
|
||||
with:
|
||||
stale-issue-label: 'stale'
|
||||
stale-issue-message: >
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ jobs:
|
|||
# Use the virtual environment created by the build step (name comes from build config)
|
||||
source ramalama-stack-test/bin/activate
|
||||
uv pip list
|
||||
nohup llama stack run tests/external/ramalama-stack/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||
nohup llama stack run tests/external/ramalama-stack/run.yaml > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
|
|
|
|||
2
.github/workflows/test-external.yml
vendored
2
.github/workflows/test-external.yml
vendored
|
|
@ -59,7 +59,7 @@ jobs:
|
|||
# Use the virtual environment created by the build step (name comes from build config)
|
||||
source ci-test/bin/activate
|
||||
uv pip list
|
||||
nohup llama stack run tests/external/run-byoa.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||
nohup llama stack run tests/external/run-byoa.yaml > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ You can access the HuggingFace trainer via the `starter` distribution:
|
|||
|
||||
```bash
|
||||
llama stack build --distro starter --image-type venv
|
||||
llama stack run --image-type venv ~/.llama/distributions/starter/starter-run.yaml
|
||||
llama stack run ~/.llama/distributions/starter/starter-run.yaml
|
||||
```
|
||||
|
||||
### Usage Example
|
||||
|
|
|
|||
|
|
@ -219,13 +219,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
|||
<TabItem value="setup" label="Setup & Configuration">
|
||||
|
||||
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
||||
2. [Optional] Provide the API key directly to the Llama Stack server
|
||||
2. [Optional] Set the API key in your environment before starting the Llama Stack server
|
||||
```bash
|
||||
export TAVILY_SEARCH_API_KEY="your key"
|
||||
```
|
||||
```bash
|
||||
--env TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY}
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="implementation" label="Implementation">
|
||||
|
|
@ -273,9 +270,9 @@ for log in EventLogger().log(response):
|
|||
<TabItem value="setup" label="Setup & Configuration">
|
||||
|
||||
1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access).
|
||||
2. Provide the API key either when starting the Llama Stack server:
|
||||
2. Provide the API key either by setting it in your environment before starting the Llama Stack server:
|
||||
```bash
|
||||
--env WOLFRAM_ALPHA_API_KEY=${WOLFRAM_ALPHA_API_KEY}
|
||||
export WOLFRAM_ALPHA_API_KEY="your key"
|
||||
```
|
||||
or from the client side:
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ Integration tests are located in [tests/integration](https://github.com/meta-lla
|
|||
Consult [tests/integration/README.md](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more details on how to run the tests.
|
||||
|
||||
Note that each provider's `sample_run_config()` method (in the configuration class for that provider)
|
||||
typically references some environment variables for specifying API keys and the like. You can set these in the environment or pass these via the `--env` flag to the test command.
|
||||
typically references some environment variables for specifying API keys and the like. You can set these in the environment before running the test command.
|
||||
|
||||
|
||||
### 2. Unit Testing
|
||||
|
|
|
|||
|
|
@ -289,10 +289,10 @@ After this step is successful, you should be able to find the built container im
|
|||
docker run -d \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e OLLAMA_URL=http://host.docker.internal:11434 \
|
||||
localhost/distribution-ollama:dev \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env OLLAMA_URL=http://host.docker.internal:11434
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
Here are the docker flags and their uses:
|
||||
|
|
@ -305,12 +305,12 @@ Here are the docker flags and their uses:
|
|||
|
||||
* `localhost/distribution-ollama:dev`: The name and tag of the container image to run
|
||||
|
||||
* `-e INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the INFERENCE_MODEL environment variable in the container
|
||||
|
||||
* `-e OLLAMA_URL=http://host.docker.internal:11434`: Sets the OLLAMA_URL environment variable in the container
|
||||
|
||||
* `--port $LLAMA_STACK_PORT`: Port number for the server to listen on
|
||||
|
||||
* `--env INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the model to use for inference
|
||||
|
||||
* `--env OLLAMA_URL=http://host.docker.internal:11434`: Configures the URL for the Ollama service
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
|
@ -320,23 +320,22 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
|
|||
|
||||
```
|
||||
llama stack run -h
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE]
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME]
|
||||
[--image-type {venv}] [--enable-ui]
|
||||
[config | template]
|
||||
[config | distro]
|
||||
|
||||
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
||||
|
||||
positional arguments:
|
||||
config | template Path to config file to use for the run or name of known template (`llama stack list` for a list). (default: None)
|
||||
config | distro Path to config file to use for the run or name of known distro (`llama stack list` for a list). (default: None)
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
||||
--image-name IMAGE_NAME
|
||||
Name of the image to run. Defaults to the current environment (default: None)
|
||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: None)
|
||||
[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running. (default: None)
|
||||
--image-type {venv}
|
||||
Image Type used during the build. This should be venv. (default: None)
|
||||
[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running. (default: None)
|
||||
--enable-ui Start the UI server (default: False)
|
||||
```
|
||||
|
||||
|
|
@ -348,9 +347,6 @@ llama stack run tgi
|
|||
|
||||
# Start using config file
|
||||
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||
|
||||
# Start using a venv
|
||||
llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||
```
|
||||
|
||||
```
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ A few things to note:
|
|||
- The id is a string you can choose freely.
|
||||
- You can instantiate any number of provider instances of the same type.
|
||||
- The configuration dictionary is provider-specific.
|
||||
- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server (via docker or via `llama stack run`), you can specify `--env OLLAMA_URL=http://my-server:11434` to override the default value.
|
||||
- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server, you can set environment variables in your shell before running `llama stack run` to override the default values.
|
||||
|
||||
### Environment Variable Substitution
|
||||
|
||||
|
|
@ -173,13 +173,10 @@ optional_token: ${env.OPTIONAL_TOKEN:+}
|
|||
|
||||
#### Runtime Override
|
||||
|
||||
You can override environment variables at runtime when starting the server:
|
||||
You can override environment variables at runtime by setting them in your shell before starting the server:
|
||||
|
||||
```bash
|
||||
# Override specific environment variables
|
||||
llama stack run --config run.yaml --env API_KEY=sk-123 --env BASE_URL=https://custom-api.com
|
||||
|
||||
# Or set them in your shell
|
||||
# Set environment variables in your shell
|
||||
export API_KEY=sk-123
|
||||
export BASE_URL=https://custom-api.com
|
||||
llama stack run --config run.yaml
|
||||
|
|
|
|||
|
|
@ -69,10 +69,10 @@ docker run \
|
|||
-it \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
-e WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||
-e WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||
-e WATSONX_BASE_URL=$WATSONX_BASE_URL \
|
||||
llamastack/distribution-watsonx \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
|
|
|||
|
|
@ -129,11 +129,11 @@ docker run -it \
|
|||
# NOTE: mount the llama-stack / llama-model directories if testing local changes else not needed
|
||||
-v $HOME/git/llama-stack:/app/llama-stack-source -v $HOME/git/llama-models:/app/llama-models-source \
|
||||
# localhost/distribution-dell:dev if building / testing locally
|
||||
llamastack/distribution-dell\
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-dell \
|
||||
--port $LLAMA_STACK_PORT
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -154,14 +154,14 @@ docker run \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e SAFETY_MODEL=$SAFETY_MODEL \
|
||||
-e DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-dell \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
|
@ -170,21 +170,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
|
|||
|
||||
```bash
|
||||
llama stack build --distro dell --image-type venv
|
||||
llama stack run dell
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
DEH_URL=$DEH_URL \
|
||||
CHROMA_URL=$CHROMA_URL \
|
||||
llama stack run dell \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
DEH_URL=$DEH_URL \
|
||||
SAFETY_MODEL=$SAFETY_MODEL \
|
||||
DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
CHROMA_URL=$CHROMA_URL \
|
||||
llama stack run ./run-with-safety.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
|
|
|||
|
|
@ -84,9 +84,9 @@ docker run \
|
|||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
llamastack/distribution-meta-reference-gpu \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
|
@ -98,10 +98,10 @@ docker run \
|
|||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
-e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||
llamastack/distribution-meta-reference-gpu \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
|
@ -110,16 +110,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
|||
|
||||
```bash
|
||||
llama stack build --distro meta-reference-gpu --image-type venv
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
llama stack run distributions/meta-reference-gpu/run.yaml \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
--port 8321
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
--port 8321
|
||||
```
|
||||
|
|
|
|||
|
|
@ -129,10 +129,10 @@ docker run \
|
|||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-nvidia \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
|
@ -142,10 +142,10 @@ If you've set up your local development environment, you can also build the imag
|
|||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
--port 8321
|
||||
```
|
||||
|
||||
## Example Notebooks
|
||||
|
|
|
|||
|
|
@ -86,9 +86,9 @@ docker run -it \
|
|||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e OLLAMA_URL=http://host.docker.internal:11434 \
|
||||
llamastack/distribution-starter \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env OLLAMA_URL=http://host.docker.internal:11434
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
|
||||
`podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL`
|
||||
|
|
@ -106,9 +106,9 @@ docker run -it \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
--network=host \
|
||||
-e OLLAMA_URL=http://localhost:11434 \
|
||||
llamastack/distribution-starter \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env OLLAMA_URL=http://localhost:11434
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
:::
|
||||
You will see output like below:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | API key for Anthropic models |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure |
|
||||
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
|
||||
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
||||
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
||||
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
|
||||
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Databricks inference provider for running models on Databricks' unified analytic
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
|
||||
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | API key for Gemini models |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | The Groq API key |
|
||||
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | The Llama API key |
|
||||
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
|
||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ Ollama inference provider for running local models through the Ollama runtime.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
||||
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Passthrough inference provider for connecting to any external inference service
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
||||
| `api_token` | `str \| None` | No | | The API token |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | | The URL for the TGI serving endpoint |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Together AI inference provider for open-source models and collaborative AI devel
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ Available Models:
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI |
|
||||
| `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI |
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
||||
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
||||
| `api_token` | `str \| None` | No | fake | The API token |
|
||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key |
|
||||
| `project_id` | `str \| None` | No | | The Project ID key |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
|
||||
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
||||
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
||||
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
||||
|
|
|
|||
|
|
@ -123,12 +123,12 @@
|
|||
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
||||
"\n",
|
||||
"# this command installs all the dependencies needed for the llama stack server with the together inference provider\n",
|
||||
"!uv run --with llama-stack llama stack build --distro together --image-type venv\n",
|
||||
"!uv run --with llama-stack llama stack build --distro together\n",
|
||||
"\n",
|
||||
"def run_llama_stack_server_background():\n",
|
||||
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
||||
" process = subprocess.Popen(\n",
|
||||
" \"uv run --with llama-stack llama stack run together --image-type venv\",\n",
|
||||
" \"uv run --with llama-stack llama stack run together\",\n",
|
||||
" shell=True,\n",
|
||||
" stdout=log_file,\n",
|
||||
" stderr=log_file,\n",
|
||||
|
|
|
|||
|
|
@ -233,12 +233,12 @@
|
|||
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
||||
"\n",
|
||||
"# this command installs all the dependencies needed for the llama stack server\n",
|
||||
"!uv run --with llama-stack llama stack build --distro meta-reference-gpu --image-type venv\n",
|
||||
"!uv run --with llama-stack llama stack build --distro meta-reference-gpu\n",
|
||||
"\n",
|
||||
"def run_llama_stack_server_background():\n",
|
||||
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
||||
" process = subprocess.Popen(\n",
|
||||
" f\"uv run --with llama-stack llama stack run meta-reference-gpu --image-type venv --env INFERENCE_MODEL={model_id}\",\n",
|
||||
" f\"INFERENCE_MODEL={model_id} uv run --with llama-stack llama stack run meta-reference-gpu\",\n",
|
||||
" shell=True,\n",
|
||||
" stdout=log_file,\n",
|
||||
" stderr=log_file,\n",
|
||||
|
|
|
|||
|
|
@ -223,12 +223,12 @@
|
|||
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
||||
"\n",
|
||||
"# this command installs all the dependencies needed for the llama stack server\n",
|
||||
"!uv run --with llama-stack llama stack build --distro llama_api --image-type venv\n",
|
||||
"!uv run --with llama-stack llama stack build --distro llama_api\n",
|
||||
"\n",
|
||||
"def run_llama_stack_server_background():\n",
|
||||
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
||||
" process = subprocess.Popen(\n",
|
||||
" \"uv run --with llama-stack llama stack run llama_api --image-type venv\",\n",
|
||||
" \"uv run --with llama-stack llama stack run llama_api\",\n",
|
||||
" shell=True,\n",
|
||||
" stdout=log_file,\n",
|
||||
" stderr=log_file,\n",
|
||||
|
|
|
|||
|
|
@ -145,12 +145,12 @@
|
|||
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
|
||||
"\n",
|
||||
"# this command installs all the dependencies needed for the llama stack server with the ollama inference provider\n",
|
||||
"!uv run --with llama-stack llama stack build --distro starter --image-type venv\n",
|
||||
"!uv run --with llama-stack llama stack build --distro starter\n",
|
||||
"\n",
|
||||
"def run_llama_stack_server_background():\n",
|
||||
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
|
||||
" process = subprocess.Popen(\n",
|
||||
" f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter --image-type venv\n",
|
||||
" f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter\n",
|
||||
" shell=True,\n",
|
||||
" stdout=log_file,\n",
|
||||
" stderr=log_file,\n",
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
...
|
||||
Build Successful!
|
||||
You can find the newly-built template here: ~/.llama/distributions/starter/starter-run.yaml
|
||||
You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter --image-type venv
|
||||
You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter
|
||||
```
|
||||
|
||||
3. **Set the ENV variables by exporting them to the terminal**:
|
||||
|
|
@ -102,12 +102,11 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
3. **Run the Llama Stack**:
|
||||
Run the stack using uv:
|
||||
```bash
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
SAFETY_MODEL=$SAFETY_MODEL \
|
||||
OLLAMA_URL=$OLLAMA_URL \
|
||||
uv run --with llama-stack llama stack run starter \
|
||||
--image-type venv \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env OLLAMA_URL=$OLLAMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
|
||||
|
||||
|
|
|
|||
|
|
@ -444,12 +444,24 @@ def _run_stack_build_command_from_build_config(
|
|||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
cprint(
|
||||
"You can run the new Llama Stack distro via: "
|
||||
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if build_config.image_type == LlamaStackImageType.VENV:
|
||||
cprint(
|
||||
"You can run the new Llama Stack distro (after activating "
|
||||
+ colored(image_name, "cyan")
|
||||
+ ") via: "
|
||||
+ colored(f"llama stack run {run_config_file}", "blue"),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
elif build_config.image_type == LlamaStackImageType.CONTAINER:
|
||||
cprint(
|
||||
"You can run the container with: "
|
||||
+ colored(
|
||||
f"docker run -p 8321:8321 -v ~/.llama:/root/.llama localhost/{image_name} --port 8321", "blue"
|
||||
),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return distro_path
|
||||
else:
|
||||
return _generate_run_config(build_config, build_dir, image_name)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import yaml
|
|||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars, validate_env_pair
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
|
@ -55,18 +55,12 @@ class StackRun(Subcommand):
|
|||
"--image-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the image to run. Defaults to the current environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
|
||||
metavar="KEY=VALUE",
|
||||
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-type",
|
||||
type=str,
|
||||
help="Image Type used during the build. This can be only venv.",
|
||||
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
|
@ -75,48 +69,22 @@ class StackRun(Subcommand):
|
|||
help="Start the UI server",
|
||||
)
|
||||
|
||||
def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and distribution name from args.config"""
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
distro_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a distribution
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
distro_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, distro_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
||||
|
||||
if args.image_type or args.image_name:
|
||||
self.parser.error(
|
||||
"The --image-type and --image-name flags are no longer supported.\n\n"
|
||||
"Please activate your virtual environment manually before running `llama stack run`.\n\n"
|
||||
"For example:\n"
|
||||
" source /path/to/venv/bin/activate\n"
|
||||
" llama stack run <config>\n"
|
||||
)
|
||||
|
||||
if args.enable_ui:
|
||||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = args.image_type, args.image_name
|
||||
|
||||
if args.config:
|
||||
try:
|
||||
|
|
@ -128,10 +96,6 @@ class StackRun(Subcommand):
|
|||
else:
|
||||
config_file = None
|
||||
|
||||
# Check if config is required based on image type
|
||||
if image_type == ImageType.VENV.value and not config_file:
|
||||
self.parser.error("Config file is required for venv environment")
|
||||
|
||||
if config_file:
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
||||
|
|
@ -146,50 +110,13 @@ class StackRun(Subcommand):
|
|||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
else:
|
||||
config = None
|
||||
|
||||
# If neither image type nor image name is provided, assume the server should be run directly
|
||||
# using the current environment packages.
|
||||
if not image_type and not image_name:
|
||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
||||
self._uvicorn_run(config_file, args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
if config_file:
|
||||
run_args.extend(["--config", str(config_file)])
|
||||
|
||||
if args.env:
|
||||
for env_var in args.env:
|
||||
if "=" not in env_var:
|
||||
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
|
||||
return
|
||||
key, value = env_var.split("=", 1) # split on first = only
|
||||
if not key:
|
||||
self.parser.error(f"Environment variable '{env_var}' has empty key")
|
||||
return
|
||||
run_args.extend(["--env", f"{key}={value}"])
|
||||
|
||||
run_command(run_args)
|
||||
self._uvicorn_run(config_file, args)
|
||||
|
||||
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
|
||||
if not config_file:
|
||||
self.parser.error("Config file is required")
|
||||
|
||||
# Set environment variables if provided
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
self.parser.error(f"Invalid environment variable format: {env_pair}")
|
||||
|
||||
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import (
|
|||
sqlstore_impl,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::conversations")
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -611,7 +611,7 @@ class InferenceRouter(Inference):
|
|||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry and chunk.usage:
|
||||
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
|
|
@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def has_model(self, model_id: str) -> bool:
|
||||
"""
|
||||
Check if a model exists in the routing table.
|
||||
|
||||
:param model_id: The model identifier to check
|
||||
:return: True if the model exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
await lookup_model(self, model_id)
|
||||
return True
|
||||
except ModelNotFoundError:
|
||||
return False
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
|
|
@ -274,22 +274,6 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|||
return config_dict
|
||||
|
||||
|
||||
def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||
"""Validate and split an environment variable key-value pair."""
|
||||
try:
|
||||
key, value = env_pair.split("=", 1)
|
||||
key = key.strip()
|
||||
if not key:
|
||||
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
|
||||
return key, value
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
|
||||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ error_handler() {
|
|||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
@ -43,7 +43,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
|||
|
||||
# Initialize variables
|
||||
yaml_config=""
|
||||
env_vars=""
|
||||
other_args=""
|
||||
|
||||
# Process remaining arguments
|
||||
|
|
@ -58,15 +57,6 @@ while [[ $# -gt 0 ]]; do
|
|||
exit 1
|
||||
fi
|
||||
;;
|
||||
--env)
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
|
|
@ -119,7 +109,6 @@ if [[ "$env_type" == "venv" ]]; then
|
|||
llama stack run \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||
|
|
|
|||
|
|
@ -98,7 +98,10 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
existing_obj = await self.get(obj.type, obj.identifier)
|
||||
# dont register if the object's providerid already exists
|
||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||
return False
|
||||
raise ValueError(
|
||||
f"Provider '{obj.provider_id}' is already registered."
|
||||
f"Unregister the existing provider first before registering it again."
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
|
|
|
|||
|
|
@ -117,11 +117,11 @@ docker run -it \
|
|||
# NOTE: mount the llama-stack directory if testing local changes else not needed
|
||||
-v $HOME/git/llama-stack:/app/llama-stack-source \
|
||||
# localhost/distribution-dell:dev if building / testing locally
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-{{ name }}\
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -142,14 +142,14 @@ docker run \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e SAFETY_MODEL=$SAFETY_MODEL \
|
||||
-e DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
|
@ -158,21 +158,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
|
|||
|
||||
```bash
|
||||
llama stack build --distro {{ name }} --image-type conda
|
||||
llama stack run {{ name }}
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
DEH_URL=$DEH_URL \
|
||||
CHROMA_URL=$CHROMA_URL \
|
||||
llama stack run {{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
DEH_URL=$DEH_URL \
|
||||
SAFETY_MODEL=$SAFETY_MODEL \
|
||||
DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
CHROMA_URL=$CHROMA_URL \
|
||||
llama stack run ./run-with-safety.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
|
|
|||
|
|
@ -72,9 +72,9 @@ docker run \
|
|||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
|
@ -86,10 +86,10 @@ docker run \
|
|||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
-e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
|
@ -98,16 +98,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
|||
|
||||
```bash
|
||||
llama stack build --distro {{ name }} --image-type venv
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
llama stack run distributions/{{ name }}/run.yaml \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
--port 8321
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||
llama stack run distributions/{{ name }}/run-with-safety.yaml \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
--port 8321
|
||||
```
|
||||
|
|
|
|||
|
|
@ -118,10 +118,10 @@ docker run \
|
|||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
|
@ -131,10 +131,10 @@ If you've set up your local development environment, you can also build the imag
|
|||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
--port 8321
|
||||
```
|
||||
|
||||
## Example Notebooks
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .watsonx import get_distribution_template # noqa: F401
|
||||
|
|
|
|||
|
|
@ -3,44 +3,33 @@ distribution_spec:
|
|||
description: Use watsonx for running LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
provider_type: remote::watsonx
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
- provider_type: remote::watsonx
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_type: inline::faiss
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- aiosqlite
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
- files
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
|
|
@ -19,8 +19,6 @@ providers:
|
|||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||
api_key: ${env.WATSONX_API_KEY:=}
|
||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
@ -48,7 +46,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -109,102 +107,7 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-3-70b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-2-13b-chat
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-2-13b-chat
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-2-13b
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-2-13b-chat
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-1-70b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-1-8b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-1b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-3b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-guard-3-11b-vision
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
|||
|
|
@ -4,17 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||
|
|
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
config=WatsonXConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"watsonx": MODEL_ENTRIES,
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
|
|
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use watsonx for running LLM inference",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
"inference": [inference_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -31,12 +31,17 @@ CATEGORIES = [
|
|||
"client",
|
||||
"telemetry",
|
||||
"openai_responses",
|
||||
"openai_conversations",
|
||||
"testing",
|
||||
"providers",
|
||||
"models",
|
||||
"files",
|
||||
"vector_io",
|
||||
"tool_runtime",
|
||||
"cli",
|
||||
"post_training",
|
||||
"scoring",
|
||||
"tests",
|
||||
]
|
||||
UNCATEGORIZED = "uncategorized"
|
||||
|
||||
|
|
@ -264,11 +269,12 @@ def get_logger(
|
|||
if root_category in _category_levels:
|
||||
log_level = _category_levels[root_category]
|
||||
else:
|
||||
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
||||
if category != UNCATEGORIZED:
|
||||
logging.warning(
|
||||
f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}"
|
||||
raise ValueError(
|
||||
f"Unknown logging category: {category}. To resolve, choose a valid category from the CATEGORIES list "
|
||||
f"or add it to the CATEGORIES list. Available categories: {CATEGORIES}"
|
||||
)
|
||||
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
||||
logger.setLevel(log_level)
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
|
|
|||
|
|
@ -11,19 +11,13 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
|
@ -175,25 +169,6 @@ def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat
|
|||
return messages
|
||||
|
||||
|
||||
def llama3_1_builtin_tool_call_with_image_dialog(
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
):
|
||||
this_dir = Path(__file__).parent
|
||||
with open(this_dir / "llama3/dog.jpg", "rb") as f:
|
||||
img = f.read()
|
||||
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
messages = interface.system_messages(**system_message_builtin_tools_only())
|
||||
messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")])
|
||||
messages += interface.assistant_response_messages(
|
||||
"Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix",
|
||||
StopReason.end_of_turn,
|
||||
)
|
||||
messages += interface.user_message("Search the web for some food recommendations for the indentified breed")
|
||||
return messages
|
||||
|
||||
|
||||
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
|
|
@ -202,35 +177,6 @@ def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
|||
return messages
|
||||
|
||||
|
||||
def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||
tool_response = json.dumps(["great song1", "awesome song2", "cool song3"])
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
messages = interface.system_messages(**system_message_custom_tools_only())
|
||||
messages += interface.user_message(content="Use tools to get latest trending songs")
|
||||
messages.append(
|
||||
RawMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_id",
|
||||
tool_name="trending_songs",
|
||||
arguments={"n": "10", "genre": "latest"},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
messages.append(
|
||||
RawMessage(
|
||||
role="assistant",
|
||||
content=tool_response,
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def llama3_2_user_assistant_conversation():
|
||||
return UseCase(
|
||||
title="User and assistant conversation",
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
policy,
|
||||
Api.telemetry in deps,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@
|
|||
import copy
|
||||
import json
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
|
|
@ -84,11 +82,6 @@ from llama_stack.providers.utils.telemetry import tracing
|
|||
from .persistence import AgentPersistence
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
|
|
@ -110,6 +103,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
persistence_store: KVStore,
|
||||
created_at: str,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
|
|
@ -120,6 +114,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.created_at = created_at
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
|
|
@ -188,28 +183,30 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
turn_id = str(uuid.uuid4())
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled:
|
||||
span = tracing.get_current_span()
|
||||
if span is not None:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled:
|
||||
span = tracing.get_current_span()
|
||||
if span is not None:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
|
|
@ -395,9 +392,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
|
|
@ -430,7 +430,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
|
|
@ -453,7 +454,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", "no violations")
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", "no violations")
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
|
|
@ -518,8 +520,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason: StopReason | None = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled and span is not None:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
def _serialize_nested(value):
|
||||
"""Recursively serialize nested Pydantic models to dicts."""
|
||||
|
|
@ -637,18 +640,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||
|
||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
)
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
)
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
||||
n_iter += 1
|
||||
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||
|
|
@ -756,7 +760,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
}
|
||||
if self.telemetry_enabled
|
||||
else {},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
|
|
@ -771,7 +777,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
call_id=tool_call.call_id,
|
||||
content=tool_result.content,
|
||||
)
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
# Store tool execution step
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
|
@ -71,6 +72,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
|
|
@ -135,6 +137,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
),
|
||||
created_at=agent_info.created_at,
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@ class OpenAIResponsesImpl:
|
|||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
inputs=input,
|
||||
inputs=all_input,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
|
|
|
|||
|
|
@ -97,6 +97,8 @@ class StreamingResponseOrchestrator:
|
|||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
self.citation_files: dict[str, str] = {}
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
|
|
@ -126,6 +128,7 @@ class StreamingResponseOrchestrator:
|
|||
# Text is the default response format for chat completion so don't need to pass it
|
||||
# (some providers don't support non-empty response_format when tools are present)
|
||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
|
|
@ -160,7 +163,7 @@ class StreamingResponseOrchestrator:
|
|||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
|
|
@ -172,6 +175,8 @@ class StreamingResponseOrchestrator:
|
|||
):
|
||||
yield stream_event
|
||||
|
||||
messages = next_turn_messages
|
||||
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
|
||||
|
|
@ -184,9 +189,7 @@ class StreamingResponseOrchestrator:
|
|||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
||||
break
|
||||
|
||||
messages = next_turn_messages
|
||||
|
||||
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
||||
self.final_messages = messages.copy()
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
|
|
@ -211,6 +214,8 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
for choice in current_response.choices:
|
||||
next_turn_messages.append(choice.message)
|
||||
logger.debug(f"Choice message content: {choice.message.content}")
|
||||
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
|
||||
|
||||
if choice.message.tool_calls and self.ctx.response_tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
|
|
@ -227,9 +232,11 @@ class StreamingResponseOrchestrator:
|
|||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
||||
next_turn_messages.pop()
|
||||
else:
|
||||
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
||||
approvals.append(tool_call)
|
||||
next_turn_messages.pop()
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
|
|
@ -470,6 +477,8 @@ class StreamingResponseOrchestrator:
|
|||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
self.sequence_number = result.sequence_number
|
||||
if result.citation_files:
|
||||
self.citation_files.update(result.citation_files)
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
|
|
|||
|
|
@ -94,7 +94,10 @@ class ToolExecutor:
|
|||
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
||||
sequence_number=sequence_number,
|
||||
final_output_message=output_message,
|
||||
final_input_message=input_message,
|
||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
|
|
@ -129,8 +132,6 @@ class ToolExecutor:
|
|||
for results in all_results:
|
||||
search_results.extend(results)
|
||||
|
||||
# Convert search results to tool result format matching memory.py
|
||||
# Format the results as interleaved content similar to memory.py
|
||||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
|
|
@ -138,27 +139,58 @@ class ToolExecutor:
|
|||
)
|
||||
)
|
||||
|
||||
unique_files = set()
|
||||
for i, result_item in enumerate(search_results):
|
||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||
# Get file_id from attributes if result_item.file_id is empty
|
||||
file_id = result_item.file_id or (
|
||||
result_item.attributes.get("document_id") if result_item.attributes else None
|
||||
)
|
||||
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
|
||||
if result_item.attributes:
|
||||
metadata_text += f", attributes: {result_item.attributes}"
|
||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||
|
||||
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
|
||||
content_items.append(TextContentItem(text=text_content))
|
||||
unique_files.add(file_id)
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
|
||||
citation_instruction = ""
|
||||
if unique_files:
|
||||
citation_instruction = (
|
||||
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
|
||||
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
|
||||
)
|
||||
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
|
||||
)
|
||||
)
|
||||
|
||||
# handling missing attributes for old versions
|
||||
citation_files = {}
|
||||
for result in search_results:
|
||||
file_id = result.file_id
|
||||
if not file_id and result.attributes:
|
||||
file_id = result.attributes.get("document_id")
|
||||
|
||||
filename = result.filename
|
||||
if not filename and result.attributes:
|
||||
filename = result.attributes.get("filename")
|
||||
if not filename:
|
||||
filename = "unknown"
|
||||
|
||||
citation_files[file_id] = filename
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
"citation_files": citation_files,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel):
|
|||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
citation_files: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
|
|
@ -45,7 +47,9 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
async def convert_chat_choice_to_response_message(
|
||||
choice: OpenAIChoice, citation_files: dict[str, str] | None = None
|
||||
) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
|
|
@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
|
|||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
|
@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str):
|
|||
return role_to_type.get(role)
|
||||
|
||||
|
||||
def _extract_citations_from_text(
|
||||
text: str, citation_files: dict[str, str]
|
||||
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
||||
"""Extract citation markers from text and create annotations
|
||||
|
||||
Args:
|
||||
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
||||
citation_files: Dictionary mapping file_id to filename
|
||||
|
||||
Returns:
|
||||
Tuple of (annotations_list, clean_text_without_markers)
|
||||
"""
|
||||
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
||||
|
||||
annotations = []
|
||||
parts = []
|
||||
total_len = 0
|
||||
last_end = 0
|
||||
|
||||
for m in file_id_regex.finditer(text):
|
||||
# segment before the marker
|
||||
prefix = text[last_end : m.start()]
|
||||
|
||||
# drop one space if it exists (since marker is at sentence end)
|
||||
if prefix.endswith(" "):
|
||||
prefix = prefix[:-1]
|
||||
|
||||
parts.append(prefix)
|
||||
total_len += len(prefix)
|
||||
|
||||
fid = m.group(1)
|
||||
if fid in citation_files:
|
||||
annotations.append(
|
||||
OpenAIResponseAnnotationFileCitation(
|
||||
file_id=fid,
|
||||
filename=citation_files[fid],
|
||||
index=total_len, # index points to punctuation
|
||||
)
|
||||
)
|
||||
|
||||
last_end = m.end()
|
||||
|
||||
parts.append(text[last_end:])
|
||||
cleaned_text = "".join(parts)
|
||||
return annotations, cleaned_text
|
||||
|
||||
|
||||
def is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import mimetypes
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
|
@ -52,10 +50,6 @@ from .context_retriever import generate_rag_query
|
|||
log = get_logger(name=__name__, category="tool_runtime")
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||
if isinstance(doc.content, URL):
|
||||
|
|
@ -331,5 +325,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
|
||||
return ToolInvocationResult(
|
||||
content=result.content or [],
|
||||
metadata=result.metadata,
|
||||
metadata={
|
||||
**(result.metadata or {}),
|
||||
"citation_files": getattr(result, "citation_files", None),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -225,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Cleanup if needed
|
||||
pass
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -434,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# nothing to do since we don't maintain a persistent connection
|
||||
pass
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
|
|
|
|||
|
|
@ -36,6 +36,9 @@ def available_providers() -> list[ProviderSpec]:
|
|||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
],
|
||||
optional_api_dependencies=[
|
||||
Api.telemetry,
|
||||
],
|
||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -268,7 +268,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import (
|
|||
ProviderSpec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
|
|
@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
pip_packages=[
|
||||
"chardet",
|
||||
"pypdf",
|
||||
pip_packages=DEFAULT_VECTOR_IO_DEPS
|
||||
+ [
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"scikit-learn",
|
||||
|
|
|
|||
|
|
@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import (
|
|||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
# Common dependencies for all vector IO providers that support document processing
|
||||
DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"]
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=["faiss-cpu"],
|
||||
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
|
|
@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::faiss",
|
||||
pip_packages=["faiss-cpu"],
|
||||
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -82,7 +85,7 @@ more details about Faiss in general.
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite-vec",
|
||||
pip_packages=["sqlite-vec"],
|
||||
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite_vec",
|
||||
pip_packages=["sqlite-vec"],
|
||||
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
|
|
@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation.
|
|||
api=Api.vector_io,
|
||||
adapter_type="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::chromadb",
|
||||
pip_packages=["chromadb"],
|
||||
pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
api=Api.vector_io,
|
||||
adapter_type="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client>=4.16.5"],
|
||||
pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
|
|
@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
|
|||
api=Api.vector_io,
|
||||
adapter_type="qdrant",
|
||||
provider_type="remote::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -607,7 +610,7 @@ Please refer to the inline provider documentation.
|
|||
api=Api.vector_io,
|
||||
adapter_type="milvus",
|
||||
provider_type="remote::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -813,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
|
|||
|
|
@ -41,9 +41,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -1,217 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
)
|
||||
from openai.types.completion import Completion as OpenAICompletion
|
||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
TokenLogProbs,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
_convert_openai_finish_reason,
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
|
||||
|
||||
async def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
n: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# model -> model
|
||||
# messages -> messages
|
||||
# sampling_params TODO(mattf): review strategy
|
||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
||||
# strategy=top_k -> nvext.top_k = top_k
|
||||
# temperature -> temperature
|
||||
# top_p -> top_p
|
||||
# top_k -> nvext.top_k
|
||||
# max_tokens -> max_tokens
|
||||
# repetition_penalty -> nvext.repetition_penalty
|
||||
# response_format -> GrammarResponseFormat TODO(mf)
|
||||
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
|
||||
# tools -> tools
|
||||
# tool_choice ("auto", "required") -> tool_choice
|
||||
# tool_prompt_format -> TBD
|
||||
# stream -> stream
|
||||
# logprobs -> logprobs
|
||||
|
||||
if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
raise ValueError(
|
||||
f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
nvext = {}
|
||||
payload: dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
|
||||
stream=request.stream,
|
||||
n=n,
|
||||
extra_body=dict(nvext=nvext),
|
||||
extra_headers={
|
||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||
},
|
||||
)
|
||||
|
||||
if request.response_format:
|
||||
# server bug - setting guided_json changes the behavior of response_format resulting in an error
|
||||
# payload.update(response_format="json_object")
|
||||
nvext.update(guided_json=request.response_format.json_schema)
|
||||
|
||||
if request.tools:
|
||||
payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools])
|
||||
if request.tool_config.tool_choice:
|
||||
payload.update(
|
||||
tool_choice=request.tool_config.tool_choice.value
|
||||
) # we cannot include tool_choice w/o tools, server will complain
|
||||
|
||||
if request.logprobs:
|
||||
payload.update(logprobs=True)
|
||||
payload.update(top_logprobs=request.logprobs.top_k)
|
||||
|
||||
if request.sampling_params:
|
||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||
|
||||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
strategy = request.sampling_params.strategy
|
||||
if isinstance(strategy, TopPSamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=strategy.top_p)
|
||||
payload.update(temperature=strategy.temperature)
|
||||
elif isinstance(strategy, TopKSamplingStrategy):
|
||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=strategy.top_k)
|
||||
elif isinstance(strategy, GreedySamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy: {strategy}")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def convert_completion_request(
|
||||
request: CompletionRequest,
|
||||
n: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# model -> model
|
||||
# prompt -> prompt
|
||||
# sampling_params TODO(mattf): review strategy
|
||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
||||
# strategy=top_k -> nvext.top_k = top_k
|
||||
# temperature -> temperature
|
||||
# top_p -> top_p
|
||||
# top_k -> nvext.top_k
|
||||
# max_tokens -> max_tokens
|
||||
# repetition_penalty -> nvext.repetition_penalty
|
||||
# response_format -> nvext.guided_json
|
||||
# stream -> stream
|
||||
# logprobs.top_k -> logprobs
|
||||
|
||||
nvext = {}
|
||||
payload: dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
prompt=request.content,
|
||||
stream=request.stream,
|
||||
extra_body=dict(nvext=nvext),
|
||||
extra_headers={
|
||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||
},
|
||||
n=n,
|
||||
)
|
||||
|
||||
if request.response_format:
|
||||
# this is not openai compliant, it is a nim extension
|
||||
nvext.update(guided_json=request.response_format.json_schema)
|
||||
|
||||
if request.logprobs:
|
||||
payload.update(logprobs=request.logprobs.top_k)
|
||||
|
||||
if request.sampling_params:
|
||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||
|
||||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
if request.sampling_params.strategy == "top_p":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(temperature=request.sampling_params.temperature)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _convert_openai_completion_logprobs(
|
||||
logprobs: OpenAICompletionLogprobs | None,
|
||||
) -> list[TokenLogProbs] | None:
|
||||
"""
|
||||
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
||||
"""
|
||||
if not logprobs:
|
||||
return None
|
||||
|
||||
return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs]
|
||||
|
||||
|
||||
def convert_openai_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> CompletionResponse:
|
||||
"""
|
||||
Convert an OpenAI Completion Choice into a CompletionResponse.
|
||||
"""
|
||||
return CompletionResponse(
|
||||
content=choice.text,
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
||||
async def convert_openai_completion_stream(
|
||||
stream: AsyncStream[OpenAICompletion],
|
||||
) -> AsyncGenerator[CompletionResponse, None]:
|
||||
"""
|
||||
Convert a stream of OpenAI Completions into a stream
|
||||
of ChatCompletionResponseStreamChunks.
|
||||
"""
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=choice.text,
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
|
@ -4,53 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
return "integrate.api.nvidia.com" in config.url
|
||||
|
||||
|
||||
async def _get_health(url: str) -> tuple[bool, bool]:
|
||||
"""
|
||||
Query {url}/v1/health/{live,ready} to check if the server is running and ready
|
||||
|
||||
Args:
|
||||
url (str): URL of the server
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool]: (is_live, is_ready)
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
live = await client.get(f"{url}/v1/health/live")
|
||||
ready = await client.get(f"{url}/v1/health/ready")
|
||||
return live.status_code == 200, ready.status_code == 200
|
||||
|
||||
|
||||
async def check_health(config: NVIDIAConfig) -> None:
|
||||
"""
|
||||
Check if the server is running and ready
|
||||
|
||||
Args:
|
||||
url (str): URL of the server
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the server is not running or ready
|
||||
"""
|
||||
if not _is_nvidia_hosted(config):
|
||||
logger.info("Checking NVIDIA NIM health...")
|
||||
try:
|
||||
is_live, is_ready = await _get_health(config.url)
|
||||
if not is_live:
|
||||
raise ConnectionError("NVIDIA NIM is not running")
|
||||
if not is_ready:
|
||||
raise ConnectionError("NVIDIA NIM is not ready")
|
||||
# TODO(mf): should we wait for the server to be ready?
|
||||
except httpx.ConnectError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
|
@ -15,10 +13,6 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
|||
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -72,9 +72,6 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
|||
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,6 @@ async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
|||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = RunpodInferenceAdapter(config)
|
||||
impl = RunpodInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,69 +4,86 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
|
||||
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
|
||||
RUNPOD_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
||||
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
}
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
class RunpodInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
Adapter for RunPod's OpenAI-compatible API endpoints.
|
||||
Supports VLLM for serverless endpoint self-hosted or public endpoints.
|
||||
Can work with any runpod endpoints that support OpenAI-compatible API
|
||||
"""
|
||||
|
||||
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(provider_model_id, model_descriptor)
|
||||
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
config: RunpodImplConfig
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get API key for OpenAI client."""
|
||||
return self.config.api_token
|
||||
|
||||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: RunpodImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||
self.config = config
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return self.config.url
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def openai_embeddings(
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
):
|
||||
"""Override to add RunPod-specific stream_options requirement."""
|
||||
if stream and not stream_options:
|
||||
stream_options = {"include_usage": True}
|
||||
|
||||
return await super().openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -63,9 +63,6 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||
return [m.id for m in await self._get_client().models.list()]
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return True
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -30,10 +30,6 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
default=True,
|
||||
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||
)
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -53,10 +53,6 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
# Strictly respecting the refresh_models directive
|
||||
return self.config.refresh_models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||
|
|
|
|||
|
|
@ -4,19 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import WatsonXConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
async def get_adapter_impl(config: WatsonXConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .watsonx import WatsonXInferenceAdapter
|
||||
|
||||
if not isinstance(config, WatsonXConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = WatsonXInferenceAdapter(config)
|
||||
return adapter
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
||||
|
|
|
|||
|
|
@ -7,16 +7,18 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
url: str
|
||||
api_key: str
|
||||
project_id: str
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="forbid",
|
||||
)
|
||||
watsonx_api_key: str | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
# This seems like it should be required, but none of the other remote inference
|
||||
# providers require it, so this is optional here too for consistency.
|
||||
# The OpenAIConfig uses default=None instead, so this is following that precedent.
|
||||
api_key: SecretStr | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||
description="The watsonx API key",
|
||||
default=None,
|
||||
description="The watsonx.ai API key",
|
||||
)
|
||||
# As above, this is optional here too for consistency.
|
||||
project_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||
description="The Project ID key",
|
||||
default=None,
|
||||
description="The watsonx.ai project ID",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-2-13b-chat",
|
||||
CoreModelId.llama2_13b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
|
|
@ -4,240 +4,120 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ibm_watsonx_ai.foundation_models import Model
|
||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
GreedySamplingStrategy,
|
||||
Inference,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from . import WatsonXConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::watsonx")
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
|
||||
# Note on structured output
|
||||
# WatsonX returns responses with a json embedded into a string.
|
||||
# Examples:
|
||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
|
||||
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
|
||||
# Not even a valid JSON, but we can still extract the JSON from the content
|
||||
def __init__(self, config: WatsonXConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
)
|
||||
self.available_models = None
|
||||
self.config = config
|
||||
|
||||
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
|
||||
# "year_born": "1963", "year_retired": "2003"\\}}$')
|
||||
# Find the start of the boxed content
|
||||
def get_base_url(self) -> str:
|
||||
return self.config.url
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: WatsonXConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||
self._config = config
|
||||
self._openai_client: AsyncOpenAI | None = None
|
||||
|
||||
self._project_id = self._config.project_id
|
||||
|
||||
def _get_client(self, model_id) -> Model:
|
||||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||
config_url = self._config.url
|
||||
project_id = self._config.project_id
|
||||
credentials = {"url": config_url, "apikey": config_api_key}
|
||||
|
||||
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/openai/v1",
|
||||
api_key=self._config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
||||
input_dict = {"params": {}}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
if request.sampling_params:
|
||||
if request.sampling_params.strategy:
|
||||
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||
if request.sampling_params.max_tokens:
|
||||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||
if request.sampling_params.repetition_penalty:
|
||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||
|
||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||
|
||||
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
||||
|
||||
params = {
|
||||
**input_dict,
|
||||
}
|
||||
# Add watsonx.ai specific parameters
|
||||
params["project_id"] = self.config.project_id
|
||||
params["time_limit"] = self.config.timeout
|
||||
return params
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
# Copied from OpenAIMixin
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
return model in self._model_cache
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
models = []
|
||||
for model_spec in self._get_model_specs():
|
||||
functions = [f["id"] for f in model_spec.get("functions", [])]
|
||||
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# watsonx.ai sometimes adds usage data to the stream
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
# Example of an embedding model:
|
||||
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
||||
# 'label': 'granite-embedding-278m-multilingual',
|
||||
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
||||
# ...
|
||||
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
||||
if "embedding" in functions:
|
||||
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
||||
context_length = model_spec["model_limits"]["max_sequence_length"]
|
||||
embedding_metadata = {
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"context_length": context_length,
|
||||
}
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata=embedding_metadata,
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
if "text_chat" in functions:
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
||||
# In that case, the cache will record the generator Model object, and the list which we return will have
|
||||
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
||||
# only used for check_model_availability() anyway.
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
return models
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
||||
# So we need to implement our own method to list models by calling the watsonx.ai API.
|
||||
def _get_model_specs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieves foundation model specifications from the watsonx.ai API.
|
||||
"""
|
||||
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||
headers = {
|
||||
# Note that there is no authorization header. Listing models does not require authentication.
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
# --- Process the Response ---
|
||||
# Raise an exception for bad status codes (4xx or 5xx)
|
||||
response.raise_for_status()
|
||||
|
||||
# If the request is successful, parse and return the JSON response.
|
||||
# The response should contain a list of model specifications
|
||||
response_data = response.json()
|
||||
if "resources" not in response_data:
|
||||
raise ValueError("Resources not found in response")
|
||||
return response_data["resources"]
|
||||
|
|
|
|||
|
|
@ -167,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -349,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -390,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
log.info("Connection to PGVector database server closed")
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
|
|
|
|||
|
|
@ -191,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
await self.client.close()
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -347,6 +347,8 @@ class WeaviateVectorIOAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
for client in self.client_cache.values():
|
||||
client.close()
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
|
|
@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
b64_encode_openai_embeddings_response,
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_tooldef_to_openai_tool,
|
||||
get_sampling_options,
|
||||
|
|
@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin(
|
|||
return False
|
||||
|
||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||
|
||||
|
||||
def b64_encode_openai_embeddings_response(
|
||||
response_data: list[dict], encoding_format: str | None = "float"
|
||||
) -> list[OpenAIEmbeddingData]:
|
||||
"""
|
||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||
"""
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response_data):
|
||||
if encoding_format == "base64":
|
||||
byte_array = bytearray()
|
||||
for embedding_value in embedding_data["embedding"]:
|
||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||
|
||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||
else:
|
||||
response_embedding = embedding_data["embedding"]
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=response_embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ class RemoteInferenceProviderConfig(BaseModel):
|
|||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically from the provider",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import json
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
|
|
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
SamplingParams,
|
||||
|
|
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
|
|||
params["user"] = user
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def b64_encode_openai_embeddings_response(
|
||||
response_data: dict, encoding_format: str | None = "float"
|
||||
) -> list[OpenAIEmbeddingData]:
|
||||
"""
|
||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||
"""
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response_data):
|
||||
if encoding_format == "base64":
|
||||
byte_array = bytearray()
|
||||
for embedding_value in embedding_data.embedding:
|
||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||
|
||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||
else:
|
||||
response_embedding = embedding_data.embedding
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=response_embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -474,17 +474,23 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
Check if a specific model is available from the provider's /v1/models or pre-registered.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
:return: True if the model is available dynamically or pre-registered, False otherwise.
|
||||
"""
|
||||
# First check if the model is pre-registered in the model store
|
||||
if hasattr(self, "model_store") and self.model_store:
|
||||
if await self.model_store.has_model(model):
|
||||
return True
|
||||
|
||||
# Then check the provider's dynamic model cache
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
return model in self._model_cache
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
return self.config.refresh_models
|
||||
|
||||
#
|
||||
# The model_dump implementations are to avoid serializing the extra fields,
|
||||
|
|
|
|||
|
|
@ -293,6 +293,18 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
await self._resume_incomplete_batches()
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Clean up mixin resources including background tasks."""
|
||||
# Cancel any running file batch tasks gracefully
|
||||
tasks_to_cancel = list(self._file_batch_tasks.items())
|
||||
for _, task in tasks_to_cancel:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a vector store."""
|
||||
|
|
@ -587,7 +599,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
content = self._chunk_to_vector_store_content(chunk)
|
||||
|
||||
response_data_item = VectorStoreSearchResponse(
|
||||
file_id=chunk.metadata.get("file_id", ""),
|
||||
file_id=chunk.metadata.get("document_id", ""),
|
||||
filename=chunk.metadata.get("filename", ""),
|
||||
score=score,
|
||||
attributes=chunk.metadata,
|
||||
|
|
@ -746,12 +758,15 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||
|
||||
chunk_attributes = attributes.copy()
|
||||
chunk_attributes["filename"] = file_response.filename
|
||||
|
||||
chunks = make_overlapped_chunks(
|
||||
file_id,
|
||||
content,
|
||||
max_chunk_size_tokens,
|
||||
chunk_overlap_tokens,
|
||||
attributes,
|
||||
chunk_attributes,
|
||||
)
|
||||
if not chunks:
|
||||
vector_store_file_object.status = "failed"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
|
|
@ -129,26 +128,6 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
|
|||
return ""
|
||||
|
||||
|
||||
def concat_interleaved_content(content: list[InterleavedContent]) -> InterleavedContent:
|
||||
"""concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list"""
|
||||
|
||||
ret = []
|
||||
|
||||
def _process(c):
|
||||
if isinstance(c, str):
|
||||
ret.append(TextContentItem(text=c))
|
||||
elif isinstance(c, list):
|
||||
for item in c:
|
||||
_process(item)
|
||||
else:
|
||||
ret.append(c)
|
||||
|
||||
for c in content:
|
||||
_process(c)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def content_from_doc(doc: RAGDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
|
|
|
|||
|
|
@ -221,8 +221,8 @@ fi
|
|||
cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
||||
--network llama-net \
|
||||
-p "${PORT}:${PORT}" \
|
||||
"${SERVER_IMAGE}" --port "${PORT}" \
|
||||
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}")
|
||||
-e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
|
||||
"${SERVER_IMAGE}" --port "${PORT}")
|
||||
|
||||
log "🦙 Starting Llama Stack..."
|
||||
if ! execute_with_log $ENGINE "${cmd[@]}"; then
|
||||
|
|
|
|||
|
|
@ -191,9 +191,11 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
|||
echo "Llama Stack Server is already running, skipping start"
|
||||
else
|
||||
echo "=== Starting Llama Stack Server ==="
|
||||
# Set a reasonable log width for better readability in server.log
|
||||
export LLAMA_STACK_LOG_WIDTH=120
|
||||
nohup llama stack run ci-tests --image-type venv > server.log 2>&1 &
|
||||
|
||||
# remove "server:" from STACK_CONFIG
|
||||
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
||||
nohup llama stack run $stack_config > server.log 2>&1 &
|
||||
|
||||
echo "Waiting for Llama Stack Server to start..."
|
||||
for i in {1..30}; do
|
||||
|
|
|
|||
|
|
@ -16,10 +16,19 @@
|
|||
|
||||
set -Eeuo pipefail
|
||||
|
||||
CONTAINER_RUNTIME=${CONTAINER_RUNTIME:-docker}
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
if command -v podman &> /dev/null; then
|
||||
CONTAINER_RUNTIME="podman"
|
||||
elif command -v docker &> /dev/null; then
|
||||
CONTAINER_RUNTIME="docker"
|
||||
else
|
||||
echo "🚨 Neither Podman nor Docker could be found"
|
||||
echo "Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🚀 Setting up telemetry stack for Llama Stack using Podman..."
|
||||
echo "🚀 Setting up telemetry stack for Llama Stack using $CONTAINER_RUNTIME..."
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then
|
||||
echo "🚨 $CONTAINER_RUNTIME could not be found"
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue