mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
Merge remote-tracking branch 'origin/main' into math_500
This commit is contained in:
commit
599873e485
92 changed files with 23070 additions and 4653 deletions
9
.cursor/rules/general.mdc
Normal file
9
.cursor/rules/general.mdc
Normal file
|
@ -0,0 +1,9 @@
|
|||
---
|
||||
description: General rules always applicable across the project
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
# Style
|
||||
|
||||
- Comments must add value to code. Don't write filler comments explaining what you are doing next; they just add noise.
|
||||
- Add a comment to clarify surprising behavior which would not be obvious. Good variable naming and clear code organization is more important.
|
8
.github/dependabot.yml
vendored
Normal file
8
.github/dependabot.yml
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
# GitHub Dependabot configuration
|
||||
version: 2
|
||||
updates:
|
||||
# Enable version updates for GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/" # Will use the default workflow location of `.github/workflows`
|
||||
schedule:
|
||||
interval: "daily"
|
|
@ -310,7 +310,7 @@ jobs:
|
|||
- name: "PR - Upload Test Summary"
|
||||
id: pr_test_summary_upload
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-summary
|
||||
path: test-summary.md
|
||||
|
@ -320,7 +320,7 @@ jobs:
|
|||
- name: "PR - Update comment"
|
||||
id: pr_update_comment
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: thollander/actions-comment-pull-request@v2
|
||||
uses: thollander/actions-comment-pull-request@v3
|
||||
with:
|
||||
filePath: test-summary.md
|
||||
|
||||
|
|
2
.github/workflows/update-readthedocs.yml
vendored
2
.github/workflows/update-readthedocs.yml
vendored
|
@ -12,12 +12,14 @@ on:
|
|||
- main
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
|
||||
jobs:
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -20,3 +20,4 @@ _build
|
|||
docs/src
|
||||
pyrightconfig.json
|
||||
venv/
|
||||
pytest-report.xml
|
||||
|
|
0
.gitmodules
vendored
0
.gitmodules
vendored
|
@ -15,10 +15,6 @@ repos:
|
|||
- id: end-of-file-fixer
|
||||
exclude: '^(.*\.svg)$'
|
||||
|
||||
# Temporarily disabling this
|
||||
# - id: no-commit-to-branch
|
||||
# args: ['--branch=main']
|
||||
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||
rev: v1.5.4
|
||||
hooks:
|
||||
|
@ -68,12 +64,6 @@ repos:
|
|||
- pydantic
|
||||
pass_filenames: false
|
||||
|
||||
# - repo: https://github.com/jsh9/pydoclint
|
||||
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
|
||||
# hooks:
|
||||
# - id: pydoclint
|
||||
# args: [--config=pyproject.toml]
|
||||
|
||||
# - repo: https://github.com/tcort/markdown-link-check
|
||||
# rev: v3.11.2
|
||||
# hooks:
|
||||
|
|
1444
CHANGELOG.md
1444
CHANGELOG.md
File diff suppressed because it is too large
Load diff
|
@ -5,3 +5,4 @@ include llama_stack/distribution/*.sh
|
|||
include llama_stack/cli/scripts/*.sh
|
||||
include llama_stack/templates/*/*.yaml
|
||||
include llama_stack/providers/tests/test_cases/inference/*.json
|
||||
include llama_stack/models/llama/*/*.md
|
||||
|
|
1427
docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb
Normal file
1427
docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb
Normal file
File diff suppressed because it is too large
Load diff
|
@ -127,15 +127,11 @@ MCP tools require:
|
|||
|
||||
## Adding Custom Tools
|
||||
|
||||
When you want to use tools other than the built-in tools, you can implement a python function and decorate it with `@client_tool`.
|
||||
When you want to use tools other than the built-in tools, you just need to implement a python function with a docstring. The content of the docstring will be used to describe the tool and the parameters and passed
|
||||
along to the generative model.
|
||||
|
||||
To define a custom tool, you need to use the `@client_tool` decorator.
|
||||
```python
|
||||
from llama_stack_client.lib.agents.client_tool import client_tool
|
||||
|
||||
|
||||
# Example tool definition
|
||||
@client_tool
|
||||
def my_tool(input: int) -> int:
|
||||
"""
|
||||
Runs my awesome tool.
|
||||
|
|
|
@ -24,6 +24,56 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
|||
- Associated with `Benchmark` resource.
|
||||
|
||||
|
||||
## Open-benchmark Eval
|
||||
|
||||
### List of open-benchmarks Llama Stack support
|
||||
|
||||
Llama stack pre-registers several popular open-benchmarks to easily evaluate model perfomance via CLI.
|
||||
|
||||
The list of open-benchmarks we currently support:
|
||||
- [MMLU-COT](https://arxiv.org/abs/2009.03300) (Measuring Massive Multitask Language Understanding): Benchmark designed to comprehensively evaluate the breadth and depth of a model's academic and professional understanding
|
||||
- [GPQA-COT](https://arxiv.org/abs/2311.12022) (A Graduate-Level Google-Proof Q&A Benchmark): A challenging benchmark of 448 multiple-choice questions written by domain experts in biology, physics, and chemistry.
|
||||
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
|
||||
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
|
||||
|
||||
|
||||
You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
|
||||
|
||||
### Run evaluation on open-benchmarks via CLI
|
||||
|
||||
We have built-in functionality to run the supported open-benckmarks using llama-stack-client CLI
|
||||
|
||||
#### Spin up Llama Stack server
|
||||
|
||||
Spin up llama stack server with 'open-benchmark' template
|
||||
```
|
||||
llama stack run llama_stack/templates/open-benchmark/run.yaml
|
||||
|
||||
```
|
||||
|
||||
#### Run eval CLI
|
||||
There are 3 necessary inputs to run a benchmark eval
|
||||
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
||||
- `model-id`: The model id to evaluate on
|
||||
- `utput_dir`: Path to store the evaluate results
|
||||
```
|
||||
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
||||
--model_id <model id to evaluate on> \
|
||||
--output_dir <directory to store the evaluate results> \
|
||||
```
|
||||
|
||||
You can run
|
||||
```
|
||||
llama-stack-client eval run-benchmark help
|
||||
```
|
||||
to see the description of all the flags that eval run-benchmark has
|
||||
|
||||
|
||||
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate
|
||||
evaluation results over there.
|
||||
|
||||
|
||||
|
||||
## What's Next?
|
||||
|
||||
- Check out our Colab notebook on working examples with running benchmark evaluations [here](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb#scrollTo=mxLCsP4MvFqP).
|
||||
|
|
|
@ -34,7 +34,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
|||
|
||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
|
||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||
|
||||
Providers come in two flavors:
|
||||
|
|
|
@ -13,16 +13,18 @@
|
|||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
from docutils import nodes
|
||||
import tomli # Import tomli for TOML parsing
|
||||
from pathlib import Path
|
||||
import requests
|
||||
import json
|
||||
|
||||
# Read version from pyproject.toml
|
||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||
pyproject = tomli.load(f)
|
||||
llama_stack_version = pyproject["project"]["version"]
|
||||
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
||||
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
|
||||
print(f"{version_tag=}")
|
||||
|
||||
# generate the full link including text and url here
|
||||
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{llama_stack_version}"
|
||||
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{version_tag}"
|
||||
llama_stack_version_link = f"<a href='{llama_stack_version_url}'>release notes</a>"
|
||||
|
||||
project = "llama-stack"
|
||||
|
@ -77,7 +79,7 @@ myst_enable_extensions = [
|
|||
|
||||
myst_substitutions = {
|
||||
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
|
||||
"llama_stack_version": llama_stack_version,
|
||||
"llama_stack_version": version_tag,
|
||||
"llama_stack_version_link": llama_stack_version_link,
|
||||
}
|
||||
|
||||
|
|
|
@ -51,25 +51,25 @@ The main points to consider are:
|
|||
|
||||
```
|
||||
llama stack build -h
|
||||
|
||||
usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates]
|
||||
[--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only]
|
||||
usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates] [--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only] [--run]
|
||||
|
||||
Build a Llama stack container
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml.
|
||||
If this argument is not provided, you will be prompted to enter information interactively
|
||||
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
|
||||
--list-templates Show the available templates for building a Llama Stack distribution
|
||||
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will
|
||||
be prompted to enter information interactively (default: None)
|
||||
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates (default: None)
|
||||
--list-templates Show the available templates for building a Llama Stack distribution (default: False)
|
||||
--image-type {conda,container,venv}
|
||||
Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config.
|
||||
Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config. (default:
|
||||
conda)
|
||||
--image-name IMAGE_NAME
|
||||
[for image-type=conda] Name of the conda environment to use for the build. If
|
||||
not specified, currently active Conda environment will be used. If no Conda
|
||||
environment is active, you must specify a name.
|
||||
--print-deps-only Print the dependencies for the stack only, without building the stack
|
||||
[for image-type=conda|venv] Name of the conda or virtual environment to use for the build. If not specified, currently active Conda environment will be used if
|
||||
found. (default: None)
|
||||
--print-deps-only Print the dependencies for the stack only, without building the stack (default: False)
|
||||
--run Run the stack after building using the same image type, name, and other applicable arguments (default: False)
|
||||
|
||||
```
|
||||
|
||||
After this step is complete, a file named `<name>-build.yaml` and template file `<name>-run.yaml` will be generated and saved at the output file path specified at the end of the command.
|
||||
|
@ -212,8 +212,8 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
|
|||
|
||||
```
|
||||
llama stack run -h
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE]
|
||||
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
|
||||
[--image-type {conda,container,venv}]
|
||||
config
|
||||
|
||||
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
||||
|
@ -223,17 +223,17 @@ positional arguments:
|
|||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321
|
||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
||||
--image-name IMAGE_NAME
|
||||
Name of the image to run. Defaults to the current conda environment
|
||||
--disable-ipv6 Disable IPv6 support
|
||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.
|
||||
Name of the image to run. Defaults to the current conda environment (default: None)
|
||||
--disable-ipv6 Disable IPv6 support (default: False)
|
||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
||||
--tls-keyfile TLS_KEYFILE
|
||||
Path to TLS key file for HTTPS
|
||||
Path to TLS key file for HTTPS (default: None)
|
||||
--tls-certfile TLS_CERTFILE
|
||||
Path to TLS certificate file for HTTPS
|
||||
Path to TLS certificate file for HTTPS (default: None)
|
||||
--image-type {conda,container,venv}
|
||||
Image Type used during the build. This can be either conda or container or venv.
|
||||
Image Type used during the build. This can be either conda or container or venv. (default: conda)
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -17,26 +17,4 @@ $ llama-stack-client configure --endpoint https://llamastack-preview.fireworks.a
|
|||
$ llama-stack-client models list
|
||||
```
|
||||
|
||||
You will see outputs:
|
||||
```
|
||||
$ llama-stack-client models list
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| identifier | llama_model | provider_id | metadata |
|
||||
+==============================+==============================+===============+============+
|
||||
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} |
|
||||
+------------------------------+------------------------------+---------------+------------+
|
||||
```
|
||||
|
||||
Checkout the [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python/blob/main/docs/cli_reference.md) repo for more details on how to use the `llama-stack-client` CLI. Checkout [llama-stack-app](https://github.com/meta-llama/llama-stack-apps/tree/main) for examples applications built on top of Llama Stack.
|
||||
|
|
|
@ -68,6 +68,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
|||
| FAISS | Single Node |
|
||||
| SQLite-Vec| Single Node |
|
||||
| Chroma | Hosted and Single Node |
|
||||
| Milvus | Hosted and Single Node |
|
||||
| Postgres (PGVector) | Hosted and Single Node |
|
||||
| Weaviate | Hosted |
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
|
||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
|
||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||
|
||||
Providers come in two flavors:
|
||||
|
@ -55,5 +55,6 @@ vector_io/sqlite-vec
|
|||
vector_io/chromadb
|
||||
vector_io/pgvector
|
||||
vector_io/qdrant
|
||||
vector_io/milvus
|
||||
vector_io/weaviate
|
||||
```
|
||||
|
|
31
docs/source/providers/vector_io/mivus.md
Normal file
31
docs/source/providers/vector_io/mivus.md
Normal file
|
@ -0,0 +1,31 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Milvus
|
||||
|
||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within a Milvus database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
||||
## Features
|
||||
|
||||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
## Usage
|
||||
|
||||
To use Milvus in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Milvus.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install Milvus using pymilvus:
|
||||
|
||||
```bash
|
||||
pip install pymilvus
|
||||
```
|
||||
## Documentation
|
||||
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
|
@ -275,18 +275,25 @@ response = client.scoring.score(
|
|||
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
|
||||
|
||||
#### Benchmark Evaluation CLI
|
||||
Usage: There are 2 inputs necessary for running a benchmark eval
|
||||
- `eval-task-id`: the identifier associated with the eval task. Each `Benchmark` is parametrized by
|
||||
- `dataset_id`: the identifier associated with the dataset.
|
||||
- `List[scoring_function_id]`: list of scoring function identifiers.
|
||||
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
|
||||
There are 3 necessary input for running a benchmark eval
|
||||
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
||||
- `model-id`: The model id to evaluate on
|
||||
- `utput_dir`: Path to store the evaluate results
|
||||
```
|
||||
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
||||
--model_id <model id to evaluate on> \
|
||||
--output_dir <directory to store the evaluate results> \
|
||||
```
|
||||
|
||||
You can run
|
||||
```
|
||||
llama-stack-client eval run-benchmark help
|
||||
```
|
||||
to see the description of all the flags to run benckmark eval
|
||||
|
||||
|
||||
```
|
||||
llama-stack-client eval run_benchmark <eval-task-id> \
|
||||
--eval-task-config ~/benchmark_config.json \
|
||||
--visualize
|
||||
```
|
||||
In the output log, you can find the path to the file that has your evaluation results. Open that file and you can see you aggrgate
|
||||
evaluation results over there.
|
||||
|
||||
|
||||
#### Application Evaluation CLI
|
||||
|
@ -338,3 +345,52 @@ The `BenchmarkConfig` are user specified config to define:
|
|||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Open-benchmark Contributing Guide
|
||||
|
||||
### Create the new dataset for your new benchmark
|
||||
An eval open-benchmark essentially contains 2 parts:
|
||||
- `raw data`: The raw dataset associated with the benchmark. You typically need to search the original paper that introduces the benchmark and find the canonical dataset (usually hosted on huggingface)
|
||||
- `prompt template`: How to ask the candidate model to generate the answer (prompt template plays a critical role to the evaluation results). Tyically, you can find the reference prompt template associated with the benchmark in benchmarks author's repo ([exmaple](https://github.com/idavidrein/gpqa/blob/main/prompts/chain_of_thought.txt)) or some other popular open source repos ([example](https://github.com/openai/simple-evals/blob/0a6e8f62e52bc5ae915f752466be3af596caf392/common.py#L14))
|
||||
|
||||
To create new open-benmark in llama stack, you need to combine the prompt template and the raw data into the `chat_completion_input` column in the evaluation dataset.
|
||||
|
||||
Llama stack enforeces the evaluate dataset schema to contain at least 3 columns:
|
||||
- `chat_completion_input`: The actual input to the model to run the generation for eval
|
||||
- `input_query`: The raw input from the raw dataset without the prompt template
|
||||
- `expected_answer`: The ground truth for scoring functions to calcalate the score from.
|
||||
|
||||
|
||||
You need to write a script [example convert script](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840) to convert the benchmark raw dataset to llama stack format eval dataset and update the dataset to huggingface [example benchmark dataset](https://huggingface.co/datasets/llamastack/mmmu)
|
||||
|
||||
|
||||
### Find scoring function for your new benchmark
|
||||
The purpose of scoring function is to calculate the score for each example based on candidate model generation result and expected_answer. It also aggregates the scores from all the examples and generate the final evaluate results.
|
||||
|
||||
|
||||
Firstly, you can see if the existing [llama stack scoring functions](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/scoring) can fulfill your need. If not, you need to write a new scoring function based on what benchmark author / other open source repo describe.
|
||||
|
||||
### Add new benchmark into template
|
||||
Firstly, you need to add the evaluation dataset associated with your benchmark under `datasets` resource in the [open-benchmark](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/open-benchmark/run.yaml)
|
||||
|
||||
Secondly, you need to add the new benchmark you just created under the `benchmarks` resource in the same template. To add the new benchmark, you need to have
|
||||
- `benchmark_id`: identifier of the benchmark
|
||||
- `dataset_id`: identifier of the dataset associated with your benchmark
|
||||
- `scoring_functions`: scoring function to calculate the score based on generation results and expected_answer
|
||||
|
||||
|
||||
### Test the new benchmark
|
||||
|
||||
Spin up llama stack server with 'open-benchmark' templates
|
||||
```
|
||||
llama stack run llama_stack/templates/open-benchmark/run.yaml
|
||||
|
||||
```
|
||||
|
||||
Run eval benchmark CLI with your new benchmark id
|
||||
```
|
||||
llama-stack-client eval run-benchmark <new_benchmark_id> \
|
||||
--model_id <model id to evaluate on> \
|
||||
--output_dir <directory to store the evaluate results> \
|
||||
```
|
||||
|
|
|
@ -199,7 +199,7 @@ AgentToolGroup = register_schema(
|
|||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
|
|
@ -40,7 +40,7 @@ class BatchInference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse: ...
|
||||
|
@ -50,7 +50,7 @@ class BatchInference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = list,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
|
@ -278,14 +278,14 @@ ResponseFormat = register_schema(
|
|||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedContent
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
|
@ -299,7 +299,7 @@ class CompletionResponse(MetricResponseMixin):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
|
@ -357,7 +357,7 @@ class ToolConfig(BaseModel):
|
|||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
||||
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
@ -444,7 +444,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -467,7 +467,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.cli.subcommand import Subcommand
|
|||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ModelPromptFormat(Subcommand):
|
||||
|
@ -44,6 +44,12 @@ class ModelPromptFormat(Subcommand):
|
|||
default="llama3_1",
|
||||
help="Model Family (llama3_1, llama3_X, etc.)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
"--list",
|
||||
action="store_true",
|
||||
help="List all available models",
|
||||
)
|
||||
|
||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||
import importlib.resources
|
||||
|
|
|
@ -16,7 +16,7 @@ class StackBuild(Subcommand):
|
|||
"build",
|
||||
prog="llama stack build",
|
||||
description="Build a Llama stack container",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_build_command)
|
||||
|
|
|
@ -5,15 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
@ -23,7 +23,7 @@ class StackRun(Subcommand):
|
|||
"run",
|
||||
prog="llama stack run",
|
||||
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_run_cmd)
|
||||
|
@ -37,12 +37,13 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321",
|
||||
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT.",
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-name",
|
||||
type=str,
|
||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||
help="Name of the image to run. Defaults to the current conda environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
|
|
@ -32,7 +32,10 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
@ -160,6 +163,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
|
||||
return sync_generator()
|
||||
|
@ -262,21 +268,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = {}
|
||||
if self.provider_data:
|
||||
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
# Use context manager for provider data
|
||||
with request_provider_data_context(headers):
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
@ -374,9 +384,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
content=wrapped_gen,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
|
|
|
@ -4,16 +4,62 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_THREAD_LOCAL = threading.local()
|
||||
# Context variable for request provider data
|
||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||
self.provider_data = provider_data
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = _provider_data_var.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
_provider_data_var.reset(self.token)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
This ensures that context variables set during generator creation are
|
||||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value right now
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
# Set context before each anext() call
|
||||
_ = _provider_data_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
@ -26,7 +72,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
||||
val = _provider_data_var.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
@ -36,25 +82,32 @@ class NeedsRequestProviderData:
|
|||
return provider_data
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing provider data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str]):
|
||||
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse provider data from request headers"""
|
||||
keys = [
|
||||
"X-LlamaStack-Provider-Data",
|
||||
"x-llamastack-provider-data",
|
||||
]
|
||||
val = None
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
val = json.loads(val)
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
log.error("Provider data not encoded as a JSON object!")
|
||||
return None
|
||||
|
||||
_THREAD_LOCAL.provider_data_header_value = val
|
||||
|
||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data)
|
||||
|
|
|
@ -7,7 +7,6 @@ import importlib
|
|||
import inspect
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -35,6 +34,7 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
|
@ -50,6 +50,8 @@ from llama_stack.providers.datatypes import (
|
|||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
@ -163,9 +165,7 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
optional_api_dependencies=[Api.telemetry],
|
||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||
deps__=[info.routing_table_api.value],
|
||||
),
|
||||
)
|
||||
}
|
||||
|
@ -186,7 +186,7 @@ def validate_and_prepare_providers(
|
|||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
|
@ -208,11 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
|||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logcat.error("core", p.deprecation_error)
|
||||
logger.error(p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logcat.warning(
|
||||
"core",
|
||||
logger.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
@ -246,9 +245,10 @@ def sort_providers_by_deps(
|
|||
)
|
||||
)
|
||||
|
||||
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
logger.debug("")
|
||||
return sorted_providers
|
||||
|
||||
|
||||
|
@ -389,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
@ -65,17 +65,9 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict
|
|||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
api_to_dep_impl = {}
|
||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,10 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
@ -22,10 +20,6 @@ from llama_stack.apis.eval import (
|
|||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
@ -33,14 +27,13 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
|
@ -49,7 +42,6 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
@ -59,10 +51,10 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
@ -72,15 +64,15 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing VectorIORouter")
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.initialize")
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.shutdown")
|
||||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_vector_db(
|
||||
|
@ -91,10 +83,7 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
||||
)
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
@ -109,8 +98,7 @@ class VectorIORouter(VectorIO):
|
|||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
@ -121,7 +109,7 @@ class VectorIORouter(VectorIO):
|
|||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
|
@ -131,21 +119,16 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing InferenceRouter")
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.initialize")
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.shutdown")
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_model(
|
||||
|
@ -156,68 +139,16 @@ class InferenceRouter(Inference):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||
) -> List[MetricEvent]:
|
||||
span = get_current_span()
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return metrics
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
|
@ -225,11 +156,12 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logcat.debug(
|
||||
"core",
|
||||
) -> AsyncGenerator:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
@ -274,59 +206,23 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
else:
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
return await provider.chat_completion(**params)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
logcat.debug(
|
||||
"core",
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
|
@ -343,41 +239,10 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
else:
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
return await provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -387,7 +252,7 @@ class InferenceRouter(Inference):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
@ -407,15 +272,15 @@ class SafetyRouter(Safety):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing SafetyRouter")
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.initialize")
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.shutdown")
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
|
@ -425,7 +290,7 @@ class SafetyRouter(Safety):
|
|||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
|
@ -434,7 +299,7 @@ class SafetyRouter(Safety):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
|
@ -447,15 +312,15 @@ class DatasetIORouter(DatasetIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing DatasetIORouter")
|
||||
logger.debug("Initializing DatasetIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.initialize")
|
||||
logger.debug("DatasetIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.shutdown")
|
||||
logger.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
|
@ -465,8 +330,7 @@ class DatasetIORouter(DatasetIO):
|
|||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
|
@ -477,7 +341,7 @@ class DatasetIORouter(DatasetIO):
|
|||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
|
@ -489,15 +353,15 @@ class ScoringRouter(Scoring):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ScoringRouter")
|
||||
logger.debug("Initializing ScoringRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.initialize")
|
||||
logger.debug("ScoringRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.shutdown")
|
||||
logger.debug("ScoringRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
|
@ -506,7 +370,7 @@ class ScoringRouter(Scoring):
|
|||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
|
||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
|
@ -527,10 +391,7 @@ class ScoringRouter(Scoring):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
||||
)
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
|
@ -548,15 +409,15 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing EvalRouter")
|
||||
logger.debug("Initializing EvalRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.initialize")
|
||||
logger.debug("EvalRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.shutdown")
|
||||
logger.debug("EvalRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
|
@ -564,7 +425,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=benchmark_config,
|
||||
|
@ -577,7 +438,7 @@ class EvalRouter(Eval):
|
|||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
|
@ -590,7 +451,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
|
@ -598,7 +459,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
@ -609,7 +470,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
@ -622,7 +483,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
|
@ -631,7 +492,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
) -> RAGQueryResult:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
|
@ -642,9 +503,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
|
@ -654,7 +514,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter")
|
||||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
|
@ -663,15 +523,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.initialize")
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.shutdown")
|
||||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
|
@ -680,5 +540,5 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
@ -9,7 +9,6 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
@ -28,10 +27,12 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
|
@ -39,6 +40,7 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
|
@ -54,8 +56,7 @@ from .endpoints import get_all_api_endpoints
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||
logcat.init()
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
@ -142,23 +143,23 @@ def handle_signal(app, signum, _) -> None:
|
|||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logcat.info("server", f"Shutting down {impl_name}")
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logcat.warning("server", f"No shutdown method for {impl_name}")
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logcat.exception("server", f"Shutdown timeout for {impl_name}")
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
|
@ -172,7 +173,7 @@ def handle_signal(app, signum, _) -> None:
|
|||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logcat.exception("server", "Timeout while waiting for tasks to finish")
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
|
@ -184,9 +185,9 @@ def handle_signal(app, signum, _) -> None:
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logcat.info("server", "Starting up")
|
||||
logger.info("Starting up")
|
||||
yield
|
||||
logcat.info("server", "Shutting down")
|
||||
logger.info("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
|
||||
|
@ -204,16 +205,14 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
async for item in await event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logcat.info("server", "Generator cancelled")
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Error in sse_generator: {e}")
|
||||
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -225,18 +224,20 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Error in {func.__name__}")
|
||||
raise translate_exception(e) from e
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logger.exception("Error executing endpoint %s", method, route)
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
@ -314,8 +315,6 @@ class ClientVersionMiddleware:
|
|||
|
||||
|
||||
def main():
|
||||
logcat.init()
|
||||
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
|
@ -355,10 +354,10 @@ def main():
|
|||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.yaml_config:
|
||||
|
@ -366,12 +365,12 @@ def main():
|
|||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
logcat.info("server", f"Using config file: {config_file}")
|
||||
logger.info(f"Using config file: {config_file}")
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
logcat.info("server", f"Using template {args.template} config file: {config_file}")
|
||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
|
@ -379,10 +378,9 @@ def main():
|
|||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
logcat.info("server", "Run configuration:")
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
|
||||
logcat.info("server", log_line)
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
@ -392,7 +390,7 @@ def main():
|
|||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
|
@ -437,7 +435,7 @@ def main():
|
|||
)
|
||||
)
|
||||
|
||||
logcat.debug("server", f"serving APIs: {apis_to_serve}")
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
@ -464,10 +462,10 @@ def main():
|
|||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||
logcat.info("server", f"Listening on {listen_host}:{port}")
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
|
|
|
@ -11,9 +11,7 @@ import tempfile
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
|
@ -39,8 +37,11 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
VectorDBs,
|
||||
|
@ -101,9 +102,8 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
|||
objects_to_process = response.data if hasattr(response, "data") else response
|
||||
|
||||
for obj in objects_to_process:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
logger.debug(
|
||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -100,12 +100,15 @@ esac
|
|||
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
set -x
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
set -x
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
|
@ -141,8 +144,6 @@ elif [[ "$env_type" == "container" ]]; then
|
|||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
set -x
|
||||
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
|
|
182
llama_stack/log.py
Normal file
182
llama_stack/log.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
from typing import Dict
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
# Predefined categories
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||
|
||||
|
||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
||||
"""
|
||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||
|
||||
Parameters:
|
||||
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
|
||||
|
||||
level_value = logging._nameToLevel.get(level)
|
||||
if level_value is None:
|
||||
logging.warning(
|
||||
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
|
||||
)
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=120)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
"""Override emit to handle markup errors gracefully."""
|
||||
try:
|
||||
super().emit(record)
|
||||
except MarkupError:
|
||||
original_markup = self.markup
|
||||
self.markup = False
|
||||
try:
|
||||
super().emit(record)
|
||||
finally:
|
||||
self.markup = original_markup
|
||||
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int]) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels.
|
||||
|
||||
Parameters:
|
||||
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
|
||||
|
||||
class CategoryFilter(logging.Filter):
|
||||
"""Ensure category is always present in log records."""
|
||||
|
||||
def filter(self, record):
|
||||
if not hasattr(record, "category"):
|
||||
record.category = "uncategorized" # Default to 'uncategorized' if no category found
|
||||
return True
|
||||
|
||||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
root_level = category_levels.get("root", logging.WARNING)
|
||||
|
||||
logging_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"rich": {
|
||||
"()": logging.Formatter,
|
||||
"format": log_format,
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use our custom handler class
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
},
|
||||
"filters": {
|
||||
"category_filter": {
|
||||
"()": CategoryFilter,
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
category: {
|
||||
"handlers": ["console"],
|
||||
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
|
||||
"propagate": False, # Disable propagation to root logger
|
||||
}
|
||||
for category in CATEGORIES
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": root_level, # Set root logger's level dynamically
|
||||
},
|
||||
}
|
||||
dictConfig(logging_config)
|
||||
|
||||
|
||||
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
|
||||
"""
|
||||
Returns a logger with the specified name and category.
|
||||
If no category is provided, defaults to 'uncategorized'.
|
||||
|
||||
Parameters:
|
||||
name (str): The name of the logger (e.g., module or filename).
|
||||
category (str): The category of the logger (default 'uncategorized').
|
||||
|
||||
Returns:
|
||||
logging.LoggerAdapter: Configured logger with category support.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
setup_logging(_category_levels)
|
|
@ -1,204 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Category-based logging utility for llama-stack.
|
||||
|
||||
This module provides a wrapper over the standard Python logging module that supports
|
||||
categorized logging with environment variable control.
|
||||
|
||||
Usage:
|
||||
from llama_stack import logcat
|
||||
logcat.info("server", "Starting up...")
|
||||
logcat.debug("inference", "Processing request...")
|
||||
|
||||
Environment variable:
|
||||
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
|
||||
Example: "server=debug;inference=warning"
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
COLORS = {
|
||||
"RESET": "\033[0m",
|
||||
"DEBUG": "\033[36m", # Cyan
|
||||
"INFO": "\033[32m", # Green
|
||||
"WARNING": "\033[33m", # Yellow
|
||||
"ERROR": "\033[31m", # Red
|
||||
"CRITICAL": "\033[35m", # Magenta
|
||||
"DIM": "\033[2m", # Dimmed text
|
||||
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
|
||||
}
|
||||
|
||||
# Static list of valid categories representing various parts of the Llama Stack
|
||||
# server codebase
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
_logger = logging.getLogger("llama_stack")
|
||||
_logger.propagate = False
|
||||
|
||||
_default_level = logging.INFO
|
||||
|
||||
# Category-level mapping (can be modified by environment variables)
|
||||
_category_levels: Dict[str, int] = {}
|
||||
|
||||
|
||||
class TerminalStreamHandler(logging.StreamHandler):
|
||||
def __init__(self, stream=None):
|
||||
super().__init__(stream)
|
||||
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
|
||||
|
||||
def format(self, record):
|
||||
record.is_tty = self.is_tty
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter with colors and fixed-width level names"""
|
||||
|
||||
def format(self, record):
|
||||
levelname = record.levelname
|
||||
# Use only time with milliseconds, not date
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
|
||||
|
||||
file_info = f"{record.filename}:{record.lineno}"
|
||||
|
||||
# Get category from extra if available
|
||||
category = getattr(record, "category", None)
|
||||
msg = record.getMessage()
|
||||
|
||||
if getattr(record, "is_tty", False):
|
||||
color = COLORS.get(levelname, COLORS["RESET"])
|
||||
if category:
|
||||
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
|
||||
formatted_msg = (
|
||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
|
||||
f"{file_info:<20} {category_formatted}{msg}"
|
||||
)
|
||||
else:
|
||||
formatted_msg = (
|
||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
|
||||
f"{file_info:<20} {msg}"
|
||||
)
|
||||
else:
|
||||
if category:
|
||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
|
||||
else:
|
||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
|
||||
|
||||
return formatted_msg
|
||||
|
||||
|
||||
def init(default_level: int = logging.INFO) -> None:
|
||||
global _default_level, _category_levels, _logger
|
||||
|
||||
_default_level = default_level
|
||||
|
||||
_logger.setLevel(logging.DEBUG)
|
||||
_logger.handlers = [] # Clear existing handlers
|
||||
|
||||
# Add our custom handler with the colored formatter
|
||||
handler = TerminalStreamHandler()
|
||||
formatter = ColoredFormatter()
|
||||
handler.setFormatter(formatter)
|
||||
_logger.addHandler(handler)
|
||||
|
||||
for category in CATEGORIES:
|
||||
_category_levels[category] = default_level
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().lower()
|
||||
|
||||
level_value = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"warn": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}.get(level)
|
||||
|
||||
if level_value is None:
|
||||
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
for cat in CATEGORIES:
|
||||
_category_levels[cat] = level_value
|
||||
else:
|
||||
if category in CATEGORIES:
|
||||
_category_levels[category] = level_value
|
||||
else:
|
||||
_logger.warning(f"Unknown logging category: {category}")
|
||||
|
||||
except ValueError:
|
||||
_logger.warning(f"Invalid logging configuration: {pair}")
|
||||
|
||||
|
||||
def _should_log(level: int, category: str) -> bool:
|
||||
category = category.lower()
|
||||
if category not in _category_levels:
|
||||
return False
|
||||
category_level = _category_levels[category]
|
||||
return level >= category_level
|
||||
|
||||
|
||||
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
|
||||
if _should_log(level, category):
|
||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
|
||||
|
||||
|
||||
def debug(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.INFO, "info", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warn(category: str, msg: str, *args, **kwargs) -> None:
|
||||
warning(category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def exception(category: str, msg: str, *args, **kwargs) -> None:
|
||||
if _should_log(logging.ERROR, category):
|
||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||
_logger.exception(msg, *args, stacklevel=2, **kwargs)
|
|
@ -17,7 +17,6 @@ from urllib.parse import urlparse
|
|||
|
||||
import httpx
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroup,
|
||||
|
@ -67,6 +66,7 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolCall,
|
||||
|
@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
|||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
|
@ -609,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
message.stop_reason = StopReason.end_of_turn
|
||||
|
@ -617,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
logcat.info("agents", "out of token budget, exiting.")
|
||||
logger.info("out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
@ -631,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message with EOM (iter: {n_iter}): {str(message)}",
|
||||
)
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message (iter: {n_iter}) from the model: {str(message)}",
|
||||
)
|
||||
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -983,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
logcat.info("agents", f"Downloading {url} -> {filepath}")
|
||||
logger.info(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
|
@ -1023,7 +1019,7 @@ async def execute_tool_call_maybe(
|
|||
else:
|
||||
name = name.value
|
||||
|
||||
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
|
@ -1033,7 +1029,7 @@ async def execute_tool_call_maybe(
|
|||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logcat.debug("agents", f"tool call {name} completed with result: {result}")
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -136,11 +136,13 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
|
@ -244,7 +246,7 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -253,6 +255,8 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -64,7 +64,7 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
|
@ -4,20 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMConfig(BaseModel):
|
||||
"""Configuration for the vLLM inference provider."""
|
||||
"""Configuration for the vLLM inference provider.
|
||||
|
||||
Note that the model name is no longer part of this static configuration.
|
||||
You can bind an instance of this provider to a specific model with the
|
||||
``models.register()`` API call."""
|
||||
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
tensor_parallel_size: int = Field(
|
||||
default=1,
|
||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||
|
@ -26,32 +25,27 @@ class VLLMConfig(BaseModel):
|
|||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
|
||||
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
|
||||
enforce_eager: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||
)
|
||||
gpu_memory_utilization: float = Field(
|
||||
default=0.3,
|
||||
description=(
|
||||
"How much GPU memory will be allocated when this provider has finished "
|
||||
"loading, including memory that was already allocated before loading."
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
|
||||
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
|
||||
}
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
|
||||
return model
|
||||
|
|
170
llama_stack/providers/inline/inference/vllm/openai_utils.py
Normal file
170
llama_stack/providers/inline/inference/vllm/openai_utils.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import vllm
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
GrammarResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
)
|
||||
|
||||
###############################################################################
|
||||
# This file contains OpenAI compatibility code that is currently only used
|
||||
# by the inline vLLM connector. Some or all of this code may be moved to a
|
||||
# central location at a later date.
|
||||
|
||||
|
||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||
"""
|
||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||
the content field for compabilitiy with OpenAI-style APIs.
|
||||
|
||||
Generates a content string that emulates the current behavior
|
||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||
|
||||
:param message: Message that may include ``context`` field
|
||||
|
||||
:returns: A version of ``message`` with any context merged into the
|
||||
``content`` field.
|
||||
"""
|
||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||
return message
|
||||
if message.context is None:
|
||||
return message
|
||||
return UserMessage(
|
||||
role=message.role,
|
||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||
content=message.content + "\n\n" + message.context,
|
||||
context=None,
|
||||
)
|
||||
|
||||
|
||||
def _llama_stack_tools_to_openai_tools(
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
"""
|
||||
if tools is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for t in tools:
|
||||
if isinstance(t.tool_name, BuiltinTool):
|
||||
raise NotImplementedError("Built-in tools not yet implemented")
|
||||
if t.parameters is None:
|
||||
parameters = None
|
||||
else: # if t.parameters is not None
|
||||
# Convert the "required" flags to a list of required params
|
||||
required_params = [k for k, v in t.parameters.items() if v.required]
|
||||
parameters = {
|
||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||
"properties": {
|
||||
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
||||
},
|
||||
"required": required_params,
|
||||
}
|
||||
|
||||
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
||||
name=t.tool_name, description=t.description, parameters=parameters
|
||||
)
|
||||
|
||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
||||
return result
|
||||
|
||||
|
||||
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a chat completion request in Llama Stack format into an
|
||||
equivalent set of arguments to pass to an OpenAI-compatible
|
||||
chat completions API.
|
||||
|
||||
:param request: Bundled request parameters in Llama Stack format.
|
||||
|
||||
:returns: Dictionary of key-value pairs to use as an initializer
|
||||
for a dataclass or to be converted directly to JSON and sent
|
||||
over the wire.
|
||||
"""
|
||||
|
||||
converted_messages = [
|
||||
# This mystery async call makes the parent function also be async
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
|
||||
for m in request.messages
|
||||
]
|
||||
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if (
|
||||
request.tool_config is not None
|
||||
and request.tool_config.tool_choice == ToolChoice.auto
|
||||
and request.tools is not None
|
||||
and len(request.tools) > 0
|
||||
):
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
# Other connectors appear to drop it quietly.
|
||||
|
||||
# Use Llama Stack shared code to translate sampling parameters.
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
|
||||
# get_sampling_options() translates repetition penalties to an option that
|
||||
# OpenAI's APIs don't know about.
|
||||
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
|
||||
# For now, translate repetition penalties into a format that vLLM's broken
|
||||
# API will handle correctly. Two wrongs make a right...
|
||||
if "repeat_penalty" in sampling_options:
|
||||
del sampling_options["repeat_penalty"]
|
||||
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
|
||||
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
# Convert a single response format into four different parameters, per
|
||||
# the OpenAI spec
|
||||
guided_decoding_options = dict()
|
||||
if request.response_format is None:
|
||||
# Use defaults
|
||||
pass
|
||||
elif isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
guided_decoding_options["guided_json"] = request.response_format.json_schema
|
||||
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
|
||||
|
||||
logprob_options = dict()
|
||||
if request.logprobs is not None:
|
||||
logprob_options["logprobs"] = request.logprobs.top_k
|
||||
|
||||
# Marshall together all the arguments for a ChatCompletionRequest
|
||||
request_options = {
|
||||
"model": request.model,
|
||||
"messages": converted_messages,
|
||||
"tools": converted_tools,
|
||||
"tool_choice": converted_tool_choice,
|
||||
"stream": request.stream,
|
||||
**sampling_options,
|
||||
**guided_decoding_options,
|
||||
**logprob_options,
|
||||
}
|
||||
|
||||
return request_options
|
|
@ -4,45 +4,71 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
InterleavedContentItem,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama import sku_list
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
get_stop_reason,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -50,188 +76,322 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import VLLMConfig
|
||||
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
||||
# available parsers.
|
||||
# TODO: Expand this list
|
||||
CONFIG_TYPE_TO_TOOL_PARSER = {
|
||||
"GraniteConfig": "granite",
|
||||
"MllamaConfig": "llama3_json",
|
||||
"LlamaConfig": "llama3_json",
|
||||
}
|
||||
DEFAULT_TOOL_PARSER = "pythonic"
|
||||
|
||||
|
||||
def _random_uuid() -> str:
|
||||
logger = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
def _random_uuid_str() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||
|
||||
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
||||
indicating no constraints.
|
||||
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
||||
"""
|
||||
if response_format is None:
|
||||
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
||||
# value that crashes the executor on some code paths. Use ``None`` instead.
|
||||
return None
|
||||
|
||||
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
||||
# Translate the types that exist and detect if Llama Stack adds new ones.
|
||||
if isinstance(response_format, JsonSchemaResponseFormat):
|
||||
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
||||
elif isinstance(response_format, GrammarResponseFormat):
|
||||
# BNF grammar.
|
||||
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
||||
# representation of the grammar.
|
||||
raise TypeError(
|
||||
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
||||
"reference implementation does not implement it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
||||
|
||||
|
||||
def _convert_sampling_params(
|
||||
sampling_params: Optional[SamplingParams],
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
log_prob_config: Optional[LogProbConfig],
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||
format."""
|
||||
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
||||
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if log_prob_config is None:
|
||||
log_prob_config = LogProbConfig()
|
||||
|
||||
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
||||
if sampling_params.strategy.top_k == 0:
|
||||
# vLLM treats "k" differently for top-k sampling
|
||||
vllm_top_k = -1
|
||||
else:
|
||||
vllm_top_k = sampling_params.strategy.top_k
|
||||
else:
|
||||
vllm_top_k = -1
|
||||
|
||||
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
vllm_top_p = sampling_params.strategy.top_p
|
||||
# Llama Stack only allows temperature with top-P.
|
||||
vllm_temperature = sampling_params.strategy.temperature
|
||||
else:
|
||||
vllm_top_p = 1.0
|
||||
vllm_temperature = 0.0
|
||||
|
||||
# vLLM allows top-p and top-k at the same time.
|
||||
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
||||
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
||||
temperature=vllm_temperature,
|
||||
top_p=vllm_top_p,
|
||||
top_k=vllm_top_k,
|
||||
repetition_penalty=sampling_params.repetition_penalty,
|
||||
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
||||
logprobs=log_prob_config.top_k,
|
||||
)
|
||||
return vllm_sampling_params
|
||||
|
||||
|
||||
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
"""Inference implementation for vLLM."""
|
||||
"""
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
||||
|
||||
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
||||
"""
|
||||
|
||||
config: VLLMConfig
|
||||
register_helper: ModelRegistryHelper
|
||||
model_ids: set[str]
|
||||
resolved_model_id: str | None
|
||||
engine: AsyncLLMEngine | None
|
||||
chat: OpenAIServingChat | None
|
||||
is_meta_llama_model: bool
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
logger.info(f"Config is: {self.config}")
|
||||
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# The following are initialized when paths are bound to this provider
|
||||
self.resolved_model_id = None
|
||||
self.model_ids = set()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.is_meta_llama_model = False
|
||||
|
||||
async def initialize(self):
|
||||
log.info("Initializing vLLM inference provider.")
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
||||
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
||||
|
||||
# Disable usage stats reporting. This would be a surprising thing for most
|
||||
# people to find out was on by default.
|
||||
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
|
||||
if "VLLM_NO_USAGE_STATS" not in os.environ:
|
||||
os.environ["VLLM_NO_USAGE_STATS"] = "1"
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Callback that is invoked through many levels of indirection during provider class
|
||||
instantiation, sometime after when __init__() is called and before any model registration
|
||||
methods or methods connected to a REST API are called.
|
||||
|
||||
model = resolve_model(self.config.model)
|
||||
if model is None:
|
||||
raise ValueError(f"Unknown model {self.config.model}")
|
||||
It's not clear what assumptions the class can make about the platform's initialization
|
||||
state here that can't be made during __init__(), and vLLM can't be started until we know
|
||||
what model it's supposed to be serving, so nothing happens here currently.
|
||||
"""
|
||||
pass
|
||||
|
||||
if model.huggingface_repo is None:
|
||||
raise ValueError(f"Model {self.config.model} needs a huggingface repo")
|
||||
|
||||
# TODO -- there are a ton of options supported here ...
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model.huggingface_repo,
|
||||
tokenizer=model.huggingface_repo,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
)
|
||||
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shut down the vLLM inference adapter."""
|
||||
log.info("Shutting down vLLM inference provider.")
|
||||
if self.engine:
|
||||
async def shutdown(self) -> None:
|
||||
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.model_ids = set()
|
||||
self.resolved_model_id = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
||||
|
||||
# Note that the return type of the superclass method is WRONG
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Callback that is called when the server associates an inference endpoint
|
||||
with an inference provider.
|
||||
Callback that is called when the server associates an inference endpoint with an
|
||||
inference provider.
|
||||
|
||||
:param model: Object that encapsulates parameters necessary for identifying
|
||||
a specific LLM.
|
||||
:param model: Object that encapsulates parameters necessary for identifying a specific
|
||||
LLM.
|
||||
|
||||
:returns: The input ``Model`` object. It may or may not be permissible
|
||||
to change fields before returning this object.
|
||||
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
||||
before returning this object.
|
||||
"""
|
||||
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
|
||||
# The current version of this provided is hard-coded to serve only
|
||||
# the model specified in the YAML config file.
|
||||
configured_model = resolve_model(self.config.model)
|
||||
registered_model = resolve_model(model.model_id)
|
||||
logger.debug(f"In register_model({model})")
|
||||
|
||||
# First attempt to interpret the model coordinates as a Llama model name
|
||||
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
||||
if resolved_llama_model is not None:
|
||||
# Load from Hugging Face repo into default local cache dir
|
||||
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
||||
|
||||
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
|
||||
# Don't set self.is_meta_llama_model until we actually load the model.
|
||||
is_meta_llama_model = True
|
||||
else: # if resolved_llama_model is None
|
||||
# Not a Llama model name. Pass the model id through to vLLM's loader
|
||||
model_id_for_vllm = model.provider_model_id
|
||||
is_meta_llama_model = False
|
||||
|
||||
if self.resolved_model_id is not None:
|
||||
if model_id_for_vllm != self.resolved_model_id:
|
||||
raise ValueError(
|
||||
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
||||
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
|
||||
f"copies of the provider instead."
|
||||
)
|
||||
else:
|
||||
# Model already loaded
|
||||
logger.info(
|
||||
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
||||
)
|
||||
self.model_ids.add(model.model_id)
|
||||
return model
|
||||
|
||||
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
||||
if is_meta_llama_model:
|
||||
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
||||
self.is_meta_llama_model = is_meta_llama_model
|
||||
|
||||
# If we get here, this is the first time registering a model.
|
||||
# Preload so that the first inference request won't time out.
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_id_for_vllm,
|
||||
tokenizer=model_id_for_vllm,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
max_num_seqs=self.config.max_num_seqs,
|
||||
max_model_len=self.config.max_model_len,
|
||||
)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
||||
# parser, we need to determine what model architecture is being used. For now, we infer
|
||||
# that information from what config class the model uses.
|
||||
low_level_model_config = self.engine.engine.get_model_config()
|
||||
hf_config = low_level_model_config.hf_config
|
||||
hf_config_class_name = hf_config.__class__.__name__
|
||||
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
||||
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
||||
else:
|
||||
# No info -- choose a default so we can at least attempt tool
|
||||
# use.
|
||||
tool_parser = DEFAULT_TOOL_PARSER
|
||||
logger.debug(f"{hf_config_class_name=}")
|
||||
logger.debug(f"{tool_parser=}")
|
||||
|
||||
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
||||
model_config = await self.engine.get_model_config()
|
||||
self.chat = OpenAIServingChat(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
models=OpenAIServingModels(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
base_model_paths=[
|
||||
# The layer below us will only see resolved model IDs
|
||||
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
|
||||
],
|
||||
),
|
||||
response_role="assistant",
|
||||
request_logger=None, # Use default logging
|
||||
chat_template=None, # Use default template from model checkpoint
|
||||
enable_auto_tools=True,
|
||||
tool_parser=tool_parser,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
self.resolved_model_id = model_id_for_vllm
|
||||
self.model_ids.add(model.model_id)
|
||||
|
||||
logger.info(f"Finished preloading model: {model_id_for_vllm}")
|
||||
|
||||
if configured_model.core_model_id != registered_model.core_model_id:
|
||||
raise ValueError(
|
||||
f"Requested model '{model.identifier}' is different from "
|
||||
f"model '{self.config.model}' that this provider "
|
||||
f"is configured to serve"
|
||||
)
|
||||
return model
|
||||
|
||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||
if sampling_params is None:
|
||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||
|
||||
options = get_sampling_options(sampling_params)
|
||||
if "repeat_penalty" in options:
|
||||
options["repetition_penalty"] = options["repeat_penalty"]
|
||||
del options["repeat_penalty"]
|
||||
|
||||
return VLLMSamplingParams(**options)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
"""
|
||||
Callback that is called when the server removes an inference endpoint from an inference
|
||||
provider.
|
||||
|
||||
:param model_id: The same external ID that the higher layers of the stack previously passed
|
||||
to :func:`register_model()`
|
||||
"""
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
||||
)
|
||||
self.model_ids.remove(model_id)
|
||||
|
||||
if len(self.model_ids) == 0:
|
||||
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
||||
# resources.
|
||||
# Note that this operation may cause in-flight chat completion requests on the
|
||||
# now-unregistered model to return errors.
|
||||
self.resolved_model_id = None
|
||||
self.chat = None
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM Inference INTERFACE
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
raise NotImplementedError("Completion not implemented for vLLM")
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
if not isinstance(content, str):
|
||||
raise NotImplementedError("Multimodal input not currently supported")
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
assert self.engine is not None
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
logger.debug(f"{converted_sampling_params=}")
|
||||
|
||||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.config.model)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, results_generator)
|
||||
return self._streaming_completion(content, converted_sampling_params)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, results_generator)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> ChatCompletionResponse:
|
||||
outputs = [o async for o in results_generator]
|
||||
final_output = outputs[-1]
|
||||
|
||||
assert final_output is not None
|
||||
outputs = final_output.outputs
|
||||
finish_reason = outputs[-1].stop_reason
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=finish_reason,
|
||||
text="".join([output.text for output in outputs]),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> AsyncGenerator:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
cur = []
|
||||
async for chunk in results_generator:
|
||||
if not chunk.outputs:
|
||||
log.warning("Empty chunk received")
|
||||
continue
|
||||
|
||||
output = chunk.outputs[-1]
|
||||
|
||||
new_tokens = output.token_ids[len(cur) :]
|
||||
text = tokenizer.decode(new_tokens)
|
||||
cur.extend(new_tokens)
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=output.finish_reason,
|
||||
text=text,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
streaming_result = None
|
||||
async for _ in self._streaming_completion(content, converted_sampling_params):
|
||||
pass
|
||||
return CompletionResponse(
|
||||
content=streaming_result.delta,
|
||||
stop_reason=streaming_result.stop_reason,
|
||||
logprobs=streaming_result.logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -242,3 +402,391 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message], # type: ignore
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None, # type: ignore
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
|
||||
# Convert to Llama Stack internal format for consistency
|
||||
request = ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if self.is_meta_llama_model:
|
||||
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
||||
# templating layer in Llama Stack currently produces better results.
|
||||
logger.debug(
|
||||
f"Routing {self.resolved_model_id} chat completion through "
|
||||
f"Llama Stack's templating layer instead of vLLM's."
|
||||
)
|
||||
return await self._chat_completion_for_meta_llama(request)
|
||||
|
||||
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
||||
|
||||
logger.debug(f"Converted request: {chat_completion_request}")
|
||||
|
||||
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
||||
logger.debug(f"Result from vLLM: {vllm_result}")
|
||||
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
||||
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
||||
|
||||
# Return type depends on "stream" argument
|
||||
if stream:
|
||||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator.
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
||||
return self._convert_non_streaming_results(vllm_result)
|
||||
|
||||
###########################################################################
|
||||
# INTERNAL METHODS
|
||||
|
||||
async def _streaming_completion(
|
||||
self, content: str, sampling_params: vllm.SamplingParams
|
||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
||||
that arguments have been validated upstream.
|
||||
|
||||
:param content: Must be a string
|
||||
:param sampling_params: Paramters from public API's ``response_format``
|
||||
and ``sampling_params`` arguments, converted to VLLM format
|
||||
"""
|
||||
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
||||
# layer, because doing so simplifies the code here.
|
||||
|
||||
# The vLLM engine requires a unique identifier for each call to generate()
|
||||
request_id = _random_uuid_str()
|
||||
|
||||
# The vLLM generate() API is streaming-only and returns an async generator.
|
||||
# The generator returns objects of type vllm.RequestOutput.
|
||||
results_generator = self.engine.generate(content, sampling_params, request_id)
|
||||
|
||||
# Need to know the model's EOS token ID for the conversion code below.
|
||||
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
|
||||
# we drill down to the LLMEngine inside the AsyncLLMEngine.
|
||||
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
|
||||
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
|
||||
llm_engine = self.engine.engine
|
||||
tokenizer_group = llm_engine.tokenizer
|
||||
eos_token_id = tokenizer_group.tokenizer.eos_token_id
|
||||
|
||||
request_output: vllm.RequestOutput = None
|
||||
async for request_output in results_generator:
|
||||
# Check for weird inference failures
|
||||
if request_output.outputs is None or len(request_output.outputs) == 0:
|
||||
# This case also should never happen
|
||||
raise ValueError("Inference produced empty result")
|
||||
|
||||
# If we get here, then request_output contains the final output of the generate() call.
|
||||
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
||||
# us to return one.
|
||||
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||
completion_string = output.text
|
||||
|
||||
# Convert logprobs from vLLM's format to Llama Stack's format
|
||||
logprobs = [
|
||||
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
||||
for logprob_dict in output.logprobs
|
||||
]
|
||||
|
||||
# The final output chunk should be labeled with the reason that the overall generate()
|
||||
# call completed.
|
||||
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
||||
if output.stop_reason is None:
|
||||
stop_reason = None # Still going
|
||||
elif output.stop_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif output.stop_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
elif isinstance(output.stop_reason, int):
|
||||
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
||||
# will return the token ID of the EOS token in the stop_reason field.
|
||||
stop_reason = StopReason.end_of_turn
|
||||
else:
|
||||
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
||||
|
||||
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
||||
# some reason.
|
||||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||
|
||||
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
||||
# provide one if it runs out of tokens.
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=completion_string,
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def _convert_non_streaming_results(
|
||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
||||
equivalent Llama Stack object.
|
||||
|
||||
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
||||
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
||||
the fields that aren't currently present in the Llama Stack dataclass.
|
||||
"""
|
||||
|
||||
# There may be multiple responses, but we can only pass through the first one.
|
||||
if len(vllm_result.choices) == 0:
|
||||
raise ValueError("Don't know how to convert response object without any responses")
|
||||
vllm_message = vllm_result.choices[0].message
|
||||
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
||||
|
||||
converted_message = CompletionMessage(
|
||||
role=vllm_message.role,
|
||||
# Llama Stack API won't accept None for content field.
|
||||
content=("" if vllm_message.content is None else vllm_message.content),
|
||||
stop_reason=get_stop_reason(vllm_finish_reason),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id=t.id,
|
||||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Convert logprobs
|
||||
|
||||
logger.debug(f"Converted message: {converted_message}")
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=converted_message,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
of the chat template currently produces more reliable outputs.
|
||||
|
||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
||||
"""
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# Note that this function call modifies `request` in place.
|
||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
||||
|
||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||
completion_response_or_iterator = await self.completion(
|
||||
model_id=model_id,
|
||||
content=prompt,
|
||||
sampling_params=request.sampling_params,
|
||||
response_format=request.response_format,
|
||||
stream=request.stream,
|
||||
logprobs=request.logprobs,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
||||
)
|
||||
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
||||
|
||||
# elsif not request.stream:
|
||||
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
||||
)
|
||||
completion_response: CompletionResponse = completion_response_or_iterator
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
completion_response.content, completion_response.stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=completion_response.logprobs,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama_streaming(
|
||||
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
||||
) -> AsyncIterator:
|
||||
"""
|
||||
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
||||
method to keep asyncio happy.
|
||||
"""
|
||||
|
||||
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
||||
last_text_len = 0
|
||||
async for chunk in results_iterator:
|
||||
if chunk.stop_reason == StopReason.end_of_turn:
|
||||
finish_reason = "stop"
|
||||
elif chunk.stop_reason == StopReason.end_of_message:
|
||||
finish_reason = "eos"
|
||||
elif chunk.stop_reason == StopReason.out_of_tokens:
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = None
|
||||
|
||||
# Convert delta back to an actual delta
|
||||
text_delta = chunk.delta[last_text_len:]
|
||||
last_text_len = len(chunk.delta)
|
||||
|
||||
logger.debug(f"{text_delta=}; {finish_reason=}")
|
||||
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
logger.debug(f"Returning chunk: {chunk}")
|
||||
yield chunk
|
||||
|
||||
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
||||
"""
|
||||
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
||||
API into a second async iterator that returns Llama Stack objects.
|
||||
|
||||
:param vllm_result: Stream of strings that need to be parsed
|
||||
"""
|
||||
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||
# those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: Dict[int, Dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||
# simplify logic below
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=None,
|
||||
)
|
||||
)
|
||||
|
||||
converted_stop_reason = None
|
||||
async for chunk_str in vllm_result:
|
||||
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
||||
# end with "\n\n".
|
||||
_prefix = "data: "
|
||||
_suffix = "\n\n"
|
||||
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
||||
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
||||
|
||||
# In between the "data: " and newlines is an event record
|
||||
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
||||
|
||||
# The end of the stream is indicated with "[DONE]"
|
||||
if data_str == "[DONE]":
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Anything that is not "[DONE]" should be a JSON record
|
||||
parsed_chunk = json.loads(data_str)
|
||||
|
||||
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs only support
|
||||
# returning one.
|
||||
first_choice = parsed_chunk["choices"][0]
|
||||
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
||||
delta_record = first_choice["delta"]
|
||||
|
||||
if "content" in delta_record:
|
||||
# Text delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=delta_record["content"]),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
elif "tool_calls" in delta_record:
|
||||
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
||||
# calls, so buffer until we get a "tool calls" stop reason
|
||||
for tc in delta_record["tool_calls"]:
|
||||
index = tc["index"]
|
||||
if index not in index_to_tool_call:
|
||||
# First time this tool call is showing up
|
||||
index_to_tool_call[index] = dict()
|
||||
tool_call = index_to_tool_call[index]
|
||||
if "id" in tc:
|
||||
tool_call["call_id"] = tc["id"]
|
||||
if "function" in tc:
|
||||
if "name" in tc["function"]:
|
||||
tool_call["tool_name"] = tc["function"]["name"]
|
||||
if "arguments" in tc["function"]:
|
||||
# Arguments comes in as pieces of a string
|
||||
if "arguments_str" not in tool_call:
|
||||
tool_call["arguments_str"] = ""
|
||||
tool_call["arguments_str"] += tc["function"]["arguments"]
|
||||
else:
|
||||
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
||||
|
||||
if first_choice["finish_reason"] == "tool_calls":
|
||||
# Special OpenAI code for "tool calls complete".
|
||||
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
||||
# call.
|
||||
for tool_call_record in index_to_tool_call.values():
|
||||
# Arguments come in as a string. Parse the completed string.
|
||||
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
||||
del tool_call_record["arguments_str"]
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
||||
# normally.
|
||||
raise ValueError("vLLM event stream ended without [DONE] message.")
|
||||
|
|
|
@ -73,7 +73,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
@ -172,8 +171,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
@ -99,7 +100,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
await self._save_index()
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k)
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
19
llama_stack/providers/inline/vector_io/milvus/__init__.py
Normal file
19
llama_stack/providers/inline/vector_io/milvus/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
20
llama_stack/providers/inline/vector_io/milvus/config.py
Normal file
20
llama_stack/providers/inline/vector_io/milvus/config.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"db_path": "${env.MILVUS_DB_PATH}"}
|
|
@ -110,4 +110,22 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="milvus",
|
||||
pip_packages=["pymilvus"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus"],
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -72,7 +72,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -83,7 +83,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -92,6 +92,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
@ -72,11 +72,13 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -112,7 +114,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
@ -121,6 +123,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
@ -71,7 +71,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -82,7 +82,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -91,6 +91,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
|
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -55,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
|
@ -68,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
return self.config.api_key.get_secret_value()
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
|
@ -86,11 +89,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -157,7 +162,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
@ -166,6 +171,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -233,7 +240,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logcat.debug("inference", f"params to fireworks: {params}")
|
||||
logger.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -93,11 +93,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if content_has_media(content):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
|
@ -188,7 +190,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -197,6 +199,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||
|
||||
|
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
|
@ -35,6 +34,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
@ -59,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import model_entries
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
|
@ -72,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError as e:
|
||||
|
@ -90,11 +90,13 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -145,7 +147,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -154,6 +156,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -210,7 +214,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logcat.debug("inference", f"params to ollama: {params}")
|
||||
logger.debug(f"params to ollama: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
|
@ -286,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding:
|
||||
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
response = await self.client.list()
|
||||
else:
|
||||
|
|
|
@ -81,11 +81,13 @@ class PassthroughInferenceAdapter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
@ -107,7 +109,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
@ -116,6 +118,8 @@ class PassthroughInferenceAdapter(Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -65,7 +65,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -74,6 +74,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
@ -74,7 +74,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -85,7 +85,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -94,6 +94,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
tool_config: Optional[ToolConfig] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
|
|
|
@ -98,11 +98,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -201,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
@ -210,6 +212,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from together import Together
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
|
@ -32,9 +31,8 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
|
@ -54,6 +52,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
|
@ -70,11 +70,13 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -91,8 +93,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key.get_secret_value()
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
|
@ -151,7 +154,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
@ -160,6 +163,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -220,7 +225,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||
}
|
||||
logcat.debug("inference", f"params to together: {params}")
|
||||
logger.debug(f"params to together: {params}")
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -7,7 +7,7 @@ import json
|
|||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
|
@ -229,7 +229,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
@ -241,11 +241,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -264,7 +266,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
@ -273,6 +275,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
|
@ -296,10 +300,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return await self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = client.chat.completions.create(**params)
|
||||
r = await client.chat.completions.create(**params)
|
||||
choice = r.choices[0]
|
||||
result = ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
|
@ -311,17 +315,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
|
||||
# generator so this wrapper is not necessary?
|
||||
async def _to_async_generator():
|
||||
s = client.chat.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
stream = await client.chat.completions.create(**params)
|
||||
if len(request.tools) > 0:
|
||||
res = _process_vllm_chat_completion_stream_response(stream)
|
||||
else:
|
||||
|
@ -331,26 +328,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self.client.completions.create(**params)
|
||||
r = await self.client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# Wrapper for async generator similar
|
||||
async def _to_async_generator():
|
||||
stream = self.client.completions.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
stream = await self.client.completions.create(**params)
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = self.client.models.list()
|
||||
available_models = [m.id for m in res]
|
||||
res = await self.client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||
|
@ -406,7 +397,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
assert model.metadata.get("embedding_dimension")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
|
||||
response = self.client.embeddings.create(
|
||||
response = await self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -31,7 +31,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
@ -77,12 +77,13 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
"q": kwargs["query"],
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
url=self.url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=self.url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_response(response.json())))
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -30,7 +30,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
@ -74,8 +74,13 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
|
|||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": kwargs["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=url,
|
||||
params=payload,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
content_items = "\n".join([str(result) for result in results])
|
||||
return ToolInvocationResult(
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -30,7 +30,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
@ -66,10 +66,12 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": kwargs["query"]},
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": kwargs["query"]},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json())))
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -31,7 +31,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
@ -73,11 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
"format": "plaintext",
|
||||
"output": "json",
|
||||
}
|
||||
response = requests.get(
|
||||
self.url,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(params=params, url=self.url)
|
||||
response.raise_for_status()
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json())))
|
||||
|
||||
def _clean_wolfram_alpha_response(self, wa_response):
|
||||
|
|
21
llama_stack/providers/remote/vector_io/milvus/__init__.py
Normal file
21
llama_stack/providers/remote/vector_io/milvus/__init__.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
22
llama_stack/providers/remote/vector_io/milvus/config.py
Normal file
22
llama_stack/providers/remote/vector_io/milvus/config.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
uri: str
|
||||
token: Optional[str] = None
|
||||
consistency_level: str = "Strong"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
175
llama_stack/providers/remote/vector_io/milvus/milvus.py
Normal file
175
llama_stack/providers/remote/vector_io/milvus/milvus.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusIndex(EmbeddingIndex):
|
||||
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
|
||||
self.client = client
|
||||
self.collection_name = collection_name.replace("-", "_")
|
||||
self.consistency_level = consistency_level
|
||||
|
||||
async def delete(self):
|
||||
if self.client.has_collection(self.collection_name):
|
||||
self.client.drop_collection(collection_name=self.collection_name)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
if not self.client.has_collection(self.collection_name):
|
||||
self.client.create_collection(
|
||||
self.collection_name,
|
||||
dimension=len(embeddings[0]),
|
||||
auto_id=True,
|
||||
consistency_level=self.consistency_level,
|
||||
)
|
||||
|
||||
data = []
|
||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)
|
||||
|
||||
data.append(
|
||||
{
|
||||
"chunk_id": chunk_id,
|
||||
"vector": embedding,
|
||||
"chunk_content": chunk.model_dump(),
|
||||
}
|
||||
)
|
||||
try:
|
||||
self.client.insert(
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
search_res = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
limit=k,
|
||||
output_fields=["*"],
|
||||
search_params={"params": {"radius": score_threshold}},
|
||||
)
|
||||
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
||||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
consistency_level = "Strong"
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
||||
|
||||
# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
|
|||
|
||||
import litellm
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
|
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -47,6 +47,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class LiteLLMOpenAIMixin(
|
||||
ModelRegistryHelper,
|
||||
|
@ -74,7 +76,7 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -85,7 +87,7 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
@ -94,6 +96,8 @@ class LiteLLMOpenAIMixin(
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -107,8 +111,7 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
|
||||
params = await self._get_params(request)
|
||||
logcat.debug("inference", f"params to litellm (openai compat): {params}")
|
||||
|
||||
logger.debug(f"params to litellm (openai compat): {params}")
|
||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||
# caches various httpx.client objects in a non-eventloop aware manner
|
||||
response = litellm.completion(**params)
|
||||
|
|
|
@ -8,14 +8,12 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
|
@ -34,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ModelFamily,
|
||||
RawContent,
|
||||
|
@ -58,7 +57,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
|
@ -464,7 +463,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
|
|||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||
llama_model = resolve_model(model)
|
||||
if llama_model is None:
|
||||
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||
return ToolPromptFormat.json
|
||||
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
|
|
|
@ -8,9 +8,11 @@ import logging
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo import AsyncMongoClient
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -30,7 +32,7 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
"password": self.config.password,
|
||||
}
|
||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
||||
self.conn = MongoClient(**conn_creds)
|
||||
self.conn = AsyncMongoClient(**conn_creds)
|
||||
self.collection = self.conn[self.config.db][self.config.collection_name]
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to MongoDB database server")
|
||||
|
@ -44,17 +46,17 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
update_query = {"$set": {"value": value, "expiration": expiration}}
|
||||
self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||
await self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
key = self._namespaced_key(key)
|
||||
query = {"key": key}
|
||||
result = self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||
return result["value"] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.collection.delete_one({"key": key})
|
||||
await self.collection.delete_one({"key": key})
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
|
@ -63,4 +65,7 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
"key": {"$gte": start_key, "$lt": end_key},
|
||||
}
|
||||
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
|
||||
return [doc["value"] for doc in cursor]
|
||||
result = []
|
||||
async for doc in cursor:
|
||||
result.append(doc["value"])
|
||||
return result
|
||||
|
|
|
@ -12,11 +12,9 @@ from dataclasses import dataclass
|
|||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from pypdf import PdfReader
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
|
@ -38,6 +36,8 @@ log = logging.getLogger(__name__)
|
|||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
pdf_bytes = io.BytesIO(data)
|
||||
from pypdf import PdfReader
|
||||
|
||||
pdf_reader = PdfReader(pdf_bytes)
|
||||
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
||||
|
||||
|
@ -75,6 +75,8 @@ def content_from_data(data_url: str) -> str:
|
|||
|
||||
encoding = parts["encoding"]
|
||||
if not encoding:
|
||||
import chardet
|
||||
|
||||
detected = chardet.detect(data)
|
||||
encoding = detected["encoding"]
|
||||
|
||||
|
|
36
llama_stack/templates/open-benchmark/build.yaml
Normal file
36
llama_stack/templates/open-benchmark/build.yaml
Normal file
|
@ -0,0 +1,36 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Distribution for running open benchmarks
|
||||
providers:
|
||||
inference:
|
||||
- remote::openai
|
||||
- remote::anthropic
|
||||
- remote::gemini
|
||||
- remote::groq
|
||||
- remote::together
|
||||
vector_io:
|
||||
- inline::sqlite-vec
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
- inline::localfs
|
||||
scoring:
|
||||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
212
llama_stack/templates/open-benchmark/run.yaml
Normal file
212
llama_stack/templates/open-benchmark/run.yaml
Normal file
|
@ -0,0 +1,212 @@
|
|||
version: '2'
|
||||
image_name: open-benchmark
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
api_key: ${env.ANTHROPIC_API_KEY:}
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY:}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
|
||||
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:}
|
||||
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:localhost}
|
||||
port: ${env.PGVECTOR_PORT:5432}
|
||||
db: ${env.PGVECTOR_DB:}
|
||||
user: ${env.PGVECTOR_USER:}
|
||||
password: ${env.PGVECTOR_PASSWORD:}
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
provider_id: openai
|
||||
provider_model_id: openai/gpt-4o
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct
|
||||
provider_id: together
|
||||
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: anthropic/claude-3-5-sonnet-latest
|
||||
provider_id: anthropic
|
||||
provider_model_id: anthropic/claude-3-5-sonnet-latest
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: gemini/gemini-1.5-flash
|
||||
provider_id: gemini
|
||||
provider_model_id: gemini/gemini-1.5-flash
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.3-70b-versatile
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
vector_dbs: []
|
||||
datasets:
|
||||
- dataset_id: simpleqa
|
||||
provider_id: huggingface
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/simpleqa
|
||||
metadata:
|
||||
path: llamastack/simpleqa
|
||||
name:
|
||||
split: train
|
||||
dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
- dataset_id: mmlu_cot
|
||||
provider_id: huggingface
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
|
||||
metadata:
|
||||
path: llamastack/mmlu_cot
|
||||
name: all
|
||||
split: test
|
||||
dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
- dataset_id: gpqa_cot
|
||||
provider_id: huggingface
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
|
||||
metadata:
|
||||
path: llamastack/gpqa_0shot_cot
|
||||
name: gpqa_main
|
||||
split: train
|
||||
dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
scoring_fns: []
|
||||
benchmarks:
|
||||
- benchmark_id: meta-reference-simpleqa
|
||||
dataset_id: simpleqa
|
||||
scoring_functions: ["llm-as-judge::405b-simpleqa"]
|
||||
- benchmark_id: meta-reference-mmlu-cot
|
||||
dataset_id: mmlu_cot
|
||||
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
||||
- benchmark_id: meta-reference-gpqa-cot
|
||||
dataset_id: gpqa_cot
|
||||
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
server:
|
||||
port: 8321
|
|
@ -15,11 +15,12 @@ providers:
|
|||
- provider_id: vllm
|
||||
provider_type: inline::vllm
|
||||
config:
|
||||
model: ${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}
|
||||
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:1}
|
||||
max_tokens: ${env.MAX_TOKENS:4096}
|
||||
max_model_len: ${env.MAX_MODEL_LEN:4096}
|
||||
max_num_seqs: ${env.MAX_NUM_SEQS:4}
|
||||
enforce_eager: ${env.ENFORCE_EAGER:False}
|
||||
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:0.7}
|
||||
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:0.3}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
name = "llama_stack"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||
description = "Llama Stack"
|
||||
readme = "README.md"
|
||||
|
@ -26,7 +26,7 @@ dependencies = [
|
|||
"httpx",
|
||||
"huggingface-hub",
|
||||
"jsonschema",
|
||||
"llama-stack-client>=0.1.4",
|
||||
"llama-stack-client>=0.1.6",
|
||||
"prompt-toolkit",
|
||||
"python-dotenv",
|
||||
"pydantic>=2",
|
||||
|
@ -34,6 +34,8 @@ dependencies = [
|
|||
"rich",
|
||||
"setuptools",
|
||||
"termcolor",
|
||||
"tiktoken",
|
||||
"pillow",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
@ -63,7 +65,6 @@ test = [
|
|||
"groq",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"tiktoken",
|
||||
"chardet",
|
||||
"pypdf",
|
||||
]
|
||||
|
@ -79,7 +80,7 @@ docs = [
|
|||
"sphinxcontrib.mermaid",
|
||||
"tomli",
|
||||
]
|
||||
codegen = ["rich", "pydantic", "jinja2"]
|
||||
codegen = ["rich", "pydantic", "jinja2>=3.1.6"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/meta-llama/llama-stack"
|
||||
|
@ -136,8 +137,6 @@ ignore = [
|
|||
|
||||
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
|
||||
"C901", # Complexity of the function is too high
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
"B008",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
@ -153,7 +152,6 @@ exclude = [
|
|||
"llama_stack/distribution",
|
||||
"llama_stack/apis",
|
||||
"llama_stack/cli",
|
||||
"llama_stack/logcat.py",
|
||||
"llama_stack/models",
|
||||
"llama_stack/strong_typing",
|
||||
"llama_stack/templates",
|
||||
|
@ -165,5 +163,5 @@ module = ["yaml", "fire"]
|
|||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "llama_stack.distribution.resolver"
|
||||
follow_imports = "normal" # This will force type checking on this module
|
||||
module = ["llama_stack.distribution.resolver", "llama_stack.log"]
|
||||
follow_imports = "normal" # This will force type checking on this module
|
||||
|
|
|
@ -20,13 +20,14 @@ huggingface-hub==0.29.0
|
|||
idna==3.10
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
llama-stack-client==0.1.4
|
||||
llama-stack-client==0.1.6
|
||||
lxml==5.3.1
|
||||
markdown-it-py==3.0.0
|
||||
mdurl==0.1.2
|
||||
numpy==2.2.3
|
||||
packaging==24.2
|
||||
pandas==2.2.3
|
||||
pillow==11.1.0
|
||||
prompt-toolkit==3.0.50
|
||||
pyaml==25.1.0
|
||||
pycryptodomex==3.21.0
|
||||
|
@ -38,6 +39,7 @@ python-dotenv==1.0.1
|
|||
pytz==2025.1
|
||||
pyyaml==6.0.2
|
||||
referencing==0.36.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
rich==13.9.4
|
||||
rpds-py==0.22.3
|
||||
|
@ -45,6 +47,7 @@ setuptools==75.8.0
|
|||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
termcolor==2.5.0
|
||||
tiktoken==0.9.0
|
||||
tqdm==4.67.1
|
||||
typing-extensions==4.12.2
|
||||
tzdata==2025.1
|
||||
|
|
75
scripts/gen-changelog.py
Normal file
75
scripts/gen-changelog.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def get_all_releases(token):
|
||||
url = f"https://api.github.com/repos/meta-llama/llama-stack/releases"
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
|
||||
if token:
|
||||
headers["Authorization"] = f"token {token}"
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(
|
||||
f"Error fetching releases: {response.status_code}, {response.text}"
|
||||
)
|
||||
|
||||
|
||||
def clean_release_body(body):
|
||||
"""Remove '## All changes' sections from release notes."""
|
||||
lines = body.split("\n")
|
||||
cleaned_lines = []
|
||||
skip_mode = False
|
||||
|
||||
for line in lines:
|
||||
if line.strip() in [
|
||||
"## All changes",
|
||||
"### What's Changed",
|
||||
"## What's Changed",
|
||||
"## New Contributors",
|
||||
]:
|
||||
skip_mode = True
|
||||
elif skip_mode and line.startswith("##"):
|
||||
# Found a new section, stop skipping
|
||||
skip_mode = False
|
||||
cleaned_lines.append(line)
|
||||
elif not skip_mode:
|
||||
cleaned_lines.append(line)
|
||||
|
||||
return "\n".join(cleaned_lines)
|
||||
|
||||
|
||||
def merge_release_notes(output_file, token=None):
|
||||
releases = get_all_releases(token)
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as md_file:
|
||||
md_file.write(f"# Changelog\n\n")
|
||||
|
||||
for release in releases:
|
||||
md_file.write(f"# {release['tag_name']}\n")
|
||||
md_file.write(f"Published on: {release['published_at']}\n\n")
|
||||
|
||||
# Clean the release body to remove "## All changes" sections
|
||||
cleaned_body = clean_release_body(release["body"])
|
||||
md_file.write(f"{cleaned_body}\n\n")
|
||||
|
||||
md_file.write("---\n\n")
|
||||
|
||||
print(f"Merged release notes saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
OUTPUT_FILE = "CHANGELOG.md"
|
||||
TOKEN = os.getenv("GITHUB_TOKEN")
|
||||
merge_release_notes(OUTPUT_FILE, TOKEN)
|
|
@ -3,4 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -3,4 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -9,7 +9,6 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.client_tool import client_tool
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
|
@ -23,7 +22,6 @@ from llama_stack.apis.agents.agents import (
|
|||
)
|
||||
|
||||
|
||||
@client_tool
|
||||
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit
|
||||
|
@ -41,7 +39,6 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
|||
return -1
|
||||
|
||||
|
||||
@client_tool
|
||||
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit
|
||||
|
@ -276,7 +273,6 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
|||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": ["builtin::websearch", client_tool],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
|
@ -571,7 +567,10 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
|||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)])
|
||||
@pytest.mark.parametrize(
|
||||
"client_tools",
|
||||
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
||||
)
|
||||
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
|
||||
client_tool, expectes_metadata = client_tools
|
||||
agent_config = {
|
||||
|
|
|
@ -42,7 +42,7 @@ def provider_data():
|
|||
for key, value in keymap.items():
|
||||
if os.environ.get(key):
|
||||
provider_data[value] = os.environ[key]
|
||||
return provider_data if len(provider_data) > 0 else None
|
||||
return provider_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -121,6 +121,9 @@ class RecordableMock:
|
|||
# Replace temporary file paths created by tempfile.mkdtemp()
|
||||
key = re.sub(r"/var/folders/[^,'\"\s]+", "<TEMP_FILE>", key)
|
||||
|
||||
# Replace /tmp/ paths which are also commonly used for temporary files
|
||||
key = re.sub(r"/tmp/[^,'\"\s]+", "<TEMP_FILE>", key)
|
||||
|
||||
return key
|
||||
|
||||
def _save_cache(self):
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -3,4 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -27,6 +27,7 @@ def base64_image_url(base64_image_data, image_path):
|
|||
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
|
||||
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
|
@ -55,6 +56,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
|
|||
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
|
||||
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
|
|
|
@ -3,4 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -81,8 +81,6 @@ def test_scoring_functions_register(
|
|||
|
||||
def test_scoring_score(llama_stack_client):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
|
@ -119,8 +117,6 @@ def test_scoring_score(llama_stack_client):
|
|||
|
||||
def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
|
|
|
@ -3,4 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -4,6 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -39,9 +46,41 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
|
|||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
class MockInferenceAdapterWithSleep:
|
||||
def __init__(self, sleep_time: int, response: Dict[str, Any]):
|
||||
self.httpd = None
|
||||
|
||||
class DelayedRequestHandler(BaseHTTPRequestHandler):
|
||||
# ruff: noqa: N802
|
||||
def do_POST(self):
|
||||
time.sleep(sleep_time)
|
||||
self.send_response(code=200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
|
||||
self.request_handler = DelayedRequestHandler
|
||||
|
||||
def __enter__(self):
|
||||
httpd = HTTPServer(("", 0), self.request_handler)
|
||||
self.httpd = httpd
|
||||
host, port = httpd.server_address
|
||||
httpd_thread = threading.Thread(target=httpd.serve_forever)
|
||||
httpd_thread.daemon = True # stop server if this thread terminates
|
||||
httpd_thread.start()
|
||||
|
||||
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
|
||||
inference_adapter = VLLMInferenceAdapter(config)
|
||||
return inference_adapter
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _traceback):
|
||||
if self.httpd:
|
||||
self.httpd.shutdown()
|
||||
self.httpd.server_close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_openai_models_list():
|
||||
with patch("openai.resources.models.Models.list") as mock_list:
|
||||
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
|
||||
yield mock_list
|
||||
|
||||
|
||||
|
@ -56,10 +95,10 @@ async def vllm_inference_adapter():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
|
||||
mock_openai_models = [
|
||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||
]
|
||||
mock_openai_models_list.return_value = mock_openai_models
|
||||
async def mock_openai_models():
|
||||
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
|
||||
|
||||
mock_openai_models_list.return_value = mock_openai_models()
|
||||
|
||||
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
|
||||
|
||||
|
@ -141,3 +180,55 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
|
|||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 0
|
||||
|
||||
|
||||
def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.set_debug(True)
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Log when event loop is blocked for more than 100ms
|
||||
loop.slow_callback_duration = 0.1
|
||||
# Sleep for 500ms in our delayed http response
|
||||
sleep_time = 0.5
|
||||
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
|
||||
mock_response = {
|
||||
"id": "chatcmpl-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1,
|
||||
"modle": "mock-model",
|
||||
"choices": [
|
||||
{
|
||||
"message": {"content": ""},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
async def do_chat_completion():
|
||||
await inference_adapter.chat_completion(
|
||||
"mock-model",
|
||||
[],
|
||||
stream=False,
|
||||
tools=None,
|
||||
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||
)
|
||||
|
||||
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
|
||||
inference_adapter.model_store = AsyncMock()
|
||||
inference_adapter.model_store.get_model.return_value = mock_model
|
||||
loop.run_until_complete(inference_adapter.initialize())
|
||||
|
||||
# Clear the logs so far and run the actual chat completion we care about
|
||||
caplog.clear()
|
||||
loop.run_until_complete(do_chat_completion())
|
||||
|
||||
# Ensure we don't have any asyncio warnings in the captured log
|
||||
# records from our chat completion call. A message gets logged
|
||||
# here any time we exceed the slow_callback_duration configured
|
||||
# above.
|
||||
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
|
||||
assert not asyncio_warnings
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from llama_stack import logcat
|
||||
|
||||
|
||||
class TestLogcat(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
|
||||
|
||||
self.log_output = io.StringIO()
|
||||
self._init_logcat()
|
||||
|
||||
def tearDown(self):
|
||||
if self.original_env is not None:
|
||||
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
|
||||
else:
|
||||
os.environ.pop("LLAMA_STACK_LOGGING", None)
|
||||
|
||||
def _init_logcat(self):
|
||||
logcat.init(default_level=logging.DEBUG)
|
||||
self.handler = logging.StreamHandler(self.log_output)
|
||||
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
|
||||
logcat._logger.handlers.clear()
|
||||
logcat._logger.addHandler(self.handler)
|
||||
|
||||
def test_basic_logging(self):
|
||||
logcat.info("server", "Info message")
|
||||
logcat.warning("server", "Warning message")
|
||||
logcat.error("server", "Error message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Info message", output)
|
||||
self.assertIn("[server] Warning message", output)
|
||||
self.assertIn("[server] Error message", output)
|
||||
|
||||
def test_different_categories(self):
|
||||
# Log messages with different categories
|
||||
logcat.info("server", "Server message")
|
||||
logcat.info("inference", "Inference message")
|
||||
logcat.info("router", "Router message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Server message", output)
|
||||
self.assertIn("[inference] Inference message", output)
|
||||
self.assertIn("[router] Router message", output)
|
||||
|
||||
def test_env_var_control(self):
|
||||
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
|
||||
self._init_logcat()
|
||||
|
||||
# These should be visible based on the environment settings
|
||||
logcat.debug("server", "Server debug message")
|
||||
logcat.info("server", "Server info message")
|
||||
logcat.warning("inference", "Inference warning message")
|
||||
logcat.error("inference", "Inference error message")
|
||||
|
||||
# These should be filtered out based on the environment settings
|
||||
logcat.debug("inference", "Inference debug message")
|
||||
logcat.info("inference", "Inference info message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Server debug message", output)
|
||||
self.assertIn("[server] Server info message", output)
|
||||
self.assertIn("[inference] Inference warning message", output)
|
||||
self.assertIn("[inference] Inference error message", output)
|
||||
|
||||
self.assertNotIn("[inference] Inference debug message", output)
|
||||
self.assertNotIn("[inference] Inference info message", output)
|
||||
|
||||
def test_invalid_category(self):
|
||||
logcat.info("nonexistent", "This message should not be logged")
|
||||
|
||||
# Check that the message was not logged
|
||||
output = self.log_output.getvalue()
|
||||
self.assertNotIn("[nonexistent] This message should not be logged", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
117
tests/unit/server/test_resolver.py
Normal file
117
tests/unit/server/test_resolver.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, Dict, Protocol
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.distribution.datatypes import (
|
||||
Api,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.routers.routers import InferenceRouter
|
||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
|
||||
"""Dynamically add protocol methods to a class by inspecting the protocol."""
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||
# Get the signature
|
||||
sig = inspect.signature(value)
|
||||
|
||||
# Create an async function with the same signature that returns a MagicMock
|
||||
async def mock_impl(*args, **kwargs):
|
||||
return MagicMock()
|
||||
|
||||
# Set the signature on our mock implementation
|
||||
mock_impl.__signature__ = sig
|
||||
# Add it to the class
|
||||
setattr(cls, name, mock_impl)
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
foo: str = Field(
|
||||
default="bar",
|
||||
description="foo",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"foo": "baz",
|
||||
}
|
||||
|
||||
|
||||
class SampleImpl:
|
||||
def __init__(self, config: SampleConfig, deps: Dict[Api, Any], provider_spec: ProviderSpec = None):
|
||||
self.__provider_id__ = "test_provider"
|
||||
self.__provider_spec__ = provider_spec
|
||||
self.__provider_config__ = config
|
||||
self.__deps__ = deps
|
||||
self.foo = config.foo
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_impls_basic():
|
||||
# Create a real provider spec
|
||||
provider_spec = InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="sample",
|
||||
module="test_module",
|
||||
config_class="test_resolver.SampleConfig",
|
||||
api_dependencies=[],
|
||||
)
|
||||
|
||||
# Create provider registry with our provider
|
||||
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
|
||||
|
||||
run_config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="sample_provider",
|
||||
provider_type="sample",
|
||||
config=SampleConfig.sample_run_config(),
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
dist_registry = MagicMock()
|
||||
|
||||
mock_module = MagicMock()
|
||||
impl = SampleImpl(SampleConfig(foo="baz"), {}, provider_spec)
|
||||
add_protocol_methods(SampleImpl, Inference)
|
||||
|
||||
mock_module.get_provider_impl = AsyncMock(return_value=impl)
|
||||
sys.modules["test_module"] = mock_module
|
||||
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry)
|
||||
|
||||
assert Api.inference in impls
|
||||
assert isinstance(impls[Api.inference], InferenceRouter)
|
||||
|
||||
table = impls[Api.inference].routing_table
|
||||
assert isinstance(table, ModelsRoutingTable)
|
||||
|
||||
impl = table.impls_by_provider_id["sample_provider"]
|
||||
assert isinstance(impl, SampleImpl)
|
||||
assert impl.foo == "baz"
|
||||
assert impl.__provider_id__ == "sample_provider"
|
||||
assert impl.__provider_spec__ == provider_spec
|
26
uv.lock
generated
26
uv.lock
generated
|
@ -1,4 +1,5 @@
|
|||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10"
|
||||
resolution-markers = [
|
||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
|
@ -733,14 +734,14 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.5"
|
||||
version = "3.1.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markupsafe" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/92/b3130cbbf5591acf9ade8708c365f3238046ac7cb8ccba6e81abccb0ccff/jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb", size = 244674 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -861,7 +862,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "llama-stack"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "blobfile" },
|
||||
|
@ -870,6 +871,7 @@ dependencies = [
|
|||
{ name = "huggingface-hub" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "llama-stack-client" },
|
||||
{ name = "pillow" },
|
||||
{ name = "prompt-toolkit" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "python-dotenv" },
|
||||
|
@ -877,6 +879,7 @@ dependencies = [
|
|||
{ name = "rich" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "termcolor" },
|
||||
{ name = "tiktoken" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
@ -923,7 +926,6 @@ test = [
|
|||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "pypdf" },
|
||||
{ name = "sqlite-vec" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "torch", version = "2.6.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "torch", version = "2.6.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin'" },
|
||||
{ name = "torchvision", version = "0.21.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||
|
@ -942,9 +944,9 @@ requires-dist = [
|
|||
{ name = "groq", marker = "extra == 'test'" },
|
||||
{ name = "httpx" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "jinja2", marker = "extra == 'codegen'" },
|
||||
{ name = "jinja2", marker = "extra == 'codegen'", specifier = ">=3.1.6" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.1.4" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.1.6" },
|
||||
{ name = "lm-format-enforcer", marker = "extra == 'test'", specifier = ">=0.10.9" },
|
||||
{ name = "myst-parser", marker = "extra == 'docs'" },
|
||||
{ name = "nbval", marker = "extra == 'dev'" },
|
||||
|
@ -952,6 +954,7 @@ requires-dist = [
|
|||
{ name = "openai", marker = "extra == 'test'" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'" },
|
||||
{ name = "prompt-toolkit" },
|
||||
{ name = "pydantic", specifier = ">=2" },
|
||||
|
@ -977,7 +980,7 @@ requires-dist = [
|
|||
{ name = "sphinxcontrib-video", marker = "extra == 'docs'" },
|
||||
{ name = "sqlite-vec", marker = "extra == 'test'" },
|
||||
{ name = "termcolor" },
|
||||
{ name = "tiktoken", marker = "extra == 'test'" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tomli", marker = "extra == 'docs'" },
|
||||
{ name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||
{ name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||
|
@ -985,10 +988,11 @@ requires-dist = [
|
|||
{ name = "types-setuptools", marker = "extra == 'dev'" },
|
||||
{ name = "uvicorn", marker = "extra == 'dev'" },
|
||||
]
|
||||
provides-extras = ["dev", "test", "docs", "codegen"]
|
||||
|
||||
[[package]]
|
||||
name = "llama-stack-client"
|
||||
version = "0.1.4"
|
||||
version = "0.1.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
|
@ -1005,9 +1009,9 @@ dependencies = [
|
|||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/71/6b/0c9900bcefe683b1186c272f372ac643ebd307db9efa95fa2c4418e207b3/llama_stack_client-0.1.4.tar.gz", hash = "sha256:539ff9b8c40272d4f3b023605aff9b70e66958b6bd952a04f9e9a5b2bfde00dd", size = 260958 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/48/70ffdc7ab655234794e9559de9b1776b39610c09aaee8d3bc74bfbd570b4/llama_stack_client-0.1.6.tar.gz", hash = "sha256:92c6c55c3281839e690df7bfc289c36a5dde0f491574bbdb6b8b665dc3d5a16c", size = 264874 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/00/56d7699354677e584610d5457baf09b0fde7ca71946532ba0f867d5e47c2/llama_stack_client-0.1.4-py3-none-any.whl", hash = "sha256:5034e7b3aac099a3ad88868b3ba1d2ba19285151ec40776ceda18e500b866a8e", size = 369327 },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/51/1102914f819cf4412a5c9fd3f7dcc28175608e5f01ee164885972c3ec30b/llama_stack_client-0.1.6-py3-none-any.whl", hash = "sha256:708e20630d4e97a1cb03a19b933f4da6748cc857fe170998c392cf0f30f0f4c7", size = 373941 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue