Merge remote-tracking branch 'origin/main' into math_500

This commit is contained in:
Botao Chen 2025-03-09 22:40:53 -07:00
commit 599873e485
92 changed files with 23070 additions and 4653 deletions

View file

@ -0,0 +1,9 @@
---
description: General rules always applicable across the project
globs:
alwaysApply: true
---
# Style
- Comments must add value to code. Don't write filler comments explaining what you are doing next; they just add noise.
- Add a comment to clarify surprising behavior which would not be obvious. Good variable naming and clear code organization is more important.

8
.github/dependabot.yml vendored Normal file
View file

@ -0,0 +1,8 @@
# GitHub Dependabot configuration
version: 2
updates:
# Enable version updates for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/" # Will use the default workflow location of `.github/workflows`
schedule:
interval: "daily"

View file

@ -310,7 +310,7 @@ jobs:
- name: "PR - Upload Test Summary"
id: pr_test_summary_upload
if: github.event_name == 'pull_request_target'
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: test-summary
path: test-summary.md
@ -320,7 +320,7 @@ jobs:
- name: "PR - Update comment"
id: pr_update_comment
if: github.event_name == 'pull_request_target'
uses: thollander/actions-comment-pull-request@v2
uses: thollander/actions-comment-pull-request@v3
with:
filePath: test-summary.md

View file

@ -12,12 +12,14 @@ on:
- main
paths:
- 'docs/**'
- 'pyproject.toml'
- '.github/workflows/update-readthedocs.yml'
pull_request:
branches:
- main
paths:
- 'docs/**'
- 'pyproject.toml'
- '.github/workflows/update-readthedocs.yml'
jobs:

1
.gitignore vendored
View file

@ -20,3 +20,4 @@ _build
docs/src
pyrightconfig.json
venv/
pytest-report.xml

0
.gitmodules vendored
View file

View file

@ -15,10 +15,6 @@ repos:
- id: end-of-file-fixer
exclude: '^(.*\.svg)$'
# Temporarily disabling this
# - id: no-commit-to-branch
# args: ['--branch=main']
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.4
hooks:
@ -68,12 +64,6 @@ repos:
- pydantic
pass_filenames: false
# - repo: https://github.com/jsh9/pydoclint
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
# hooks:
# - id: pydoclint
# args: [--config=pyproject.toml]
# - repo: https://github.com/tcort/markdown-link-check
# rev: v3.11.2
# hooks:

File diff suppressed because it is too large Load diff

View file

@ -5,3 +5,4 @@ include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/*.yaml
include llama_stack/providers/tests/test_cases/inference/*.json
include llama_stack/models/llama/*/*.md

File diff suppressed because it is too large Load diff

View file

@ -127,15 +127,11 @@ MCP tools require:
## Adding Custom Tools
When you want to use tools other than the built-in tools, you can implement a python function and decorate it with `@client_tool`.
When you want to use tools other than the built-in tools, you just need to implement a python function with a docstring. The content of the docstring will be used to describe the tool and the parameters and passed
along to the generative model.
To define a custom tool, you need to use the `@client_tool` decorator.
```python
from llama_stack_client.lib.agents.client_tool import client_tool
# Example tool definition
@client_tool
def my_tool(input: int) -> int:
"""
Runs my awesome tool.

View file

@ -24,6 +24,56 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
- Associated with `Benchmark` resource.
## Open-benchmark Eval
### List of open-benchmarks Llama Stack support
Llama stack pre-registers several popular open-benchmarks to easily evaluate model perfomance via CLI.
The list of open-benchmarks we currently support:
- [MMLU-COT](https://arxiv.org/abs/2009.03300) (Measuring Massive Multitask Language Understanding): Benchmark designed to comprehensively evaluate the breadth and depth of a model's academic and professional understanding
- [GPQA-COT](https://arxiv.org/abs/2311.12022) (A Graduate-Level Google-Proof Q&A Benchmark): A challenging benchmark of 448 multiple-choice questions written by domain experts in biology, physics, and chemistry.
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
### Run evaluation on open-benchmarks via CLI
We have built-in functionality to run the supported open-benckmarks using llama-stack-client CLI
#### Spin up Llama Stack server
Spin up llama stack server with 'open-benchmark' template
```
llama stack run llama_stack/templates/open-benchmark/run.yaml
```
#### Run eval CLI
There are 3 necessary inputs to run a benchmark eval
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
- `model-id`: The model id to evaluate on
- `utput_dir`: Path to store the evaluate results
```
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
--model_id <model id to evaluate on> \
--output_dir <directory to store the evaluate results> \
```
You can run
```
llama-stack-client eval run-benchmark help
```
to see the description of all the flags that eval run-benchmark has
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate
evaluation results over there.
## What's Next?
- Check out our Colab notebook on working examples with running benchmark evaluations [here](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb#scrollTo=mxLCsP4MvFqP).

View file

@ -34,7 +34,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors:

View file

@ -13,16 +13,18 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
from docutils import nodes
import tomli # Import tomli for TOML parsing
from pathlib import Path
import requests
import json
# Read version from pyproject.toml
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
pyproject = tomli.load(f)
llama_stack_version = pyproject["project"]["version"]
pypi_url = "https://pypi.org/pypi/llama-stack/json"
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
print(f"{version_tag=}")
# generate the full link including text and url here
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{llama_stack_version}"
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{version_tag}"
llama_stack_version_link = f"<a href='{llama_stack_version_url}'>release notes</a>"
project = "llama-stack"
@ -77,7 +79,7 @@ myst_enable_extensions = [
myst_substitutions = {
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
"llama_stack_version": llama_stack_version,
"llama_stack_version": version_tag,
"llama_stack_version_link": llama_stack_version_link,
}

View file

@ -51,25 +51,25 @@ The main points to consider are:
```
llama stack build -h
usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates]
[--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only]
usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates] [--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only] [--run]
Build a Llama stack container
options:
-h, --help show this help message and exit
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml.
If this argument is not provided, you will be prompted to enter information interactively
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
--list-templates Show the available templates for building a Llama Stack distribution
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will
be prompted to enter information interactively (default: None)
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates (default: None)
--list-templates Show the available templates for building a Llama Stack distribution (default: False)
--image-type {conda,container,venv}
Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config.
Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config. (default:
conda)
--image-name IMAGE_NAME
[for image-type=conda] Name of the conda environment to use for the build. If
not specified, currently active Conda environment will be used. If no Conda
environment is active, you must specify a name.
--print-deps-only Print the dependencies for the stack only, without building the stack
[for image-type=conda|venv] Name of the conda or virtual environment to use for the build. If not specified, currently active Conda environment will be used if
found. (default: None)
--print-deps-only Print the dependencies for the stack only, without building the stack (default: False)
--run Run the stack after building using the same image type, name, and other applicable arguments (default: False)
```
After this step is complete, a file named `<name>-build.yaml` and template file `<name>-run.yaml` will be generated and saved at the output file path specified at the end of the command.
@ -212,8 +212,8 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
```
llama stack run -h
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE]
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
[--image-type {conda,container,venv}]
config
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
@ -223,17 +223,17 @@ positional arguments:
options:
-h, --help show this help message and exit
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
--image-name IMAGE_NAME
Name of the image to run. Defaults to the current conda environment
--disable-ipv6 Disable IPv6 support
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.
Name of the image to run. Defaults to the current conda environment (default: None)
--disable-ipv6 Disable IPv6 support (default: False)
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
--tls-keyfile TLS_KEYFILE
Path to TLS key file for HTTPS
Path to TLS key file for HTTPS (default: None)
--tls-certfile TLS_CERTFILE
Path to TLS certificate file for HTTPS
Path to TLS certificate file for HTTPS (default: None)
--image-type {conda,container,venv}
Image Type used during the build. This can be either conda or container or venv.
Image Type used during the build. This can be either conda or container or venv. (default: conda)
```

View file

@ -17,26 +17,4 @@ $ llama-stack-client configure --endpoint https://llamastack-preview.fireworks.a
$ llama-stack-client models list
```
You will see outputs:
```
$ llama-stack-client models list
+------------------------------+------------------------------+---------------+------------+
| identifier | llama_model | provider_id | metadata |
+==============================+==============================+===============+============+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
```
Checkout the [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python/blob/main/docs/cli_reference.md) repo for more details on how to use the `llama-stack-client` CLI. Checkout [llama-stack-app](https://github.com/meta-llama/llama-stack-apps/tree/main) for examples applications built on top of Llama Stack.

View file

@ -68,6 +68,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
| FAISS | Single Node |
| SQLite-Vec| Single Node |
| Chroma | Hosted and Single Node |
| Milvus | Hosted and Single Node |
| Postgres (PGVector) | Hosted and Single Node |
| Weaviate | Hosted |

View file

@ -2,7 +2,7 @@
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors:
@ -55,5 +55,6 @@ vector_io/sqlite-vec
vector_io/chromadb
vector_io/pgvector
vector_io/qdrant
vector_io/milvus
vector_io/weaviate
```

View file

@ -0,0 +1,31 @@
---
orphan: true
---
# Milvus
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use Milvus in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Milvus.
3. Start storing and querying vectors.
## Installation
You can install Milvus using pymilvus:
```bash
pip install pymilvus
```
## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.

View file

@ -275,18 +275,25 @@ response = client.scoring.score(
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
#### Benchmark Evaluation CLI
Usage: There are 2 inputs necessary for running a benchmark eval
- `eval-task-id`: the identifier associated with the eval task. Each `Benchmark` is parametrized by
- `dataset_id`: the identifier associated with the dataset.
- `List[scoring_function_id]`: list of scoring function identifiers.
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
There are 3 necessary input for running a benchmark eval
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
- `model-id`: The model id to evaluate on
- `utput_dir`: Path to store the evaluate results
```
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
--model_id <model id to evaluate on> \
--output_dir <directory to store the evaluate results> \
```
You can run
```
llama-stack-client eval run-benchmark help
```
to see the description of all the flags to run benckmark eval
```
llama-stack-client eval run_benchmark <eval-task-id> \
--eval-task-config ~/benchmark_config.json \
--visualize
```
In the output log, you can find the path to the file that has your evaluation results. Open that file and you can see you aggrgate
evaluation results over there.
#### Application Evaluation CLI
@ -338,3 +345,52 @@ The `BenchmarkConfig` are user specified config to define:
}
}
```
## Open-benchmark Contributing Guide
### Create the new dataset for your new benchmark
An eval open-benchmark essentially contains 2 parts:
- `raw data`: The raw dataset associated with the benchmark. You typically need to search the original paper that introduces the benchmark and find the canonical dataset (usually hosted on huggingface)
- `prompt template`: How to ask the candidate model to generate the answer (prompt template plays a critical role to the evaluation results). Tyically, you can find the reference prompt template associated with the benchmark in benchmarks author's repo ([exmaple](https://github.com/idavidrein/gpqa/blob/main/prompts/chain_of_thought.txt)) or some other popular open source repos ([example](https://github.com/openai/simple-evals/blob/0a6e8f62e52bc5ae915f752466be3af596caf392/common.py#L14))
To create new open-benmark in llama stack, you need to combine the prompt template and the raw data into the `chat_completion_input` column in the evaluation dataset.
Llama stack enforeces the evaluate dataset schema to contain at least 3 columns:
- `chat_completion_input`: The actual input to the model to run the generation for eval
- `input_query`: The raw input from the raw dataset without the prompt template
- `expected_answer`: The ground truth for scoring functions to calcalate the score from.
You need to write a script [example convert script](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840) to convert the benchmark raw dataset to llama stack format eval dataset and update the dataset to huggingface [example benchmark dataset](https://huggingface.co/datasets/llamastack/mmmu)
### Find scoring function for your new benchmark
The purpose of scoring function is to calculate the score for each example based on candidate model generation result and expected_answer. It also aggregates the scores from all the examples and generate the final evaluate results.
Firstly, you can see if the existing [llama stack scoring functions](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/scoring) can fulfill your need. If not, you need to write a new scoring function based on what benchmark author / other open source repo describe.
### Add new benchmark into template
Firstly, you need to add the evaluation dataset associated with your benchmark under `datasets` resource in the [open-benchmark](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/open-benchmark/run.yaml)
Secondly, you need to add the new benchmark you just created under the `benchmarks` resource in the same template. To add the new benchmark, you need to have
- `benchmark_id`: identifier of the benchmark
- `dataset_id`: identifier of the dataset associated with your benchmark
- `scoring_functions`: scoring function to calculate the score based on generation results and expected_answer
### Test the new benchmark
Spin up llama stack server with 'open-benchmark' templates
```
llama stack run llama_stack/templates/open-benchmark/run.yaml
```
Run eval benchmark CLI with your new benchmark id
```
llama-stack-client eval run-benchmark <new_benchmark_id> \
--model_id <model id to evaluate on> \
--output_dir <directory to store the evaluate results> \
```

View file

@ -199,7 +199,7 @@ AgentToolGroup = register_schema(
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)

View file

@ -40,7 +40,7 @@ class BatchInference(Protocol):
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
@ -50,7 +50,7 @@ class BatchInference(Protocol):
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -278,14 +278,14 @@ ResponseFormat = register_schema(
class CompletionRequest(BaseModel):
model: str
content: InterleavedContent
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class CompletionResponse(MetricResponseMixin):
class CompletionResponse(BaseModel):
"""Response from a completion request.
:param content: The generated completion text
@ -299,7 +299,7 @@ class CompletionResponse(MetricResponseMixin):
@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
class CompletionResponseStreamChunk(BaseModel):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
@ -357,7 +357,7 @@ class ToolConfig(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin):
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
"""Response from a chat completion request.
:param completion_message: The complete response message
@ -444,7 +444,7 @@ class Inference(Protocol):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -467,7 +467,7 @@ class Inference(Protocol):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,

View file

@ -13,7 +13,7 @@ from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent
ROOT_DIR = Path(__file__).parent.parent.parent
class ModelPromptFormat(Subcommand):
@ -44,6 +44,12 @@ class ModelPromptFormat(Subcommand):
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
self.parser.add_argument(
"-l",
"--list",
action="store_true",
help="List all available models",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
import importlib.resources

View file

@ -16,7 +16,7 @@ class StackBuild(Subcommand):
"build",
prog="llama stack build",
description="Build a Llama stack container",
formatter_class=argparse.RawTextHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_build_command)

View file

@ -5,15 +5,15 @@
# the root directory of this source tree.
import argparse
import logging
import os
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="server")
class StackRun(Subcommand):
@ -23,7 +23,7 @@ class StackRun(Subcommand):
"run",
prog="llama stack run",
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_run_cmd)
@ -37,12 +37,13 @@ class StackRun(Subcommand):
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321",
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT.",
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
)
self.parser.add_argument(
"--image-name",
type=str,
default=os.environ.get("CONDA_DEFAULT_ENV"),
help="Name of the image to run. Defaults to the current conda environment",
)
self.parser.add_argument(

View file

@ -32,7 +32,10 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -160,6 +163,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
except StopAsyncIteration:
pass
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return sync_generator()
@ -262,21 +268,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
# Create headers with provider data if available
headers = {}
if self.provider_data:
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
# Use context manager for provider data
with request_provider_data_context(headers):
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
@ -374,9 +384,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
finally:
await end_trace()
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(gen())
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
content=wrapped_gen,
headers={
"Content-Type": "application/json",
},

View file

@ -4,16 +4,62 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextvars
import json
import logging
import threading
from typing import Any, Dict
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
_THREAD_LOCAL = threading.local()
# Context variable for request provider data
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager):
"""Context manager for request provider data"""
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
self.provider_data = provider_data
self.token = None
def __enter__(self):
# Save the current value and set the new one
self.token = _provider_data_var.set(self.provider_data)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value
if self.token is not None:
_provider_data_var.reset(self.token)
T = TypeVar("T")
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve request headers context variables across iterations.
This ensures that context variables set during generator creation are
available during each iteration of the generator, even if the original
context manager has exited.
"""
# Capture the current context value right now
context_value = _provider_data_var.get()
async def wrapper():
while True:
# Set context before each anext() call
_ = _provider_data_var.set(context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
class NeedsRequestProviderData:
@ -26,7 +72,7 @@ class NeedsRequestProviderData:
if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator")
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
val = _provider_data_var.get()
if not val:
return None
@ -36,25 +82,32 @@ class NeedsRequestProviderData:
return provider_data
except Exception as e:
log.error(f"Error parsing provider data: {e}")
return None
def set_request_provider_data(headers: Dict[str, str]):
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
"""Parse provider data from request headers"""
keys = [
"X-LlamaStack-Provider-Data",
"x-llamastack-provider-data",
]
val = None
for key in keys:
val = headers.get(key, None)
if val:
break
if not val:
return
return None
try:
val = json.loads(val)
return json.loads(val)
except json.JSONDecodeError:
log.error("Provider data not encoded as a JSON object!", val)
return
log.error("Provider data not encoded as a JSON object!")
return None
_THREAD_LOCAL.provider_data_header_value = val
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
"""Context manager that sets request provider data from headers for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data)

View file

@ -7,7 +7,6 @@ import importlib
import inspect
from typing import Any, Dict, List, Set, Tuple
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
@ -35,6 +34,7 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
@ -50,6 +50,8 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
pass
@ -163,9 +165,7 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
# Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
deps__=[info.routing_table_api.value],
),
)
}
@ -186,7 +186,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
@ -208,11 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logcat.error("core", p.deprecation_error)
logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logcat.warning(
"core",
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
@ -246,9 +245,10 @@ def sort_providers_by_deps(
)
)
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logcat.debug("core", f" {api_str} => {provider.provider_id}")
logger.debug(f" {api_str} => {provider.provider_id}")
logger.debug("")
return sorted_providers
@ -389,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
@ -65,17 +65,9 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
impl = api_to_routers[api.value](routing_table)
await impl.initialize()
return impl

View file

@ -4,10 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -22,10 +20,6 @@ from llama_stack.apis.eval import (
JobStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -33,14 +27,13 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
SamplingParams,
StopReason,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
@ -49,7 +42,6 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -59,10 +51,10 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
@ -72,15 +64,15 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing VectorIORouter")
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize")
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown")
logger.debug("VectorIORouter.shutdown")
pass
async def register_vector_db(
@ -91,10 +83,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
logcat.debug(
"core",
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
)
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
@ -109,8 +98,7 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
logcat.debug(
"core",
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
@ -121,7 +109,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -131,21 +119,16 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None:
logcat.debug("core", "Initializing InferenceRouter")
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
logger.debug("InferenceRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown")
logger.debug("InferenceRouter.shutdown")
pass
async def register_model(
@ -156,68 +139,16 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
logcat.debug(
"core",
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> List[MetricEvent]:
span = get_current_span()
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return metrics
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
@ -225,11 +156,12 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logcat.debug(
"core",
) -> AsyncGenerator:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -274,59 +206,23 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
return (chunk async for chunk in await provider.chat_completion(**params))
else:
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
return await provider.chat_completion(**params)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
logcat.debug(
"core",
if sampling_params is None:
sampling_params = SamplingParams()
logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
@ -343,41 +239,10 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
prompt_tokens = await self._count_tokens(content)
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
return (chunk async for chunk in await provider.completion(**params))
else:
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
return await provider.completion(**params)
async def embeddings(
self,
@ -387,7 +252,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -407,15 +272,15 @@ class SafetyRouter(Safety):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing SafetyRouter")
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize")
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown")
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
@ -425,7 +290,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
@ -434,7 +299,7 @@ class SafetyRouter(Safety):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
@ -447,15 +312,15 @@ class DatasetIORouter(DatasetIO):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing DatasetIORouter")
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize")
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown")
logger.debug("DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
@ -465,8 +330,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
logcat.debug(
"core",
logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
)
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
@ -477,7 +341,7 @@ class DatasetIORouter(DatasetIO):
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
@ -489,15 +353,15 @@ class ScoringRouter(Scoring):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ScoringRouter")
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize")
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown")
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
@ -506,7 +370,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -527,10 +391,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
logcat.debug(
"core",
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
)
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
@ -548,15 +409,15 @@ class EvalRouter(Eval):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing EvalRouter")
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize")
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown")
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
@ -564,7 +425,7 @@ class EvalRouter(Eval):
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
@ -577,7 +438,7 @@ class EvalRouter(Eval):
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
@ -590,7 +451,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
@ -598,7 +459,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
@ -609,7 +470,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
@ -622,7 +483,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
@ -631,7 +492,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
@ -642,9 +503,8 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logcat.debug(
"core",
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
@ -654,7 +514,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter")
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -663,15 +523,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize")
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown")
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
@ -680,5 +540,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -9,7 +9,6 @@ import asyncio
import functools
import inspect
import json
import logging
import os
import signal
import sys
@ -28,10 +27,12 @@ from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
@ -39,6 +40,7 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
@ -54,8 +56,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logcat.init()
logger = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -142,23 +143,23 @@ def handle_signal(app, signum, _) -> None:
not block the current execution.
"""
signame = signal.Signals(signum).name
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logcat.info("server", f"Shutting down {impl_name}")
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logcat.warning("server", f"No shutdown method for {impl_name}")
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logcat.exception("server", f"Shutdown timeout for {impl_name}")
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e:
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks
loop = asyncio.get_running_loop()
@ -172,7 +173,7 @@ def handle_signal(app, signum, _) -> None:
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logcat.exception("server", "Timeout while waiting for tasks to finish")
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
@ -184,9 +185,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
logcat.info("server", "Starting up")
logger.info("Starting up")
yield
logcat.info("server", "Shutting down")
logger.info("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
@ -204,16 +205,14 @@ async def maybe_await(value):
async def sse_generator(event_gen):
try:
event_gen = await event_gen
async for item in event_gen:
async for item in await event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logcat.info("server", "Generator cancelled")
logger.info("Generator cancelled")
await event_gen.aclose()
except Exception as e:
logcat.exception("server", f"Error in sse_generator: {e}")
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
logger.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
@ -225,18 +224,20 @@ async def sse_generator(event_gen):
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
set_request_provider_data(request.headers)
# Use context manager for request provider data
with request_provider_data_context(request.headers):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
logcat.exception("server", f"Error in {func.__name__}")
raise translate_exception(e) from e
try:
if is_streaming:
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
logger.exception("Error executing endpoint %s", method, route)
raise translate_exception(e) from e
sig = inspect.signature(func)
@ -314,8 +315,6 @@ class ClientVersionMiddleware:
def main():
logcat.init()
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
@ -355,10 +354,10 @@ def main():
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logcat.error("server", f"Error: {str(e)}")
logger.error(f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
@ -366,12 +365,12 @@ def main():
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
logcat.info("server", f"Using config file: {config_file}")
logger.info(f"Using config file: {config_file}")
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
logcat.info("server", f"Using template {args.template} config file: {config_file}")
logger.info(f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
@ -379,10 +378,9 @@ def main():
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
logcat.info("server", "Run configuration:")
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
logcat.info("server", log_line)
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
@ -392,7 +390,7 @@ def main():
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:
logcat.error("server", f"Error: {str(e)}")
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
@ -437,7 +435,7 @@ def main():
)
)
logcat.debug("server", f"serving APIs: {apis_to_serve}")
logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
@ -464,10 +462,10 @@ def main():
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logcat.info("server", f"Listening on {listen_host}:{port}")
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,

View file

@ -11,9 +11,7 @@ import tempfile
from typing import Any, Dict, Optional
import yaml
from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
@ -39,8 +37,11 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core")
class LlamaStack(
VectorDBs,
@ -101,9 +102,8 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
logcat.debug(
"core",
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
logger.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)

View file

@ -100,12 +100,15 @@ esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
$PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args
elif [[ "$env_type" == "container" ]]; then
set -x
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
@ -141,8 +144,6 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \
$env_vars \

182
llama_stack/log.py Normal file
View file

@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from logging.config import dictConfig
from typing import Dict
from rich.console import Console
from rich.errors import MarkupError
from rich.logging import RichHandler
# Default log level
DEFAULT_LOG_LEVEL = logging.INFO
# Predefined categories
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
# Initialize category levels with default level
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
def parse_environment_config(env_config: str) -> Dict[str, int]:
"""
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
Parameters:
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
Returns:
Dict[str, int]: A dictionary mapping categories to their log levels.
"""
category_levels = {}
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
level_value = logging._nameToLevel.get(level)
if level_value is None:
logging.warning(
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
)
continue
if category == "all":
# Apply the log level to all categories and the root logger
for cat in CATEGORIES:
category_levels[cat] = level_value
# Set the root logger's level to the specified level
category_levels["root"] = level_value
elif category in CATEGORIES:
category_levels[category] = level_value
logging.info(f"Setting '{category}' category to level '{level}'.")
else:
logging.warning(f"Unknown logging category: {category}. No changes made.")
except ValueError:
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
return category_levels
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=120)
super().__init__(*args, **kwargs)
def emit(self, record):
"""Override emit to handle markup errors gracefully."""
try:
super().emit(record)
except MarkupError:
original_markup = self.markup
self.markup = False
try:
super().emit(record)
finally:
self.markup = original_markup
def setup_logging(category_levels: Dict[str, int]) -> None:
"""
Configure logging based on the provided category log levels.
Parameters:
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
"""
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
class CategoryFilter(logging.Filter):
"""Ensure category is always present in log records."""
def filter(self, record):
if not hasattr(record, "category"):
record.category = "uncategorized" # Default to 'uncategorized' if no category found
return True
# Determine the root logger's level (default to WARNING if not specified)
root_level = category_levels.get("root", logging.WARNING)
logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"rich": {
"()": logging.Formatter,
"format": log_format,
}
},
"handlers": {
"console": {
"()": CustomRichHandler, # Use our custom handler class
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False,
"show_path": False,
"markup": True,
"filters": ["category_filter"],
}
},
"filters": {
"category_filter": {
"()": CategoryFilter,
}
},
"loggers": {
category: {
"handlers": ["console"],
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
"propagate": False, # Disable propagation to root logger
}
for category in CATEGORIES
},
"root": {
"handlers": ["console"],
"level": root_level, # Set root logger's level dynamically
},
}
dictConfig(logging_config)
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
"""
Returns a logger with the specified name and category.
If no category is provided, defaults to 'uncategorized'.
Parameters:
name (str): The name of the logger (e.g., module or filename).
category (str): The category of the logger (default 'uncategorized').
Returns:
logging.LoggerAdapter: Configured logger with category support.
"""
logger = logging.getLogger(name)
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
return logging.LoggerAdapter(logger, {"category": category})
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
_category_levels.update(parse_environment_config(env_config))
setup_logging(_category_levels)

View file

@ -1,204 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Category-based logging utility for llama-stack.
This module provides a wrapper over the standard Python logging module that supports
categorized logging with environment variable control.
Usage:
from llama_stack import logcat
logcat.info("server", "Starting up...")
logcat.debug("inference", "Processing request...")
Environment variable:
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
Example: "server=debug;inference=warning"
"""
import datetime
import logging
import os
from typing import Dict
# ANSI color codes for terminal output
COLORS = {
"RESET": "\033[0m",
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"DIM": "\033[2m", # Dimmed text
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
}
# Static list of valid categories representing various parts of the Llama Stack
# server codebase
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
_logger = logging.getLogger("llama_stack")
_logger.propagate = False
_default_level = logging.INFO
# Category-level mapping (can be modified by environment variables)
_category_levels: Dict[str, int] = {}
class TerminalStreamHandler(logging.StreamHandler):
def __init__(self, stream=None):
super().__init__(stream)
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
def format(self, record):
record.is_tty = self.is_tty
return super().format(record)
class ColoredFormatter(logging.Formatter):
"""Custom formatter with colors and fixed-width level names"""
def format(self, record):
levelname = record.levelname
# Use only time with milliseconds, not date
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
file_info = f"{record.filename}:{record.lineno}"
# Get category from extra if available
category = getattr(record, "category", None)
msg = record.getMessage()
if getattr(record, "is_tty", False):
color = COLORS.get(levelname, COLORS["RESET"])
if category:
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
f"{file_info:<20} {category_formatted}{msg}"
)
else:
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
f"{file_info:<20} {msg}"
)
else:
if category:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
else:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
return formatted_msg
def init(default_level: int = logging.INFO) -> None:
global _default_level, _category_levels, _logger
_default_level = default_level
_logger.setLevel(logging.DEBUG)
_logger.handlers = [] # Clear existing handlers
# Add our custom handler with the colored formatter
handler = TerminalStreamHandler()
formatter = ColoredFormatter()
handler.setFormatter(formatter)
_logger.addHandler(handler)
for category in CATEGORIES:
_category_levels[category] = default_level
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().lower()
level_value = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"warn": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}.get(level)
if level_value is None:
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
continue
if category == "all":
for cat in CATEGORIES:
_category_levels[cat] = level_value
else:
if category in CATEGORIES:
_category_levels[category] = level_value
else:
_logger.warning(f"Unknown logging category: {category}")
except ValueError:
_logger.warning(f"Invalid logging configuration: {pair}")
def _should_log(level: int, category: str) -> bool:
category = category.lower()
if category not in _category_levels:
return False
category_level = _category_levels[category]
return level >= category_level
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
if _should_log(level, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
def debug(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
def info(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.INFO, "info", category, msg, *args, **kwargs)
def warning(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
def warn(category: str, msg: str, *args, **kwargs) -> None:
warning(category, msg, *args, **kwargs)
def error(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
def critical(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
def exception(category: str, msg: str, *args, **kwargs) -> None:
if _should_log(logging.ERROR, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
_logger.exception(msg, *args, stacklevel=2, **kwargs)

View file

@ -17,7 +17,6 @@ from urllib.parse import urlparse
import httpx
from llama_stack import logcat
from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroup,
@ -67,6 +66,7 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
class ChatAgent(ShieldRunnerMixin):
def __init__(
@ -609,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn
@ -617,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
break
if stop_reason == StopReason.out_of_tokens:
logcat.info("agents", "out of token budget, exiting.")
logger.info("out of token budget, exiting.")
yield message
break
@ -631,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments
yield message
else:
logcat.debug(
"agents",
f"completion message with EOM (iter: {n_iter}): {str(message)}",
)
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message]
else:
logcat.debug(
"agents",
f"completion message (iter: {n_iter}) from the model: {str(message)}",
)
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -983,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
logcat.info("agents", f"Downloading {url} -> {filepath}")
logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
@ -1023,7 +1019,7 @@ async def execute_tool_call_maybe(
else:
name = name.value
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
result = await tool_runtime_api.invoke_tool(
tool_name=name,
kwargs={
@ -1033,7 +1029,7 @@ async def execute_tool_call_maybe(
**toolgroup_args.get(group_name, {}),
},
)
logcat.debug("agents", f"tool call {name} completed with result: {result}")
logger.info(f"tool call {name} completed with result: {result}")
return result

View file

@ -136,11 +136,13 @@ class MetaReferenceInferenceImpl(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
@ -244,7 +246,7 @@ class MetaReferenceInferenceImpl(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -253,6 +255,8 @@ class MetaReferenceInferenceImpl(
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"

View file

@ -53,7 +53,7 @@ class SentenceTransformersInferenceImpl(
self,
model_id: str,
content: str,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -64,7 +64,7 @@ class SentenceTransformersInferenceImpl(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -4,20 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference import supported_inference_models
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""
"""Configuration for the vLLM inference provider.
Note that the model name is no longer part of this static configuration.
You can bind an instance of this provider to a specific model with the
``models.register()`` API call."""
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
@ -26,32 +25,27 @@ class VLLMConfig(BaseModel):
default=4096,
description="Maximum number of tokens to generate.",
)
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
enforce_eager: bool = Field(
default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
)
gpu_memory_utilization: float = Field(
default=0.3,
description=(
"How much GPU memory will be allocated when this provider has finished "
"loading, including memory that was already allocated before loading."
),
)
@classmethod
def sample_run_config(cls):
return {
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
"enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
}
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
import vllm
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
JsonSchemaResponseFormat,
Message,
ToolChoice,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
###############################################################################
# This file contains OpenAI compatibility code that is currently only used
# by the inline vLLM connector. Some or all of this code may be moved to a
# central location at a later date.
def _merge_context_into_content(message: Message) -> Message: # type: ignore
"""
Merge the ``context`` field of a Llama Stack ``Message`` object into
the content field for compabilitiy with OpenAI-style APIs.
Generates a content string that emulates the current behavior
of ``llama_models.llama3.api.chat_format.encode_message()``.
:param message: Message that may include ``context`` field
:returns: A version of ``message`` with any context merged into the
``content`` field.
"""
if not isinstance(message, UserMessage): # Separate type check for linter
return message
if message.context is None:
return message
return UserMessage(
role=message.role,
# Emumate llama_models.llama3.api.chat_format.encode_message()
content=message.content + "\n\n" + message.context,
context=None,
)
def _llama_stack_tools_to_openai_tools(
tools: Optional[List[ToolDefinition]] = None,
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.
"""
if tools is None:
return []
result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
else: # if t.parameters is not None
# Convert the "required" flags to a list of required params
required_params = [k for k, v in t.parameters.items() if v.required]
parameters = {
"type": "object", # Mystery value that shows up in OpenAI docs
"properties": {
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
},
"required": required_params,
}
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
name=t.tool_name, description=t.description, parameters=parameters
)
# Every tool definition is double-boxed in a ChatCompletionToolsParam
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
return result
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
request: ChatCompletionRequest,
) -> dict:
"""
Convert a chat completion request in Llama Stack format into an
equivalent set of arguments to pass to an OpenAI-compatible
chat completions API.
:param request: Bundled request parameters in Llama Stack format.
:returns: Dictionary of key-value pairs to use as an initializer
for a dataclass or to be converted directly to JSON and sent
over the wire.
"""
converted_messages = [
# This mystery async call makes the parent function also be async
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
for m in request.messages
]
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
# Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled.
converted_tool_choice = "none"
if (
request.tool_config is not None
and request.tool_config.tool_choice == ToolChoice.auto
and request.tools is not None
and len(request.tools) > 0
):
converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.
# Use Llama Stack shared code to translate sampling parameters.
sampling_options = get_sampling_options(request.sampling_params)
# get_sampling_options() translates repetition penalties to an option that
# OpenAI's APIs don't know about.
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
# For now, translate repetition penalties into a format that vLLM's broken
# API will handle correctly. Two wrongs make a right...
if "repeat_penalty" in sampling_options:
del sampling_options["repeat_penalty"]
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
# Convert a single response format into four different parameters, per
# the OpenAI spec
guided_decoding_options = dict()
if request.response_format is None:
# Use defaults
pass
elif isinstance(request.response_format, JsonSchemaResponseFormat):
guided_decoding_options["guided_json"] = request.response_format.json_schema
elif isinstance(request.response_format, GrammarResponseFormat):
guided_decoding_options["guided_grammar"] = request.response_format.bnf
else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
logprob_options = dict()
if request.logprobs is not None:
logprob_options["logprobs"] = request.logprobs.top_k
# Marshall together all the arguments for a ChatCompletionRequest
request_options = {
"model": request.model,
"messages": converted_messages,
"tools": converted_tools,
"tool_choice": converted_tool_choice,
"stream": request.stream,
**sampling_options,
**guided_decoding_options,
**logprob_options,
}
return request_options

View file

@ -4,45 +4,71 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import json
import re
import uuid
from typing import AsyncGenerator, List, Optional
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
InterleavedContentItem,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama import sku_list
from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
get_stop_reason,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -50,188 +76,322 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import VLLMConfig
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
log = logging.getLogger(__name__)
# Map from Hugging Face model architecture name to appropriate tool parser.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
# available parsers.
# TODO: Expand this list
CONFIG_TYPE_TO_TOOL_PARSER = {
"GraniteConfig": "granite",
"MllamaConfig": "llama3_json",
"LlamaConfig": "llama3_json",
}
DEFAULT_TOOL_PARSER = "pythonic"
def _random_uuid() -> str:
logger = get_logger(__name__, category="inference")
def _random_uuid_str() -> str:
return str(uuid.uuid4().hex)
def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat], # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams:
"""
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
indicating no constraints.
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
"""
if response_format is None:
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
# value that crashes the executor on some code paths. Use ``None`` instead.
return None
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
# Translate the types that exist and detect if Llama Stack adds new ones.
if isinstance(response_format, JsonSchemaResponseFormat):
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
elif isinstance(response_format, GrammarResponseFormat):
# BNF grammar.
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
# representation of the grammar.
raise TypeError(
"Constrained decoding with BNF grammars is not currently implemented, because the "
"reference implementation does not implement it."
)
else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
def _convert_sampling_params(
sampling_params: Optional[SamplingParams],
response_format: Optional[ResponseFormat], # type: ignore
log_prob_config: Optional[LogProbConfig],
) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
format."""
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
# Stack dataclasses. These defaults are different from vLLM's defaults.
if sampling_params is None:
sampling_params = SamplingParams()
if log_prob_config is None:
log_prob_config = LogProbConfig()
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
if sampling_params.strategy.top_k == 0:
# vLLM treats "k" differently for top-k sampling
vllm_top_k = -1
else:
vllm_top_k = sampling_params.strategy.top_k
else:
vllm_top_k = -1
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
vllm_top_p = sampling_params.strategy.top_p
# Llama Stack only allows temperature with top-P.
vllm_temperature = sampling_params.strategy.temperature
else:
vllm_top_p = 1.0
vllm_temperature = 0.0
# vLLM allows top-p and top-k at the same time.
vllm_sampling_params = vllm.SamplingParams.from_optional(
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
temperature=vllm_temperature,
top_p=vllm_top_p,
top_k=vllm_top_k,
repetition_penalty=sampling_params.repetition_penalty,
guided_decoding=_response_format_to_guided_decoding_params(response_format),
logprobs=log_prob_config.top_k,
)
return vllm_sampling_params
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
"""Inference implementation for vLLM."""
"""
vLLM-based inference model adapter for Llama Stack with support for multiple models.
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
"""
config: VLLMConfig
register_helper: ModelRegistryHelper
model_ids: set[str]
resolved_model_id: str | None
engine: AsyncLLMEngine | None
chat: OpenAIServingChat | None
is_meta_llama_model: bool
def __init__(self, config: VLLMConfig):
self.config = config
logger.info(f"Config is: {self.config}")
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.formatter = ChatFormat(Tokenizer.get_instance())
# The following are initialized when paths are bound to this provider
self.resolved_model_id = None
self.model_ids = set()
self.engine = None
self.chat = None
self.is_meta_llama_model = False
async def initialize(self):
log.info("Initializing vLLM inference provider.")
###########################################################################
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
if "VLLM_NO_USAGE_STATS" not in os.environ:
os.environ["VLLM_NO_USAGE_STATS"] = "1"
async def initialize(self) -> None:
"""
Callback that is invoked through many levels of indirection during provider class
instantiation, sometime after when __init__() is called and before any model registration
methods or methods connected to a REST API are called.
model = resolve_model(self.config.model)
if model is None:
raise ValueError(f"Unknown model {self.config.model}")
It's not clear what assumptions the class can make about the platform's initialization
state here that can't be made during __init__(), and vLLM can't be started until we know
what model it's supposed to be serving, so nothing happens here currently.
"""
pass
if model.huggingface_repo is None:
raise ValueError(f"Model {self.config.model} needs a huggingface repo")
# TODO -- there are a ton of options supported here ...
engine_args = AsyncEngineArgs(
model=model.huggingface_repo,
tokenizer=model.huggingface_repo,
tensor_parallel_size=self.config.tensor_parallel_size,
enforce_eager=self.config.enforce_eager,
gpu_memory_utilization=self.config.gpu_memory_utilization,
guided_decoding_backend="lm-format-enforcer",
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
async def shutdown(self):
"""Shut down the vLLM inference adapter."""
log.info("Shutting down vLLM inference provider.")
if self.engine:
async def shutdown(self) -> None:
logger.info(f"Shutting down inline vLLM inference provider {self}.")
if self.engine is not None:
self.engine.shutdown_background_loop()
self.engine = None
self.chat = None
self.model_ids = set()
self.resolved_model_id = None
###########################################################################
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
# Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model:
"""
Callback that is called when the server associates an inference endpoint
with an inference provider.
Callback that is called when the server associates an inference endpoint with an
inference provider.
:param model: Object that encapsulates parameters necessary for identifying
a specific LLM.
:param model: Object that encapsulates parameters necessary for identifying a specific
LLM.
:returns: The input ``Model`` object. It may or may not be permissible
to change fields before returning this object.
:returns: The input ``Model`` object. It may or may not be permissible to change fields
before returning this object.
"""
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
# The current version of this provided is hard-coded to serve only
# the model specified in the YAML config file.
configured_model = resolve_model(self.config.model)
registered_model = resolve_model(model.model_id)
logger.debug(f"In register_model({model})")
# First attempt to interpret the model coordinates as a Llama model name
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
if resolved_llama_model is not None:
# Load from Hugging Face repo into default local cache dir
model_id_for_vllm = resolved_llama_model.huggingface_repo
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
# Don't set self.is_meta_llama_model until we actually load the model.
is_meta_llama_model = True
else: # if resolved_llama_model is None
# Not a Llama model name. Pass the model id through to vLLM's loader
model_id_for_vllm = model.provider_model_id
is_meta_llama_model = False
if self.resolved_model_id is not None:
if model_id_for_vllm != self.resolved_model_id:
raise ValueError(
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
f"copies of the provider instead."
)
else:
# Model already loaded
logger.info(
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
)
self.model_ids.add(model.model_id)
return model
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
if is_meta_llama_model:
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
self.is_meta_llama_model = is_meta_llama_model
# If we get here, this is the first time registering a model.
# Preload so that the first inference request won't time out.
engine_args = AsyncEngineArgs(
model=model_id_for_vllm,
tokenizer=model_id_for_vllm,
tensor_parallel_size=self.config.tensor_parallel_size,
enforce_eager=self.config.enforce_eager,
gpu_memory_utilization=self.config.gpu_memory_utilization,
max_num_seqs=self.config.max_num_seqs,
max_model_len=self.config.max_model_len,
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
# parser, we need to determine what model architecture is being used. For now, we infer
# that information from what config class the model uses.
low_level_model_config = self.engine.engine.get_model_config()
hf_config = low_level_model_config.hf_config
hf_config_class_name = hf_config.__class__.__name__
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
else:
# No info -- choose a default so we can at least attempt tool
# use.
tool_parser = DEFAULT_TOOL_PARSER
logger.debug(f"{hf_config_class_name=}")
logger.debug(f"{tool_parser=}")
# Wrap the lower-level engine in an OpenAI-compatible chat API
model_config = await self.engine.get_model_config()
self.chat = OpenAIServingChat(
engine_client=self.engine,
model_config=model_config,
models=OpenAIServingModels(
engine_client=self.engine,
model_config=model_config,
base_model_paths=[
# The layer below us will only see resolved model IDs
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
],
),
response_role="assistant",
request_logger=None, # Use default logging
chat_template=None, # Use default template from model checkpoint
enable_auto_tools=True,
tool_parser=tool_parser,
chat_template_content_format="auto",
)
self.resolved_model_id = model_id_for_vllm
self.model_ids.add(model.model_id)
logger.info(f"Finished preloading model: {model_id_for_vllm}")
if configured_model.core_model_id != registered_model.core_model_id:
raise ValueError(
f"Requested model '{model.identifier}' is different from "
f"model '{self.config.model}' that this provider "
f"is configured to serve"
)
return model
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
options = get_sampling_options(sampling_params)
if "repeat_penalty" in options:
options["repetition_penalty"] = options["repeat_penalty"]
del options["repeat_penalty"]
return VLLMSamplingParams(**options)
async def unregister_model(self, model_id: str) -> None:
pass
"""
Callback that is called when the server removes an inference endpoint from an inference
provider.
:param model_id: The same external ID that the higher layers of the stack previously passed
to :func:`register_model()`
"""
if model_id not in self.model_ids:
raise ValueError(
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
)
self.model_ids.remove(model_id)
if len(self.model_ids) == 0:
# Last model was just unregistered. Shut down the connection to vLLM and free up
# resources.
# Note that this operation may cause in-flight chat completion requests on the
# now-unregistered model to return errors.
self.resolved_model_id = None
self.chat = None
self.engine.shutdown_background_loop()
self.engine = None
###########################################################################
# METHODS INHERITED FROM Inference INTERFACE
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
raise NotImplementedError("Completion not implemented for vLLM")
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
if not isinstance(content, str):
raise NotImplementedError("Multimodal input not currently supported")
if sampling_params is None:
sampling_params = SamplingParams()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
assert self.engine is not None
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
logger.debug(f"{converted_sampling_params=}")
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(request, self.config.model)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
if stream:
return self._stream_chat_completion(request, results_generator)
return self._streaming_completion(content, converted_sampling_params)
else:
return await self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> ChatCompletionResponse:
outputs = [o async for o in results_generator]
final_output = outputs[-1]
assert final_output is not None
outputs = final_output.outputs
finish_reason = outputs[-1].stop_reason
choice = OpenAICompatCompletionChoice(
finish_reason=finish_reason,
text="".join([output.text for output in outputs]),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
tokenizer = Tokenizer.get_instance()
async def _generate_and_convert_to_openai_compat():
cur = []
async for chunk in results_generator:
if not chunk.outputs:
log.warning("Empty chunk received")
continue
output = chunk.outputs[-1]
new_tokens = output.token_ids[len(cur) :]
text = tokenizer.decode(new_tokens)
cur.extend(new_tokens)
choice = OpenAICompatCompletionChoice(
finish_reason=output.finish_reason,
text=text,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
streaming_result = None
async for _ in self._streaming_completion(content, converted_sampling_params):
pass
return CompletionResponse(
content=streaming_result.delta,
stop_reason=streaming_result.stop_reason,
logprobs=streaming_result.logprobs,
)
async def embeddings(
self,
@ -242,3 +402,391 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
sampling_params = sampling_params or SamplingParams()
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
# Convert to Llama Stack internal format for consistency
request = ChatCompletionRequest(
model=self.resolved_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if self.is_meta_llama_model:
# Bypass vLLM chat templating layer for Meta Llama models, because the
# templating layer in Llama Stack currently produces better results.
logger.debug(
f"Routing {self.resolved_model_id} chat completion through "
f"Llama Stack's templating layer instead of vLLM's."
)
return await self._chat_completion_for_meta_llama(request)
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
logger.debug(f"Converted request: {chat_completion_request}")
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
logger.debug(f"Result from vLLM: {vllm_result}")
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
raise ValueError(f"Error from vLLM layer: {vllm_result}")
# Return type depends on "stream" argument
if stream:
if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
return self._convert_non_streaming_results(vllm_result)
###########################################################################
# INTERNAL METHODS
async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]:
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
that arguments have been validated upstream.
:param content: Must be a string
:param sampling_params: Paramters from public API's ``response_format``
and ``sampling_params`` arguments, converted to VLLM format
"""
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
# layer, because doing so simplifies the code here.
# The vLLM engine requires a unique identifier for each call to generate()
request_id = _random_uuid_str()
# The vLLM generate() API is streaming-only and returns an async generator.
# The generator returns objects of type vllm.RequestOutput.
results_generator = self.engine.generate(content, sampling_params, request_id)
# Need to know the model's EOS token ID for the conversion code below.
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
# we drill down to the LLMEngine inside the AsyncLLMEngine.
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
llm_engine = self.engine.engine
tokenizer_group = llm_engine.tokenizer
eos_token_id = tokenizer_group.tokenizer.eos_token_id
request_output: vllm.RequestOutput = None
async for request_output in results_generator:
# Check for weird inference failures
if request_output.outputs is None or len(request_output.outputs) == 0:
# This case also should never happen
raise ValueError("Inference produced empty result")
# If we get here, then request_output contains the final output of the generate() call.
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
# us to return one.
output: vllm.CompletionOutput = request_output.outputs[0]
completion_string = output.text
# Convert logprobs from vLLM's format to Llama Stack's format
logprobs = [
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
for logprob_dict in output.logprobs
]
# The final output chunk should be labeled with the reason that the overall generate()
# call completed.
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
if output.stop_reason is None:
stop_reason = None # Still going
elif output.stop_reason == "stop":
stop_reason = StopReason.end_of_turn
elif output.stop_reason == "length":
stop_reason = StopReason.out_of_tokens
elif isinstance(output.stop_reason, int):
# If the model config specifies multiple end-of-sequence tokens, then vLLM
# will return the token ID of the EOS token in the stop_reason field.
stop_reason = StopReason.end_of_turn
else:
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
# some reason.
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
stop_reason = StopReason.end_of_message
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
# provide one if it runs out of tokens.
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta=completion_string,
stop_reason=StopReason.out_of_tokens,
logprobs=logprobs,
)
def _convert_non_streaming_results(
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
) -> ChatCompletionResponse:
"""
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
equivalent Llama Stack object.
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
the fields that aren't currently present in the Llama Stack dataclass.
"""
# There may be multiple responses, but we can only pass through the first one.
if len(vllm_result.choices) == 0:
raise ValueError("Don't know how to convert response object without any responses")
vllm_message = vllm_result.choices[0].message
vllm_finish_reason = vllm_result.choices[0].finish_reason
converted_message = CompletionMessage(
role=vllm_message.role,
# Llama Stack API won't accept None for content field.
content=("" if vllm_message.content is None else vllm_message.content),
stop_reason=get_stop_reason(vllm_finish_reason),
tool_calls=[
ToolCall(
call_id=t.id,
tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments),
)
for t in vllm_message.tool_calls
],
)
# TODO: Convert logprobs
logger.debug(f"Converted message: {converted_message}")
return ChatCompletionResponse(
completion_message=converted_message,
)
async def _chat_completion_for_meta_llama(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method.
"""
formatter = ChatFormat(Tokenizer.get_instance())
# Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response_or_iterator = await self.completion(
model_id=model_id,
content=prompt,
sampling_params=request.sampling_params,
response_format=request.response_format,
stream=request.stream,
logprobs=request.logprobs,
)
if request.stream:
if not isinstance(completion_response_or_iterator, AsyncIterator):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
)
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
# elsif not request.stream:
if not isinstance(completion_response_or_iterator, CompletionResponse):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
)
completion_response: CompletionResponse = completion_response_or_iterator
raw_message = formatter.decode_assistant_message_from_content(
completion_response.content, completion_response.stop_reason
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=completion_response.logprobs,
)
async def _chat_completion_for_meta_llama_streaming(
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
) -> AsyncIterator:
"""
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
method to keep asyncio happy.
"""
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
async def _generate_and_convert_to_openai_compat():
chunk: CompletionResponseStreamChunk # Make Pylance happy
last_text_len = 0
async for chunk in results_iterator:
if chunk.stop_reason == StopReason.end_of_turn:
finish_reason = "stop"
elif chunk.stop_reason == StopReason.end_of_message:
finish_reason = "eos"
elif chunk.stop_reason == StopReason.out_of_tokens:
finish_reason = "length"
else:
finish_reason = None
# Convert delta back to an actual delta
text_delta = chunk.delta[last_text_len:]
last_text_len = len(chunk.delta)
logger.debug(f"{text_delta=}; {finish_reason=}")
yield OpenAICompatCompletionResponse(
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
logger.debug(f"Returning chunk: {chunk}")
yield chunk
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
"""
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
API into a second async iterator that returns Llama Stack objects.
:param vllm_result: Stream of strings that need to be parsed
"""
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
# those chunks and output them at the end.
# This data structure holds the current set of partial tool calls.
index_to_tool_call: Dict[int, Dict] = dict()
# The Llama Stack event stream must always start with a start event. Use an empty one to
# simplify logic below
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
stop_reason=None,
)
)
converted_stop_reason = None
async for chunk_str in vllm_result:
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
# end with "\n\n".
_prefix = "data: "
_suffix = "\n\n"
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
# In between the "data: " and newlines is an event record
data_str = chunk_str[len(_prefix) : -len(_suffix)]
# The end of the stream is indicated with "[DONE]"
if data_str == "[DONE]":
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=converted_stop_reason,
)
)
return
# Anything that is not "[DONE]" should be a JSON record
parsed_chunk = json.loads(data_str)
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs only support
# returning one.
first_choice = parsed_chunk["choices"][0]
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
delta_record = first_choice["delta"]
if "content" in delta_record:
# Text delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=delta_record["content"]),
stop_reason=converted_stop_reason,
)
)
elif "tool_calls" in delta_record:
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
# calls, so buffer until we get a "tool calls" stop reason
for tc in delta_record["tool_calls"]:
index = tc["index"]
if index not in index_to_tool_call:
# First time this tool call is showing up
index_to_tool_call[index] = dict()
tool_call = index_to_tool_call[index]
if "id" in tc:
tool_call["call_id"] = tc["id"]
if "function" in tc:
if "name" in tc["function"]:
tool_call["tool_name"] = tc["function"]["name"]
if "arguments" in tc["function"]:
# Arguments comes in as pieces of a string
if "arguments_str" not in tool_call:
tool_call["arguments_str"] = ""
tool_call["arguments_str"] += tc["function"]["arguments"]
else:
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
if first_choice["finish_reason"] == "tool_calls":
# Special OpenAI code for "tool calls complete".
# Output the buffered tool calls. Llama Stack requires a separate event per tool
# call.
for tool_call_record in index_to_tool_call.values():
# Arguments come in as a string. Parse the completed string.
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
del tool_call_record["arguments_str"]
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
stop_reason=converted_stop_reason,
)
)
# If we get here, we've lost the connection with the vLLM event stream before it ended
# normally.
raise ValueError("vLLM event stream ended without [DONE] message.")

View file

@ -73,7 +73,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
resource = Resource.create(
{
@ -172,8 +171,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if self.meter is None:
return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import json
@ -99,7 +100,7 @@ class FaissIndex(EmbeddingIndex):
await self._save_index()
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k)
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
chunks = []
scores = []

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class MilvusVectorIOConfig(BaseModel):
db_path: str
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {"db_path": "${env.MILVUS_DB_PATH}"}

View file

@ -110,4 +110,22 @@ def available_providers() -> List[ProviderSpec]:
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="milvus",
pip_packages=["pymilvus"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
),
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus"],
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
),
]

View file

@ -72,7 +72,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -83,7 +83,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -92,6 +92,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,

View file

@ -72,11 +72,13 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -112,7 +114,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -121,6 +123,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,

View file

@ -71,7 +71,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self,
model: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -82,7 +82,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -91,6 +91,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -55,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
@ -68,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
pass
def _get_api_key(self) -> str:
if self.config.api_key is not None:
return self.config.api_key.get_secret_value()
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
@ -86,11 +89,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -157,7 +162,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -166,6 +171,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -233,7 +240,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logcat.debug("inference", f"params to fireworks: {params}")
logger.debug(f"params to fireworks: {params}")
return params
async def embeddings(

View file

@ -93,11 +93,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
if content_has_media(content):
raise NotImplementedError("Media is not supported")
@ -188,7 +190,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -197,6 +199,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union
import httpx
from ollama import AsyncClient
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
@ -35,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -59,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import model_entries
log = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
@ -72,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url)
async def initialize(self) -> None:
log.info(f"checking connectivity to Ollama at `{self.url}`...")
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
try:
await self.client.ps()
except httpx.ConnectError as e:
@ -90,11 +90,13 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -145,7 +147,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -154,6 +156,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -210,7 +214,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
"options": sampling_options,
"stream": request.stream,
}
logcat.debug("inference", f"params to ollama: {params}")
logger.debug(f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
@ -286,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding:
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
response = await self.client.list()
else:

View file

@ -81,11 +81,13 @@ class PassthroughInferenceAdapter(Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id)
@ -107,7 +109,7 @@ class PassthroughInferenceAdapter(Inference):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -116,6 +118,8 @@ class PassthroughInferenceAdapter(Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id)

View file

@ -54,7 +54,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
self,
model: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -65,7 +65,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -74,6 +74,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,

View file

@ -74,7 +74,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -85,7 +85,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -94,6 +94,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
tool_config: Optional[ToolConfig] = None,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(

View file

@ -98,11 +98,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -201,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -210,6 +212,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -32,9 +31,8 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
@ -54,6 +52,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
@ -70,11 +70,13 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -91,8 +93,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def _get_client(self) -> Together:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key.get_secret_value()
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
together_api_key = config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
@ -151,7 +154,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -160,6 +163,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -220,7 +225,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logcat.debug("inference", f"params to together: {params}")
logger.debug(f"params to together: {params}")
return params
async def embeddings(

View file

@ -7,7 +7,7 @@ import json
import logging
from typing import AsyncGenerator, List, Optional, Union
from openai import OpenAI
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
@ -229,7 +229,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def shutdown(self) -> None:
pass
@ -241,11 +241,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -264,7 +266,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -273,6 +275,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
@ -296,10 +300,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = client.chat.completions.create(**params)
r = await client.chat.completions.create(**params)
choice = r.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
@ -311,17 +315,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
params = await self._get_params(request)
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
# generator so this wrapper is not necessary?
async def _to_async_generator():
s = client.chat.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
stream = await client.chat.completions.create(**params)
if len(request.tools) > 0:
res = _process_vllm_chat_completion_stream_response(stream)
else:
@ -331,26 +328,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
r = await self.client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self.client.completions.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
res = self.client.models.list()
available_models = [m.id for m in res]
res = await self.client.models.list()
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
@ -406,7 +397,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
assert model.metadata.get("embedding_dimension")
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
response = await self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,

View file

@ -7,7 +7,7 @@
import json
from typing import Any, Dict, List, Optional
import requests
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
@ -31,7 +31,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
async def register_tool(self, tool: Tool) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
@ -77,12 +77,13 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
"q": kwargs["query"],
}
response = requests.get(
url=self.url,
params=params,
headers=headers,
)
response.raise_for_status()
async with httpx.AsyncClient() as client:
response = await client.get(
url=self.url,
params=params,
headers=headers,
)
response.raise_for_status()
return ToolInvocationResult(content=json.dumps(self._clean_response(response.json())))

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional
import requests
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
@ -30,7 +30,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
async def register_tool(self, tool: Tool) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
@ -74,8 +74,13 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
"Accept": "application/json",
}
payload = {"q": kwargs["query"]}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
async with httpx.AsyncClient() as client:
response = await client.get(
url=url,
params=payload,
headers=headers,
)
response.raise_for_status()
results = self._clean_brave_response(response.json())
content_items = "\n".join([str(result) for result in results])
return ToolInvocationResult(

View file

@ -7,7 +7,7 @@
import json
from typing import Any, Dict, List, Optional
import requests
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
@ -30,7 +30,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
async def register_tool(self, tool: Tool) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
@ -66,10 +66,12 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": api_key, "query": kwargs["query"]},
)
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.tavily.com/search",
json={"api_key": api_key, "query": kwargs["query"]},
)
response.raise_for_status()
return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json())))

View file

@ -7,7 +7,7 @@
import json
from typing import Any, Dict, List, Optional
import requests
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
@ -31,7 +31,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
async def register_tool(self, tool: Tool) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
@ -73,11 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
"format": "plaintext",
"output": "json",
}
response = requests.get(
self.url,
params=params,
)
async with httpx.AsyncClient() as client:
response = await client.get(params=params, url=self.url)
response.raise_for_status()
return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json())))
def _clean_wolfram_alpha_response(self, wa_response):

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import MilvusVectorIOConfig
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class MilvusVectorIOConfig(BaseModel):
uri: str
token: Optional[str] = None
consistency_level: str = "Strong"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import hashlib
import logging
import os
import uuid
from typing import Any, Dict, List, Optional, Union
from numpy.typing import NDArray
from pymilvus import MilvusClient
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = logging.getLogger(__name__)
class MilvusIndex(EmbeddingIndex):
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
self.client = client
self.collection_name = collection_name.replace("-", "_")
self.consistency_level = consistency_level
async def delete(self):
if self.client.has_collection(self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not self.client.has_collection(self.collection_name):
self.client.create_collection(
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
consistency_level=self.consistency_level,
)
data = []
for chunk, embedding in zip(chunks, embeddings, strict=False):
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)
data.append(
{
"chunk_id": chunk_id,
"vector": embedding,
"chunk_content": chunk.model_dump(),
}
)
try:
self.client.insert(
self.collection_name,
data=data,
)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search(
collection_name=self.collection_name,
data=[embedding],
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
)
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference
) -> None:
self.config = config
self.cache = {}
self.client = None
self.inference_api = inference_api
async def initialize(self) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
async def shutdown(self) -> None:
self.client.close()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
import litellm
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -47,6 +47,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
@ -74,7 +76,7 @@ class LiteLLMOpenAIMixin(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -85,7 +87,7 @@ class LiteLLMOpenAIMixin(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -94,6 +96,8 @@ class LiteLLMOpenAIMixin(
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -107,8 +111,7 @@ class LiteLLMOpenAIMixin(
)
params = await self._get_params(request)
logcat.debug("inference", f"params to litellm (openai compat): {params}")
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params)

View file

@ -8,14 +8,12 @@ import asyncio
import base64
import io
import json
import logging
import re
from typing import List, Optional, Tuple, Union
import httpx
from PIL import Image as PIL_Image
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
@ -34,6 +32,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
UserMessage,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
ModelFamily,
RawContent,
@ -58,7 +57,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
@ -464,7 +463,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format")
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (

View file

@ -8,9 +8,11 @@ import logging
from datetime import datetime
from typing import List, Optional
from pymongo import MongoClient
from pymongo import AsyncMongoClient
from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig
from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig
log = logging.getLogger(__name__)
@ -30,7 +32,7 @@ class MongoDBKVStoreImpl(KVStore):
"password": self.config.password,
}
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
self.conn = MongoClient(**conn_creds)
self.conn = AsyncMongoClient(**conn_creds)
self.collection = self.conn[self.config.db][self.config.collection_name]
except Exception as e:
log.exception("Could not connect to MongoDB database server")
@ -44,17 +46,17 @@ class MongoDBKVStoreImpl(KVStore):
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
key = self._namespaced_key(key)
update_query = {"$set": {"value": value, "expiration": expiration}}
self.collection.update_one({"key": key}, update_query, upsert=True)
await self.collection.update_one({"key": key}, update_query, upsert=True)
async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key)
query = {"key": key}
result = self.collection.find_one(query, {"value": 1, "_id": 0})
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
return result["value"] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.collection.delete_one({"key": key})
await self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key)
@ -63,4 +65,7 @@ class MongoDBKVStoreImpl(KVStore):
"key": {"$gte": start_key, "$lt": end_key},
}
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
return [doc["value"] for doc in cursor]
result = []
async for doc in cursor:
result.append(doc["value"])
return result

View file

@ -12,11 +12,9 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import unquote
import chardet
import httpx
import numpy as np
from numpy.typing import NDArray
from pypdf import PdfReader
from llama_stack.apis.common.content_types import (
URL,
@ -38,6 +36,8 @@ log = logging.getLogger(__name__)
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)
from pypdf import PdfReader
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
@ -75,6 +75,8 @@ def content_from_data(data_url: str) -> str:
encoding = parts["encoding"]
if not encoding:
import chardet
detected = chardet.detect(data)
encoding = detected["encoding"]

View file

@ -0,0 +1,36 @@
version: '2'
distribution_spec:
description: Distribution for running open benchmarks
providers:
inference:
- remote::openai
- remote::anthropic
- remote::gemini
- remote::groq
- remote::together
vector_io:
- inline::sqlite-vec
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda

View file

@ -0,0 +1,212 @@
version: '2'
image_name: open-benchmark
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY}
vector_io:
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:}
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:localhost}
port: ${env.PGVECTOR_PORT:5432}
db: ${env.PGVECTOR_DB:}
user: ${env.PGVECTOR_USER:}
password: ${env.PGVECTOR_PASSWORD:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
models:
- metadata: {}
model_id: openai/gpt-4o
provider_id: openai
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
- metadata: {}
model_id: anthropic/claude-3-5-sonnet-latest
provider_id: anthropic
provider_model_id: anthropic/claude-3-5-sonnet-latest
model_type: llm
- metadata: {}
model_id: gemini/gemini-1.5-flash
provider_id: gemini
provider_model_id: gemini/gemini-1.5-flash
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets:
- dataset_id: simpleqa
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/simpleqa
metadata:
path: llamastack/simpleqa
name:
split: train
dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
- dataset_id: mmlu_cot
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
metadata:
path: llamastack/mmlu_cot
name: all
split: test
dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
- dataset_id: gpqa_cot
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
metadata:
path: llamastack/gpqa_0shot_cot
name: gpqa_main
split: train
dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
scoring_fns: []
benchmarks:
- benchmark_id: meta-reference-simpleqa
dataset_id: simpleqa
scoring_functions: ["llm-as-judge::405b-simpleqa"]
- benchmark_id: meta-reference-mmlu-cot
dataset_id: mmlu_cot
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
- benchmark_id: meta-reference-gpqa-cot
dataset_id: gpqa_cot
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -15,11 +15,12 @@ providers:
- provider_id: vllm
provider_type: inline::vllm
config:
model: ${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:1}
max_tokens: ${env.MAX_TOKENS:4096}
max_model_len: ${env.MAX_MODEL_LEN:4096}
max_num_seqs: ${env.MAX_NUM_SEQS:4}
enforce_eager: ${env.ENFORCE_EAGER:False}
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:0.7}
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:0.3}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llama_stack"
version = "0.1.5"
version = "0.1.6"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack"
readme = "README.md"
@ -26,7 +26,7 @@ dependencies = [
"httpx",
"huggingface-hub",
"jsonschema",
"llama-stack-client>=0.1.4",
"llama-stack-client>=0.1.6",
"prompt-toolkit",
"python-dotenv",
"pydantic>=2",
@ -34,6 +34,8 @@ dependencies = [
"rich",
"setuptools",
"termcolor",
"tiktoken",
"pillow",
]
[project.optional-dependencies]
@ -63,7 +65,6 @@ test = [
"groq",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
"tiktoken",
"chardet",
"pypdf",
]
@ -79,7 +80,7 @@ docs = [
"sphinxcontrib.mermaid",
"tomli",
]
codegen = ["rich", "pydantic", "jinja2"]
codegen = ["rich", "pydantic", "jinja2>=3.1.6"]
[project.urls]
Homepage = "https://github.com/meta-llama/llama-stack"
@ -136,8 +137,6 @@ ignore = [
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901", # Complexity of the function is too high
# these ignores are from flake8-bugbear; please fix!
"B008",
]
[tool.mypy]
@ -153,7 +152,6 @@ exclude = [
"llama_stack/distribution",
"llama_stack/apis",
"llama_stack/cli",
"llama_stack/logcat.py",
"llama_stack/models",
"llama_stack/strong_typing",
"llama_stack/templates",
@ -165,5 +163,5 @@ module = ["yaml", "fire"]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "llama_stack.distribution.resolver"
follow_imports = "normal" # This will force type checking on this module
module = ["llama_stack.distribution.resolver", "llama_stack.log"]
follow_imports = "normal" # This will force type checking on this module

View file

@ -20,13 +20,14 @@ huggingface-hub==0.29.0
idna==3.10
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
llama-stack-client==0.1.4
llama-stack-client==0.1.6
lxml==5.3.1
markdown-it-py==3.0.0
mdurl==0.1.2
numpy==2.2.3
packaging==24.2
pandas==2.2.3
pillow==11.1.0
prompt-toolkit==3.0.50
pyaml==25.1.0
pycryptodomex==3.21.0
@ -38,6 +39,7 @@ python-dotenv==1.0.1
pytz==2025.1
pyyaml==6.0.2
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rich==13.9.4
rpds-py==0.22.3
@ -45,6 +47,7 @@ setuptools==75.8.0
six==1.17.0
sniffio==1.3.1
termcolor==2.5.0
tiktoken==0.9.0
tqdm==4.67.1
typing-extensions==4.12.2
tzdata==2025.1

75
scripts/gen-changelog.py Normal file
View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import requests
def get_all_releases(token):
url = f"https://api.github.com/repos/meta-llama/llama-stack/releases"
headers = {"Accept": "application/vnd.github.v3+json"}
if token:
headers["Authorization"] = f"token {token}"
response = requests.get(url, headers=headers)
if response.status_code == 200:
return response.json()
else:
raise Exception(
f"Error fetching releases: {response.status_code}, {response.text}"
)
def clean_release_body(body):
"""Remove '## All changes' sections from release notes."""
lines = body.split("\n")
cleaned_lines = []
skip_mode = False
for line in lines:
if line.strip() in [
"## All changes",
"### What's Changed",
"## What's Changed",
"## New Contributors",
]:
skip_mode = True
elif skip_mode and line.startswith("##"):
# Found a new section, stop skipping
skip_mode = False
cleaned_lines.append(line)
elif not skip_mode:
cleaned_lines.append(line)
return "\n".join(cleaned_lines)
def merge_release_notes(output_file, token=None):
releases = get_all_releases(token)
with open(output_file, "w", encoding="utf-8") as md_file:
md_file.write(f"# Changelog\n\n")
for release in releases:
md_file.write(f"# {release['tag_name']}\n")
md_file.write(f"Published on: {release['published_at']}\n\n")
# Clean the release body to remove "## All changes" sections
cleaned_body = clean_release_body(release["body"])
md_file.write(f"{cleaned_body}\n\n")
md_file.write("---\n\n")
print(f"Merged release notes saved to {output_file}")
if __name__ == "__main__":
OUTPUT_FILE = "CHANGELOG.md"
TOKEN = os.getenv("GITHUB_TOKEN")
merge_release_notes(OUTPUT_FILE, TOKEN)

View file

@ -3,4 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -3,4 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -9,7 +9,6 @@ from uuid import uuid4
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.client_tool import client_tool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document
@ -23,7 +22,6 @@ from llama_stack.apis.agents.agents import (
)
@client_tool
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit
@ -41,7 +39,6 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
return -1
@client_tool
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit
@ -276,7 +273,6 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"tools": ["builtin::websearch", client_tool],
"client_tools": [client_tool.get_tool_definition()],
}
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
@ -571,7 +567,10 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
assert expected_kw in response.output_message.content.lower()
@pytest.mark.parametrize("client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)])
@pytest.mark.parametrize(
"client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
)
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
client_tool, expectes_metadata = client_tools
agent_config = {

View file

@ -42,7 +42,7 @@ def provider_data():
for key, value in keymap.items():
if os.environ.get(key):
provider_data[value] = os.environ[key]
return provider_data if len(provider_data) > 0 else None
return provider_data
@pytest.fixture(scope="session")

View file

@ -121,6 +121,9 @@ class RecordableMock:
# Replace temporary file paths created by tempfile.mkdtemp()
key = re.sub(r"/var/folders/[^,'\"\s]+", "<TEMP_FILE>", key)
# Replace /tmp/ paths which are also commonly used for temporary files
key = re.sub(r"/tmp/[^,'\"\s]+", "<TEMP_FILE>", key)
return key
def _save_cache(self):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -3,4 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -27,6 +27,7 @@ def base64_image_url(base64_image_data, image_path):
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
message = {
"role": "user",
@ -55,6 +56,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
message = {
"role": "user",

View file

@ -3,4 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -81,8 +81,6 @@ def test_scoring_functions_register(
def test_scoring_score(llama_stack_client):
register_dataset(llama_stack_client, for_rag=True)
response = llama_stack_client.datasets.list()
assert len(response) == 1
# scoring individual rows
rows = llama_stack_client.datasetio.get_rows_paginated(
@ -119,8 +117,6 @@ def test_scoring_score(llama_stack_client):
def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id):
register_dataset(llama_stack_client, for_rag=True)
response = llama_stack_client.datasets.list()
assert len(response) == 1
# scoring individual rows
rows = llama_stack_client.datasetio.get_rows_paginated(

View file

@ -3,4 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -4,6 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict
from unittest.mock import AsyncMock, patch
import pytest
@ -39,9 +46,41 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings
class MockInferenceAdapterWithSleep:
def __init__(self, sleep_time: int, response: Dict[str, Any]):
self.httpd = None
class DelayedRequestHandler(BaseHTTPRequestHandler):
# ruff: noqa: N802
def do_POST(self):
time.sleep(sleep_time)
self.send_response(code=200)
self.end_headers()
self.wfile.write(json.dumps(response).encode("utf-8"))
self.request_handler = DelayedRequestHandler
def __enter__(self):
httpd = HTTPServer(("", 0), self.request_handler)
self.httpd = httpd
host, port = httpd.server_address
httpd_thread = threading.Thread(target=httpd.serve_forever)
httpd_thread.daemon = True # stop server if this thread terminates
httpd_thread.start()
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
inference_adapter = VLLMInferenceAdapter(config)
return inference_adapter
def __exit__(self, _exc_type, _exc_value, _traceback):
if self.httpd:
self.httpd.shutdown()
self.httpd.server_close()
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.Models.list") as mock_list:
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
yield mock_list
@ -56,10 +95,10 @@ async def vllm_inference_adapter():
@pytest.mark.asyncio
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
mock_openai_models = [
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
]
mock_openai_models_list.return_value = mock_openai_models
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
mock_openai_models_list.return_value = mock_openai_models()
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
@ -141,3 +180,55 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 0
def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop()
loop.set_debug(True)
caplog.set_level(logging.WARNING)
# Log when event loop is blocked for more than 100ms
loop.slow_callback_duration = 0.1
# Sleep for 500ms in our delayed http response
sleep_time = 0.5
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
mock_response = {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1,
"modle": "mock-model",
"choices": [
{
"message": {"content": ""},
"logprobs": None,
"finish_reason": "stop",
"index": 0,
}
],
}
async def do_chat_completion():
await inference_adapter.chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
inference_adapter.model_store = AsyncMock()
inference_adapter.model_store.get_model.return_value = mock_model
loop.run_until_complete(inference_adapter.initialize())
# Clear the logs so far and run the actual chat completion we care about
caplog.clear()
loop.run_until_complete(do_chat_completion())
# Ensure we don't have any asyncio warnings in the captured log
# records from our chat completion call. A message gets logged
# here any time we exceed the slow_callback_duration configured
# above.
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
assert not asyncio_warnings

View file

@ -1,88 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import logging
import os
import unittest
from llama_stack import logcat
class TestLogcat(unittest.TestCase):
def setUp(self):
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
self.log_output = io.StringIO()
self._init_logcat()
def tearDown(self):
if self.original_env is not None:
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
else:
os.environ.pop("LLAMA_STACK_LOGGING", None)
def _init_logcat(self):
logcat.init(default_level=logging.DEBUG)
self.handler = logging.StreamHandler(self.log_output)
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
logcat._logger.handlers.clear()
logcat._logger.addHandler(self.handler)
def test_basic_logging(self):
logcat.info("server", "Info message")
logcat.warning("server", "Warning message")
logcat.error("server", "Error message")
output = self.log_output.getvalue()
self.assertIn("[server] Info message", output)
self.assertIn("[server] Warning message", output)
self.assertIn("[server] Error message", output)
def test_different_categories(self):
# Log messages with different categories
logcat.info("server", "Server message")
logcat.info("inference", "Inference message")
logcat.info("router", "Router message")
output = self.log_output.getvalue()
self.assertIn("[server] Server message", output)
self.assertIn("[inference] Inference message", output)
self.assertIn("[router] Router message", output)
def test_env_var_control(self):
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
self._init_logcat()
# These should be visible based on the environment settings
logcat.debug("server", "Server debug message")
logcat.info("server", "Server info message")
logcat.warning("inference", "Inference warning message")
logcat.error("inference", "Inference error message")
# These should be filtered out based on the environment settings
logcat.debug("inference", "Inference debug message")
logcat.info("inference", "Inference info message")
output = self.log_output.getvalue()
self.assertIn("[server] Server debug message", output)
self.assertIn("[server] Server info message", output)
self.assertIn("[inference] Inference warning message", output)
self.assertIn("[inference] Inference error message", output)
self.assertNotIn("[inference] Inference debug message", output)
self.assertNotIn("[inference] Inference info message", output)
def test_invalid_category(self):
logcat.info("nonexistent", "This message should not be logged")
# Check that the message was not logged
output = self.log_output.getvalue()
self.assertNotIn("[nonexistent] This message should not be logged", output)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import sys
from typing import Any, Dict, Protocol
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
from llama_stack.distribution.datatypes import (
Api,
Provider,
StackRunConfig,
)
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.routers.routers import InferenceRouter
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
"""Dynamically add protocol methods to a class by inspecting the protocol."""
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
# Get the signature
sig = inspect.signature(value)
# Create an async function with the same signature that returns a MagicMock
async def mock_impl(*args, **kwargs):
return MagicMock()
# Set the signature on our mock implementation
mock_impl.__signature__ = sig
# Add it to the class
setattr(cls, name, mock_impl)
class SampleConfig(BaseModel):
foo: str = Field(
default="bar",
description="foo",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
return {
"foo": "baz",
}
class SampleImpl:
def __init__(self, config: SampleConfig, deps: Dict[Api, Any], provider_spec: ProviderSpec = None):
self.__provider_id__ = "test_provider"
self.__provider_spec__ = provider_spec
self.__provider_config__ = config
self.__deps__ = deps
self.foo = config.foo
async def initialize(self):
pass
@pytest.mark.asyncio
async def test_resolve_impls_basic():
# Create a real provider spec
provider_spec = InlineProviderSpec(
api=Api.inference,
provider_type="sample",
module="test_module",
config_class="test_resolver.SampleConfig",
api_dependencies=[],
)
# Create provider registry with our provider
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
run_config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="sample_provider",
provider_type="sample",
config=SampleConfig.sample_run_config(),
)
]
},
)
dist_registry = MagicMock()
mock_module = MagicMock()
impl = SampleImpl(SampleConfig(foo="baz"), {}, provider_spec)
add_protocol_methods(SampleImpl, Inference)
mock_module.get_provider_impl = AsyncMock(return_value=impl)
sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry)
assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter)
table = impls[Api.inference].routing_table
assert isinstance(table, ModelsRoutingTable)
impl = table.impls_by_provider_id["sample_provider"]
assert isinstance(impl, SampleImpl)
assert impl.foo == "baz"
assert impl.__provider_id__ == "sample_provider"
assert impl.__provider_spec__ == provider_spec

26
uv.lock generated
View file

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.10"
resolution-markers = [
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
@ -733,14 +734,14 @@ wheels = [
[[package]]
name = "jinja2"
version = "3.1.5"
version = "3.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markupsafe" },
]
sdist = { url = "https://files.pythonhosted.org/packages/af/92/b3130cbbf5591acf9ade8708c365f3238046ac7cb8ccba6e81abccb0ccff/jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb", size = 244674 }
sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 },
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 },
]
[[package]]
@ -861,7 +862,7 @@ wheels = [
[[package]]
name = "llama-stack"
version = "0.1.5"
version = "0.1.6"
source = { editable = "." }
dependencies = [
{ name = "blobfile" },
@ -870,6 +871,7 @@ dependencies = [
{ name = "huggingface-hub" },
{ name = "jsonschema" },
{ name = "llama-stack-client" },
{ name = "pillow" },
{ name = "prompt-toolkit" },
{ name = "pydantic" },
{ name = "python-dotenv" },
@ -877,6 +879,7 @@ dependencies = [
{ name = "rich" },
{ name = "setuptools" },
{ name = "termcolor" },
{ name = "tiktoken" },
]
[package.optional-dependencies]
@ -923,7 +926,6 @@ test = [
{ name = "opentelemetry-sdk" },
{ name = "pypdf" },
{ name = "sqlite-vec" },
{ name = "tiktoken" },
{ name = "torch", version = "2.6.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
{ name = "torch", version = "2.6.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin'" },
{ name = "torchvision", version = "0.21.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
@ -942,9 +944,9 @@ requires-dist = [
{ name = "groq", marker = "extra == 'test'" },
{ name = "httpx" },
{ name = "huggingface-hub" },
{ name = "jinja2", marker = "extra == 'codegen'" },
{ name = "jinja2", marker = "extra == 'codegen'", specifier = ">=3.1.6" },
{ name = "jsonschema" },
{ name = "llama-stack-client", specifier = ">=0.1.4" },
{ name = "llama-stack-client", specifier = ">=0.1.6" },
{ name = "lm-format-enforcer", marker = "extra == 'test'", specifier = ">=0.10.9" },
{ name = "myst-parser", marker = "extra == 'docs'" },
{ name = "nbval", marker = "extra == 'dev'" },
@ -952,6 +954,7 @@ requires-dist = [
{ name = "openai", marker = "extra == 'test'" },
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
{ name = "pillow" },
{ name = "pre-commit", marker = "extra == 'dev'" },
{ name = "prompt-toolkit" },
{ name = "pydantic", specifier = ">=2" },
@ -977,7 +980,7 @@ requires-dist = [
{ name = "sphinxcontrib-video", marker = "extra == 'docs'" },
{ name = "sqlite-vec", marker = "extra == 'test'" },
{ name = "termcolor" },
{ name = "tiktoken", marker = "extra == 'test'" },
{ name = "tiktoken" },
{ name = "tomli", marker = "extra == 'docs'" },
{ name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" },
{ name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" },
@ -985,10 +988,11 @@ requires-dist = [
{ name = "types-setuptools", marker = "extra == 'dev'" },
{ name = "uvicorn", marker = "extra == 'dev'" },
]
provides-extras = ["dev", "test", "docs", "codegen"]
[[package]]
name = "llama-stack-client"
version = "0.1.4"
version = "0.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@ -1005,9 +1009,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/71/6b/0c9900bcefe683b1186c272f372ac643ebd307db9efa95fa2c4418e207b3/llama_stack_client-0.1.4.tar.gz", hash = "sha256:539ff9b8c40272d4f3b023605aff9b70e66958b6bd952a04f9e9a5b2bfde00dd", size = 260958 }
sdist = { url = "https://files.pythonhosted.org/packages/b5/48/70ffdc7ab655234794e9559de9b1776b39610c09aaee8d3bc74bfbd570b4/llama_stack_client-0.1.6.tar.gz", hash = "sha256:92c6c55c3281839e690df7bfc289c36a5dde0f491574bbdb6b8b665dc3d5a16c", size = 264874 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/00/56d7699354677e584610d5457baf09b0fde7ca71946532ba0f867d5e47c2/llama_stack_client-0.1.4-py3-none-any.whl", hash = "sha256:5034e7b3aac099a3ad88868b3ba1d2ba19285151ec40776ceda18e500b866a8e", size = 369327 },
{ url = "https://files.pythonhosted.org/packages/38/51/1102914f819cf4412a5c9fd3f7dcc28175608e5f01ee164885972c3ec30b/llama_stack_client-0.1.6-py3-none-any.whl", hash = "sha256:708e20630d4e97a1cb03a19b933f4da6748cc857fe170998c392cf0f30f0f4c7", size = 373941 },
]
[[package]]