Merge branch 'meta-llama:main' into fix-ollama-rag

This commit is contained in:
Kai Wu 2025-02-26 15:26:45 -08:00 committed by GitHub
commit f42dc48986
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
77 changed files with 9426 additions and 1046 deletions

3
.gitmodules vendored
View file

@ -1,3 +0,0 @@
[submodule "llama_stack/providers/impls/ios/inference/executorch"]
path = llama_stack/providers/inline/ios/inference/executorch
url = https://github.com/pytorch/executorch

View file

@ -42,8 +42,9 @@ repos:
- black==24.3.0
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.5.26
rev: 0.6.3
hooks:
- id: uv-lock
- id: uv-export
args: [
"--frozen",
@ -51,8 +52,6 @@ repos:
"--no-emit-project",
"--output-file=requirements.txt"
]
files: ^pyproject\.toml$
- id: uv-sync
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
@ -91,8 +90,7 @@ repos:
language: python
pass_filenames: false
require_serial: true
files: ^llama_stack/templates/.*$
files: ^llama_stack/providers/.*/inference/.*/models\.py$
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -3,4 +3,4 @@ include distributions/dependencies.json
include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/*.yaml
include llama_stack/providers/tests/test_cases/*.json
include llama_stack/providers/tests/test_cases/inference/*.json

View file

@ -136,6 +136,42 @@
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"dev": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"autoevals",
@ -171,6 +207,37 @@
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"groq": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"groq",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn"
],
"hf-endpoint": [
"aiohttp",
"aiosqlite",

View file

@ -803,7 +803,7 @@
}
],
"source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
"model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n",
"model_id\n"
]
@ -1688,7 +1688,7 @@
" enable_session_persistence=False,\n",
" toolgroups = [\n",
" {\n",
" \"name\": \"builtin::rag\",\n",
" \"name\": \"builtin::rag/knowledge_search\",\n",
" \"args\" : {\n",
" \"vector_db_ids\": [vector_db_id],\n",
" }\n",

File diff suppressed because one or more lines are too long

View file

@ -7,12 +7,12 @@ Each agent turn follows these key steps:
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
2. **Context Retrieval**:
- If RAG is enabled, the agent queries relevant documents from memory banks
- For new documents, they are first inserted into the memory bank
- Retrieved context is augmented to the user's prompt
- If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent.
- For new documents, they are first inserted into the memory bank.
- Retrieved context is provided to the LLM as a tool response in the message history.
3. **Inference Loop**: The agent enters its main execution loop:
- The LLM receives the augmented prompt (with context and/or previous tool outputs)
- The LLM receives a user prompt (with previous tool outputs)
- The LLM generates a response, potentially with tool calls
- If tool calls are present:
- Tool inputs are safety-checked
@ -40,19 +40,16 @@ sequenceDiagram
S->>E: Input Safety Check
deactivate S
E->>M: 2.1 Query Context
M-->>E: 2.2 Retrieved Documents
loop Inference Loop
E->>L: 3.1 Augment with Context
L-->>E: 3.2 Response (with/without tool calls)
E->>L: 2.1 Augment with Context
L-->>E: 2.2 Response (with/without tool calls)
alt Has Tool Calls
E->>S: Check Tool Input
S->>T: 4.1 Execute Tool
T-->>E: 4.2 Tool Response
E->>L: 5.1 Tool Response
L-->>E: 5.2 Synthesized Response
S->>T: 3.1 Execute Tool
T-->>E: 3.2 Tool Response
E->>L: 4.1 Tool Response
L-->>E: 4.2 Synthesized Response
end
opt Stop Conditions
@ -64,7 +61,7 @@ sequenceDiagram
end
E->>S: Output Safety Check
S->>U: 6. Final Response
S->>U: 5. Final Response
```
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
@ -77,7 +74,10 @@ agent_config = AgentConfig(
instructions="You are a helpful assistant",
# Enable both RAG and tool usage
toolgroups=[
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": ["my_docs"]},
},
"builtin::code_interpreter",
],
# Configure safety

View file

@ -91,7 +91,7 @@ agent_config = AgentConfig(
enable_session_persistence=False,
toolgroups=[
{
"name": "builtin::rag",
"name": "builtin::rag/knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
},

View file

@ -13,6 +13,13 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
from docutils import nodes
import tomli # Import tomli for TOML parsing
from pathlib import Path
# Read version from pyproject.toml
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
pyproject = tomli.load(f)
llama_stack_version = pyproject["project"]["version"]
project = "llama-stack"
copyright = "2025, Meta"
@ -66,6 +73,7 @@ myst_enable_extensions = [
myst_substitutions = {
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
"llama_stack_version": llama_stack_version,
}
suppress_warnings = ['myst.header']

View file

@ -0,0 +1,77 @@
---
orphan: true
---
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# Groq Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-groq` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::groq` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
| vector_io | `inline::faiss` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `GROQ_API_KEY`: Groq API Key (default: ``)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (llama3-8b-8192)`
- `meta-llama/Llama-3.1-8B-Instruct (llama-3.1-8b-instant)`
- `meta-llama/Llama-3-70B-Instruct (llama3-70b-8192)`
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b-versatile)`
- `meta-llama/Llama-3.2-3B-Instruct (llama-3.2-3b-preview)`
### Prerequisite: API Keys
Make sure you have access to a Groq API Key. You can get one by visiting [Groq](https://api.groq.com/).
## Running Llama Stack with Groq
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-groq \
--port $LLAMA_STACK_PORT \
--env GROQ_API_KEY=$GROQ_API_KEY
```
### Via Conda
```bash
llama stack build --template groq --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env GROQ_API_KEY=$GROQ_API_KEY
```

View file

@ -243,7 +243,7 @@ agent_config = AgentConfig(
# Define tools available to the agent
toolgroups=[
{
"name": "builtin::rag",
"name": "builtin::rag/knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
},

View file

@ -1,8 +1,7 @@
```{admonition} News
:class: tip
Llama Stack 0.1.3 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.3) for more details.
Llama Stack {{ llama_stack_version }} is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v{{ llama_stack_version }}) for more details.
```
# Llama Stack

View file

@ -6,7 +6,7 @@ The `llama-stack-client` CLI allows you to query information about the distribut
### `llama-stack-client`
```bash
$ llama-stack-client -h
llama-stack-client -h
usage: llama-stack-client [-h] {models,memory_banks,shields} ...
@ -21,7 +21,7 @@ subcommands:
### `llama-stack-client configure`
```bash
$ llama-stack-client configure
llama-stack-client configure
> Enter the host name of the Llama Stack distribution server: localhost
> Enter the port number of the Llama Stack distribution server: 8321
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
@ -29,7 +29,7 @@ Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:
### `llama-stack-client providers list`
```bash
$ llama-stack-client providers list
llama-stack-client providers list
```
```
+-----------+----------------+-----------------+
@ -55,7 +55,7 @@ $ llama-stack-client providers list
### `llama-stack-client models list`
```bash
$ llama-stack-client models list
llama-stack-client models list
```
```
+----------------------+----------------------+---------------+----------------------------------------------------------+
@ -67,7 +67,7 @@ $ llama-stack-client models list
### `llama-stack-client models get`
```bash
$ llama-stack-client models get Llama3.1-8B-Instruct
llama-stack-client models get Llama3.1-8B-Instruct
```
```
@ -80,7 +80,7 @@ $ llama-stack-client models get Llama3.1-8B-Instruct
```bash
$ llama-stack-client models get Random-Model
llama-stack-client models get Random-Model
Model RandomModel is not found at distribution endpoint host:port. Please ensure endpoint is serving specified model.
```
@ -88,26 +88,26 @@ Model RandomModel is not found at distribution endpoint host:port. Please ensure
### `llama-stack-client models register`
```bash
$ llama-stack-client models register <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
llama-stack-client models register <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
```
### `llama-stack-client models update`
```bash
$ llama-stack-client models update <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
llama-stack-client models update <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
```
### `llama-stack-client models delete`
```bash
$ llama-stack-client models delete <model_id>
llama-stack-client models delete <model_id>
```
## Vector DB Management
### `llama-stack-client vector_dbs list`
```bash
$ llama-stack-client vector_dbs list
llama-stack-client vector_dbs list
```
```
+--------------+----------------+---------------------+---------------+------------------------+
@ -120,7 +120,7 @@ $ llama-stack-client vector_dbs list
### `llama-stack-client vector_dbs register`
```bash
$ llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
```
Options:
@ -131,13 +131,13 @@ Options:
### `llama-stack-client vector_dbs unregister`
```bash
$ llama-stack-client vector_dbs unregister <vector-db-id>
llama-stack-client vector_dbs unregister <vector-db-id>
```
## Shield Management
### `llama-stack-client shields list`
```bash
$ llama-stack-client shields list
llama-stack-client shields list
```
```
@ -150,7 +150,7 @@ $ llama-stack-client shields list
### `llama-stack-client shields register`
```bash
$ llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
```
Options:
@ -163,12 +163,12 @@ Options:
### `llama-stack-client benchmarks list`
```bash
$ llama-stack-client benchmarks list
llama-stack-client benchmarks list
```
### `llama-stack-client benchmarks register`
```bash
$ llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
```
Options:
@ -182,7 +182,7 @@ Options:
## Eval execution
### `llama-stack-client eval run-benchmark`
```bash
$ llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
```
Options:
@ -207,7 +207,7 @@ Example benchmark_config.json:
### `llama-stack-client eval run-scoring`
```bash
$ llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
```
Options:
@ -220,7 +220,7 @@ Options:
### `llama-stack-client toolgroups list`
```bash
$ llama-stack-client toolgroups list
llama-stack-client toolgroups list
```
```
+---------------------------+------------------+------+---------------+
@ -236,14 +236,14 @@ $ llama-stack-client toolgroups list
### `llama-stack-client toolgroups get`
```bash
$ llama-stack-client toolgroups get <toolgroup_id>
llama-stack-client toolgroups get <toolgroup_id>
```
Shows detailed information about a specific toolgroup. If the toolgroup is not found, displays an error message.
### `llama-stack-client toolgroups register`
```bash
$ llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>]
llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>]
```
Options:
@ -254,5 +254,5 @@ Options:
### `llama-stack-client toolgroups unregister`
```bash
$ llama-stack-client toolgroups unregister <toolgroup_id>
llama-stack-client toolgroups unregister <toolgroup_id>
```

View file

@ -6,7 +6,7 @@ This guide will walk you through an end-to-end workflow with Llama Stack with Ol
If you're looking for more specific topics, we have a [Zero to Hero Guide](#next-steps) that covers everything from Tool Calling to Agents in detail. Feel free to skip to the end to explore the advanced topics you're interested in.
> If you'd prefer not to set up a local server, explore our notebook on [tool calling with the Together API](Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb). This notebook will show you how to leverage together.ai's Llama Stack Server API, allowing you to get started with Llama Stack without the need for a locally built and running server.
> If you'd prefer not to set up a local server, explore our notebook on [tool calling with the Together API](Tool_Calling101_Using_Together_Llama_Stack_Server.ipynb). This notebook will show you how to leverage together.ai's Llama Stack Server API, allowing you to get started with Llama Stack without the need for a locally built and running server.
## Table of Contents
1. [Setup and run ollama](#setup-ollama)

View file

@ -343,7 +343,7 @@ def _hf_download(
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
except Exception as e:
parser.error(e)

View file

@ -7,8 +7,6 @@
import argparse
import json
from termcolor import colored
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_list import resolve_model
@ -52,11 +50,12 @@ class ModelDescribe(Subcommand):
)
return
headers = [
"Model",
model.descriptor(),
]
rows = [
(
colored("Model", "white", attrs=["bold"]),
colored(model.descriptor(), "white", attrs=["bold"]),
),
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
("Description", model.description),
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
@ -77,5 +76,6 @@ class ModelDescribe(Subcommand):
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import argparse
import logging
import os
from pathlib import Path
@ -12,6 +13,8 @@ from llama_stack.cli.subcommand import Subcommand
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = logging.getLogger(__name__)
class StackRun(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
@ -75,7 +78,6 @@ class StackRun(Subcommand):
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
@ -85,10 +87,6 @@ class StackRun(Subcommand):
)
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
if not args.config:
self.parser.error("Must specify a config file to run")
return
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
@ -115,11 +113,23 @@ class StackRun(Subcommand):
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
return
print(f"Using run configuration: {config_file}")
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file} is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
except yaml.parser.ParserError as e:
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
try:
config = parse_and_maybe_upgrade_config(config_dict)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
@ -129,18 +139,10 @@ class StackRun(Subcommand):
for env_var in args.env:
if "=" not in env_var:
cprint(
f"Environment variable '{env_var}' must be in KEY=VALUE format",
color="red",
)
return
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
key, value = env_var.split("=", 1) # split on first = only
if not key:
cprint(
f"Environment variable '{env_var}' has empty key",
color="red",
)
return
self.parser.error(f"Environment variable '{env_var}' has empty key")
run_args.extend(["--env", f"{key}={value}"])
if args.tls_keyfile and args.tls_certfile:

View file

@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl("query_from_memory").query(
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)

View file

@ -142,7 +142,7 @@ def handle_signal(app, signum, _) -> None:
not block the current execution.
"""
signame = signal.Signals(signum).name
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
logger.info("Starting up")
yield
print("Shutting down")
logger.info("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
@ -352,10 +352,10 @@ def main():
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
print(f"Setting CLI environment variable {key} => {value}")
logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
print(f"Error: {str(e)}")
logger.error(f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
@ -363,12 +363,12 @@ def main():
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}")
logger.info(f"Using config file: {config_file}")
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}")
logger.info(f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
@ -376,9 +376,9 @@ def main():
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
print("Run configuration:")
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
print(yaml.dump(safe_config, indent=2))
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
@ -387,7 +387,8 @@ def main():
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError:
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
@ -432,7 +433,7 @@ def main():
)
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
logger.info(f"Serving API {api_str}")
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
@ -462,10 +463,10 @@ def main():
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
print(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,

View file

@ -33,7 +33,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v7"
KEY_VERSION = "v8"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -132,7 +132,7 @@ def rag_chat_page():
},
toolgroups=[
dict(
name="builtin::rag",
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
},

View file

@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from pydantic import TypeAdapter
from llama_stack.apis.agents import (
AgentConfig,
@ -62,7 +61,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import (
BuiltinTool,
@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import (
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
@ -84,7 +82,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "query_from_memory"
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
@ -499,111 +497,18 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups:
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
toolgroups.add(tool_group_name)
toolgroup_args[tool_group_name] = toolgroup.args
else:
toolgroups.add(toolgroup)
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs)
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
args = toolgroup_args.get(RAG_TOOL_GROUP, {})
vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.vector_db_id:
vector_db_ids.append(session_info.vector_db_id)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded,
tool_call=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
),
),
)
)
)
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content([msg.content for msg in input_messages]),
vector_db_ids=vector_db_ids,
query_config=query_config,
)
retrieved_context = result.content
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=retrieved_context or [],
metadata=result.metadata,
)
],
),
)
)
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
# append retrieved_context to the last user message
for message in input_messages[::-1]:
if isinstance(message, UserMessage):
message.context = retrieved_context
break
output_attachments = []
n_iter = 0
@ -631,9 +536,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=[
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tools=tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
@ -837,7 +740,7 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(),
completed_at=datetime.now().astimezone().isoformat(),
),
)
)
@ -845,8 +748,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := _interpret_content_as_attachment(result_message.content):
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
@ -858,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
@ -873,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin):
}
)
tool_def_map = {}
tool_name_to_def = {}
tool_to_group = {}
for tool_def in self.agent_config.client_tools:
if tool_def_map.get(tool_def.name, None):
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
@ -893,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin):
},
)
tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
continue
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
raise ValueError(
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier
@ -906,31 +817,61 @@ class ChatAgent(ShieldRunnerMixin):
else:
built_in_type = BuiltinTool(tool_name)
if tool_def_map.get(built_in_type, None):
if tool_name_to_def.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
tool_name_to_def[built_in_type] = ToolDefinition(
tool_name=built_in_type,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue
if tool_def_map.get(tool_def.identifier, None):
if tool_name_to_def.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
if tool_name in (None, tool_def.identifier):
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group
return list(tool_name_to_def.values()), tool_to_group
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
Args:
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
Returns:
A tuple of (tool_type, tool_group, tool_name)
"""
split_names = toolgroup_name_with_maybe_tool_name.split("/")
if len(split_names) == 2:
# e.g. "builtin::rag"
tool_group, tool_name = split_names
else:
tool_group, tool_name = split_names[0], None
return tool_group, tool_name
async def handle_documents(
self,
@ -939,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
@ -1060,7 +1001,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else:
raise ValueError(f"Unsupported URL {url}")
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
content.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
return ToolResponseMessage(
call_id="",

@ -1 +0,0 @@
Subproject commit 9b6d4b4a7b9b8f811bb6b269b0c2ce254e3a0c1b

View file

@ -4,15 +4,25 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import torch
from safetensors.torch import save_file
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG_FNAME,
ADAPTER_MODEL_FNAME,
REPO_ID_FNAME,
SUFFIXES_TO_NOT_COPY,
ModelType,
copy_files,
safe_torch_load,
)
from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG")
@ -75,9 +85,24 @@ class TorchtuneCheckpointer:
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str = "meta",
) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
if checkpoint_format == "meta":
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
self._save_hf_format_checkpoint(model_file_path, state_dict)
else:
raise ValueError(f"Unsupported checkpoint format: {format}")
return str(model_file_path)
def _save_meta_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
adapter_only: bool = False,
) -> None:
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
@ -140,6 +165,76 @@ class TorchtuneCheckpointer:
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
print("model_file_path", str(model_file_path))
def _save_hf_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
) -> None:
# the config.json file contains model params needed for state dict conversion
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
return str(model_file_path)
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
# This json file is produced and saved in the download step.
# contents are {"repo_id": "some_model/some_model_version"}
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")
if training.ADAPTER_KEY in state_dict:
# TODO: saving it "as is" is a requirement because, if we only save with
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
# convert_weights.peft_to_tune. The .pt format is not needed, but
# it is an easy way to distinguish the adapters. Ideally we should save only one.
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME).with_suffix(".pt")
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(state_dict[training.ADAPTER_KEY], output_path)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
state_dict[training.ADAPTER_KEY] = convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=config["num_attention_heads"],
num_kv_heads=config["num_key_value_heads"],
dim=config["hidden_size"],
head_dim=config.get("head_dim", None),
)
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_MODEL_FNAME)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path = output_path.with_suffix(".safetensors")
save_file(
state_dict[training.ADAPTER_KEY],
output_path,
metadata={"format": "pt"},
)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
else:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
if training.ADAPTER_CONFIG in state_dict:
state_dict[training.ADAPTER_CONFIG] = convert_weights.tune_to_peft_adapter_config(
adapter_config=state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_CONFIG_FNAME).with_suffix(".json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
# So its easy to run inference with the model using this epoch's checkpoint
copy_files(
self._checkpoint_dir.parent,
model_file_path,
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
)

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"

View file

@ -117,6 +117,7 @@ class LoraFinetuningSingleDevice:
self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
@ -419,6 +420,7 @@ class LoraFinetuningSingleDevice:
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
checkpoint_format=self._checkpoint_format,
)
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
@ -460,7 +462,7 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}")
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
@ -547,10 +549,11 @@ class LoraFinetuningSingleDevice:
checkpoints.append(checkpoint)
# clean up the memory after training finishes
self._model.to("cpu")
if self._device.type != "cpu":
self._model.to("cpu")
torch.cuda.empty_cache()
del self._model
gc.collect()
torch.cuda.empty_cache()
return (memory_stats, checkpoints)

View file

@ -7,6 +7,7 @@
import json
import os
import sqlite3
import threading
from datetime import datetime
from opentelemetry.sdk.trace import SpanProcessor
@ -17,14 +18,18 @@ class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.conn = None
self._local = threading.local() # Thread-local storage for connections
self.setup_database()
def _get_connection(self) -> sqlite3.Connection:
"""Get the database connection."""
if self.conn is None:
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
return self.conn
def _get_connection(self):
"""Get a thread-local database connection."""
if not hasattr(self._local, "conn"):
try:
self._local.conn = sqlite3.connect(self.conn_string)
except Exception as e:
print(f"Error connecting to SQLite database: {e}")
raise e
return self._local.conn
def setup_database(self):
"""Create the necessary tables if they don't exist."""
@ -168,9 +173,14 @@ class SQLiteSpanProcessor(SpanProcessor):
def shutdown(self):
"""Cleanup any resources."""
if self.conn:
self.conn.close()
self.conn = None
# We can't access other threads' connections, so we just close our own
if hasattr(self._local, "conn"):
try:
self._local.conn.close()
except Exception as e:
print(f"Error closing SQLite connection: {e}")
finally:
del self._local.conn
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""

View file

@ -10,6 +10,8 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -23,6 +25,7 @@ from llama_stack.apis.tools import (
RAGToolRuntime,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked = []
for c in chunks:
picked = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, c in enumerate(chunks):
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
break
picked.append(
TextContentItem(
text=f"id:{metadata['document_id']}; content:{c.content}",
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
)
)
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
return RAGQueryResult(
content=[
TextContentItem(
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
),
*picked,
TextContentItem(
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
content=picked,
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
},
@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
ToolDef(
name="query_from_memory",
description="Retrieve context from memory",
),
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
),
]
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
raise RuntimeError(
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
query = kwargs["query"]
result = await self.query(
content=query,
vector_db_ids=vector_db_ids,
query_config=query_config,
)
return ToolInvocationResult(
content=result.content,
metadata=result.metadata,
)

View file

@ -207,6 +207,33 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="openai",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="anthropic",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="gemini",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
from .config import AnthropicConfig
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: Optional[str] = None
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from .config import AnthropicConfig
from .models import MODEL_ENTRIES
class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class AnthropicConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="API key for Anthropic models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"anthropic/claude-3-5-sonnet-latest",
"anthropic/claude-3-7-sonnet-latest",
"anthropic/claude-3-5-haiku-latest",
]
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="anthropic/voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]

View file

@ -23,8 +23,8 @@ class FireworksImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"api_key": "${env.FIREWORKS_API_KEY}",
"api_key": api_key,
}

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
from .config import GeminiConfig
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: Optional[str] = None
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter
impl = GeminiInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class GeminiConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="API key for Gemini models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from .config import GeminiConfig
from .models import MODEL_ENTRIES
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"gemini/gemini-1.5-flash",
"gemini/gemini-1.5-pro",
]
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="gemini/text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),
]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
@ -18,3 +18,15 @@ class GroqConfig(BaseModel):
default=None,
description="The Groq API key",
)
url: str = Field(
default="https://api.groq.com",
description="The URL for the Groq AI server",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.groq.com",
"api_key": "${env.GROQ_API_KEY}",
}

View file

@ -29,17 +29,10 @@ from llama_stack.apis.inference import (
ToolConfig,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
build_model_entry,
)
from .groq_utils import (
@ -47,33 +40,7 @@ from .groq_utils import (
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
_MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_hf_repo_model_entry(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]
from .models import _MODEL_ENTRIES
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_model_entry
_MODEL_ENTRIES = [
build_model_entry(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_model_entry(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_model_entry(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]

View file

@ -40,6 +40,10 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
)
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
@ -47,8 +51,6 @@ from .models import _MODEL_ENTRIES
from .openai_utils import (
convert_chat_completion_request,
convert_completion_request,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_openai_completion_choice,
convert_openai_completion_stream,
)
@ -201,7 +203,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
if stream:
return convert_openai_chat_completion_stream(response)
return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False)
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])

View file

@ -4,249 +4,36 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import warnings
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
SystemMessage,
TokenLogProbs,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
GreedySamplingStrategy,
StopReason,
ToolCall,
ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
from llama_stack.providers.utils.inference.openai_compat import (
_convert_openai_finish_reason,
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
)
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition:
tool_name: str | BuiltinTool
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
ToolParamDefinition:
param_type: str
description: Optional[str]
required: Optional[bool]
default: Optional[Any]
OpenAI spec -
{
"type": "function",
"function": {
"name": tool_name,
"description": description,
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_type,
"description": description,
"default": default,
},
...
},
"required": [param_name, ...],
},
},
}
"""
out = {
"type": "function",
"function": {},
}
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
else:
function.update(name=tool.tool_name)
if tool.description:
function.update(description=tool.description)
if tool.parameters:
parameters = {
"type": "object",
"properties": {},
}
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type}
if param.description:
properties[param_name].update(description=param.description)
if param.default:
properties[param_name].update(default=param.default)
if param.required:
required.append(param_name)
if required:
parameters.update(required=required)
function.update(parameters=parameters)
return out
async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
# users can supply a dict instead of a Message object, we'll
# convert it to a Message object and proceed with some type safety.
if isinstance(message, dict):
if "role" not in message:
raise ValueError("role is required in message")
if message["role"] == "user":
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "tool":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)
else:
raise ValueError(f"Unsupported message role: {message['role']}")
# Map Llama Stack spec to OpenAI spec -
# str -> str
# {"type": "text", "text": ...} -> {"type": "text", "text": ...}
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...]
async def _convert_user_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str) or isinstance(content, TextContentItem):
return content
elif isinstance(content, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
type="image_url",
)
elif isinstance(content, List):
return [await _convert_user_message_content(item) for item in content]
else:
raise ValueError(f"Unsupported content type: {type(content)}")
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=await _convert_user_message_content(message.content),
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=message.content,
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
],
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=message.content,
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=message.content,
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return out
async def convert_chat_completion_request(
request: ChatCompletionRequest,
n: int = 1,
@ -281,7 +68,7 @@ async def convert_chat_completion_request(
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
messages=[await _convert_message(message) for message in request.messages],
messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
stream=request.stream,
n=n,
extra_body=dict(nvext=nvext),
@ -296,7 +83,7 @@ async def convert_chat_completion_request(
nvext.update(guided_json=request.response_format.json_schema)
if request.tools:
payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools])
payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools])
if request.tool_config.tool_choice:
payload.update(
tool_choice=request.tool_config.tool_choice.value
@ -329,239 +116,6 @@ async def convert_chat_completion_request(
return payload
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls", ...]
- stop: model hit a natural stop point or a provided stop sequence
- length: maximum number of tokens specified in the request was reached
- tool_calls: model called a tool
->
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# TODO(mf): are end_of_turn and end_of_message semantics correct?
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
OpenAI ChatCompletionMessageToolCall:
id: str
function: Function
type: Literal["function"]
OpenAI Function:
arguments: str
name: str
->
ToolCall:
call_id: str
tool_name: str
arguments: Dict[str, ...]
"""
if not tool_calls:
return [] # CompletionMessage tool_calls is not optional
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
)
for call in tool_calls
]
def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
OpenAI ChoiceLogprobs:
content: Optional[List[ChatCompletionTokenLogprob]]
OpenAI ChatCompletionTokenLogprob:
token: str
logprob: float
top_logprobs: List[TopLogprob]
OpenAI TopLogprob:
token: str
logprob: float
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
"""
if not logprobs:
return None
return [
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
for content in logprobs.content
]
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
"""
Convert an OpenAI Choice into a ChatCompletionResponse.
OpenAI Choice:
message: ChatCompletionMessage
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
OpenAI ChatCompletionMessage:
role: Literal["assistant"]
content: Optional[str]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
->
ChatCompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
"error in server response: finish_reason not found"
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
async def convert_openai_chat_completion_stream(
stream: AsyncStream[OpenAIChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
# we implement NIM specific semantics, the main difference from OpenAI
# is that tool_calls are always produced as a complete call. there is no
# intermediate / partial tool call streamed. because of this, we can
# simplify the logic and not concern outselves with parse_status of
# started/in_progress/failed. we can always assume success.
#
# a stream of ChatCompletionResponseStreamChunk consists of
# 0. a start event
# 1. zero or more progress events
# - each progress event has a delta
# - each progress event may have a stop_reason
# - each progress event may have logprobs
# - each progress event may have tool_calls
# if a progress event has tool_calls,
# it is fully formed and
# can be emitted with a parse_status of success
# 2. a complete event
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
# if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately
if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first
if choice.delta.content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
# NIM only produces fully formed tool calls, so we can assume success
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
def convert_completion_request(
request: CompletionRequest,
n: int = 1,

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
from .config import OpenAIConfig
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: Optional[str] = None
async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter
impl = OpenAIInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class OpenAIConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="API key for OpenAI models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openai/chatgpt-4o-latest",
]
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="openai/text-embedding-3-small",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1536, "context_length": 8192},
),
ProviderModelEntry(
provider_model_id="openai/text-embedding-3-large",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 3072, "context_length": 8192},
),
]

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass

View file

@ -169,7 +169,7 @@ async def _process_vllm_chat_completion_stream_response(
args = {} if not args_str else json.loads(args_str)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
if args is not None:
if args:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
@ -183,7 +183,7 @@ async def _process_vllm_chat_completion_stream_response(
),
)
)
else:
elif args_str:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,

View file

@ -46,9 +46,10 @@ def pytest_generate_tests(metafunc):
if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model):
params.append(pytest.param(model, id=model))
print(f"params: {params}")
if not params:
model = metafunc.config.getoption("--inference-model")
params = [pytest.param(model, id="")]
params = [pytest.param(model, id=model)]
metafunc.parametrize(
"inference_model",

View file

@ -197,7 +197,7 @@ def inference_nvidia() -> ProviderFixture:
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAConfig().model_dump(),
config=NVIDIAConfig(api_key=get_env_or_fail("NVIDIA_API_KEY")).model_dump(),
)
],
)

View file

@ -27,8 +27,6 @@ from llama_stack.models.llama.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.tests.test_cases.test_case import TestCase
@ -58,28 +56,6 @@ def common_params(inference_model):
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
class TestInference:
# Session scope for asyncio because the tests in this class all
# share the same provider instance.
@ -100,12 +76,20 @@ class TestInference:
assert model_def is not None
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:non_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion(self, inference_model, inference_stack):
async def test_text_completion_non_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
@ -114,12 +98,24 @@ class TestInference:
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
assert tc["expected"] in response.content
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_text_completion_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
content=tc["content"],
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
@ -133,12 +129,20 @@ class TestInference:
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:logprobs_non_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_logprobs(self, inference_model, inference_stack):
async def test_text_completion_logprobs_non_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
@ -154,10 +158,22 @@ class TestInference:
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:logprobs_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_text_completion_logprobs_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
content=tc["content"],
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
@ -180,9 +196,14 @@ class TestInference:
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.parametrize("test_case", ["completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:structured_output",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_structured_output(self, inference_model, inference_stack, test_case):
async def test_text_completion_structured_output(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
class Output(BaseModel):
@ -213,14 +234,20 @@ class TestInference:
assert answer.year_born == expected["year_born"]
assert answer.year_retired == expected["year_retired"]
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
async def test_text_chat_completion_non_streaming(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
messages=messages,
stream=False,
**common_params,
)
@ -230,9 +257,16 @@ class TestInference:
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:structured_output",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(self, inference_model, inference_stack, common_params, test_case):
async def test_text_chat_completion_structured_output(
self, inference_model, inference_stack, common_params, test_case
):
inference_impl, _ = inference_stack
class AnswerFormat(BaseModel):
@ -281,14 +315,22 @@ class TestInference:
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
async def test_text_chat_completion_streaming(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
messages=messages,
stream=True,
**common_params,
)
@ -304,26 +346,28 @@ class TestInference:
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages_tool_calling",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling(
async def test_text_chat_completion_with_tool_calling(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
test_case,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
tools=tc["tools"],
stream=False,
**common_params,
)
@ -339,32 +383,35 @@ class TestInference:
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
assert call.tool_name == tc["tools"][0]["tool_name"]
for name, value in tc["expected"].items():
assert name in call.arguments
assert value in call.arguments[name]
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages_tool_calling",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling_streaming(
async def test_text_chat_completion_with_tool_calling_streaming(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
test_case,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
tools=tc["tools"],
stream=True,
**common_params,
)
@ -397,6 +444,7 @@ class TestInference:
assert isinstance(last.event.delta.tool_call, ToolCall)
call = last.event.delta.tool_call
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
assert call.tool_name == tc["tools"][0]["tool_name"]
for name, value in tc["expected"].items():
assert name in call.arguments
assert value in call.arguments[name]

View file

@ -1,24 +0,0 @@
{
"01": {
"name": "structured output",
"data": {
"notes": "We include context about Michael Jordan in the prompt so that the test is focused on the funtionality of the model and not on the information embedded in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
},
{
"role": "user",
"content": "Please give me information about Michael Jordan."
}
],
"expected": {
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}
}
}
}

View file

@ -1,13 +0,0 @@
{
"01": {
"name": "structured output",
"data": {
"user_input": "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.",
"expected": {
"name": "Michael Jordan",
"year_born": "1963",
"year_retired": "2003"
}
}
}
}

View file

@ -0,0 +1,171 @@
{
"non_streaming_01": {
"data": {
"question": "Which planet do humans live on?",
"expected": "Earth"
}
},
"non_streaming_02": {
"data": {
"question": "Which planet has rings around it with a name starting with letter S?",
"expected": "Saturn"
}
},
"sample_messages": {
"data": {
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What's the weather like today?"
}
]
}
},
"streaming_01": {
"data": {
"question": "What's the name of the Sun in latin?",
"expected": "Sol"
}
},
"streaming_02": {
"data": {
"question": "What is the name of the US captial?",
"expected": "Washington"
}
},
"tool_calling": {
"data": {
"messages": [
{"role": "system", "content": "Pretend you are a weather assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"}
],
"tools": [
{
"tool_name": "get_weather",
"description": "Get the current weather",
"parameters": {
"location": {
"param_type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
}
}
],
"expected": {
"location": "San Francisco, CA"
}
}
},
"sample_messages_tool_calling": {
"data": {
"messages": [
{
"role": "system",
"content": "Pretend you are a weather assistant."
},
{
"role": "user",
"content": "What's the weather like today?"
},
{
"role": "user",
"content": "What's the weather like in San Francisco?"
}
],
"tools": [
{
"tool_name": "get_weather",
"description": "Get the current weather",
"parameters": {
"location": {
"param_type": "string",
"description": "The city and state, e.g. San Francisco, CA",
"required": true
}
}
}
],
"expected": {
"location": "San Francisco"
}
}
},
"structured_output": {
"data": {
"notes": "We include context about Michael Jordan in the prompt so that the test is focused on the funtionality of the model and not on the information embedded in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
},
{
"role": "user",
"content": "Please give me information about Michael Jordan."
}
],
"expected": {
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}
}
},
"tool_calling_tools_absent": {
"data": {
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What pods are in the namespace openshift-lightspeed?"
},
{
"role": "assistant",
"content": "",
"stop_reason": "end_of_turn",
"tool_calls": [
{
"call_id": "1",
"tool_name": "get_object_namespace_list",
"arguments": {
"kind": "pod",
"namespace": "openshift-lightspeed"
}
}
]
},
{
"role": "tool",
"call_id": "1",
"tool_name": "get_object_namespace_list",
"content": "the objects are pod1, pod2, pod3"
}
],
"tools": [
{
"tool_name": "get_object_namespace_list",
"description": "Get the list of objects in a namespace",
"parameters": {
"kind": {
"param_type": "string",
"description": "the type of object",
"required": true
},
"namespace": {
"param_type": "string",
"description": "the name of the namespace",
"required": true
}
}
}
]
}
}
}

View file

@ -0,0 +1,43 @@
{
"sanity": {
"data": {
"content": "Complete the sentence using one word: Roses are red, violets are "
}
},
"non_streaming": {
"data": {
"content": "Micheael Jordan is born in ",
"expected": "1963"
}
},
"streaming": {
"data": {
"content": "Roses are red,"
}
},
"log_probs": {
"data": {
"content": "Complete the sentence: Micheael Jordan is born in "
}
},
"logprobs_non_streaming": {
"data": {
"content": "Micheael Jordan is born in "
}
},
"logprobs_streaming": {
"data": {
"content": "Roses are red,"
}
},
"structured_output": {
"data": {
"user_input": "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.",
"expected": {
"name": "Michael Jordan",
"year_born": "1963",
"year_retired": "2003"
}
}
}
}

View file

@ -9,7 +9,10 @@ import pathlib
class TestCase:
_apis = ["chat_completion", "completion"]
_apis = [
"inference/chat_completion",
"inference/completion",
]
_jsonblob = {}
def __init__(self, name):
@ -17,7 +20,12 @@ class TestCase:
if self._jsonblob == {}:
for api in self._apis:
with open(pathlib.Path(__file__).parent / f"{api}.json", "r") as f:
TestCase._jsonblob.update({f"{api}-{k}": v for k, v in json.load(f).items()})
coloned = api.replace("/", ":")
try:
loaded = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"There is a syntax error in {api}.json: {e}") from e
TestCase._jsonblob.update({f"{coloned}:{k}": v for k, v in loaded.items()})
# loading this test case
tc = self._jsonblob.get(name)
@ -25,7 +33,6 @@ class TestCase:
raise ValueError(f"Test case {name} not found")
# these are the only fields we need
self.name = tc.get("name")
self.data = tc.get("data")
def __getitem__(self, key):

View file

@ -43,7 +43,7 @@ class SentenceTransformerEmbeddingMixin:
)
return EmbeddingsResponse(embeddings=embeddings)
def _load_sentence_transformer_model(self, model: str) -> SentenceTransformer:
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
import litellm
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models.models import Model
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
get_sampling_options,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
Inference,
):
def __init__(self, model_entries) -> None:
self.model_entries = model_entries
ModelRegistryHelper.__init__(self, model_entries)
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
if model_id is None:
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
return model
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError("LiteLLM does not support completion requests")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
params = await self._get_params(request)
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params)
if stream:
return self._stream_chat_completion(response)
else:
return convert_openai_chat_completion_choice(response.choices[0])
async def _stream_chat_completion(
self, response: litellm.ModelResponse
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
async def _stream_generator():
for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"strict": True,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = request.tool_config.tool_choice.value
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
response = litellm.embedding(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = [data["embedding"] for data in response["data"]]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -5,13 +5,58 @@
# the root directory of this source tree.
import json
import logging
from typing import AsyncGenerator, Dict, List, Optional, Union
import warnings
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import ChatCompletionMessageToolCall
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
TextDelta,
ToolCallDelta,
@ -27,13 +72,18 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
Message,
SystemMessage,
TokenLogProbs,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
GreedySamplingStrategy,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
@ -177,6 +227,31 @@ def process_chat_completion_response(
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
choice = response.choices[0]
if choice.finish_reason == "tool_calls":
if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response")
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
completion_message=CompletionMessage(
stop_reason=StopReason.end_of_turn,
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
),
logprobs=None,
)
else:
# Otherwise, return tool calls as normal
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_turn,
# Content is not optional
content="",
),
logprobs=None,
)
# TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
@ -417,6 +492,95 @@ class UnparseableToolCall(BaseModel):
arguments: str = ""
async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
# users can supply a dict instead of a Message object, we'll
# convert it to a Message object and proceed with some type safety.
if isinstance(message, dict):
if "role" not in message:
raise ValueError("role is required in message")
if message["role"] == "user":
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "tool":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)
else:
raise ValueError(f"Unsupported message role: {message['role']}")
# Map Llama Stack spec to OpenAI spec -
# str -> str
# {"type": "text", "text": ...} -> {"type": "text", "text": ...}
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...]
async def _convert_user_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str):
return OpenAIChatCompletionContentPartTextParam(
type="text",
text=content,
)
elif isinstance(content, TextContentItem):
return OpenAIChatCompletionContentPartTextParam(
type="text",
text=content.text,
)
elif isinstance(content, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
type="image_url",
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
)
elif isinstance(content, List):
return [await _convert_user_message_content(item) for item in content]
else:
raise ValueError(f"Unsupported content type: {type(content)}")
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=await _convert_user_message_content(message.content),
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=message.content,
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
],
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=message.content,
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=message.content,
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return out
def convert_tool_call(
tool_call: ChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
@ -439,3 +603,365 @@ def convert_tool_call(
)
return valid_tool_call
def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition:
tool_name: str | BuiltinTool
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
ToolParamDefinition:
param_type: str
description: Optional[str]
required: Optional[bool]
default: Optional[Any]
OpenAI spec -
{
"type": "function",
"function": {
"name": tool_name,
"description": description,
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_type,
"description": description,
"default": default,
},
...
},
"required": [param_name, ...],
},
},
}
"""
out = {
"type": "function",
"function": {},
}
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
else:
function.update(name=tool.tool_name)
if tool.description:
function.update(description=tool.description)
if tool.parameters:
parameters = {
"type": "object",
"properties": {},
}
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type}
if param.description:
properties[param_name].update(description=param.description)
if param.default:
properties[param_name].update(default=param.default)
if param.required:
required.append(param_name)
if required:
parameters.update(required=required)
function.update(parameters=parameters)
return out
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls", ...]
- stop: model hit a natural stop point or a provided stop sequence
- length: maximum number of tokens specified in the request was reached
- tool_calls: model called a tool
->
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# TODO(mf): are end_of_turn and end_of_message semantics correct?
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
OpenAI ChatCompletionMessageToolCall:
id: str
function: Function
type: Literal["function"]
OpenAI Function:
arguments: str
name: str
->
ToolCall:
call_id: str
tool_name: str
arguments: Dict[str, ...]
"""
if not tool_calls:
return [] # CompletionMessage tool_calls is not optional
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
)
for call in tool_calls
]
def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
OpenAI ChoiceLogprobs:
content: Optional[List[ChatCompletionTokenLogprob]]
OpenAI ChatCompletionTokenLogprob:
token: str
logprob: float
top_logprobs: List[TopLogprob]
OpenAI TopLogprob:
token: str
logprob: float
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
"""
if not logprobs:
return None
return [
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
for content in logprobs.content
]
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
"""
Convert an OpenAI Choice into a ChatCompletionResponse.
OpenAI Choice:
message: ChatCompletionMessage
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
OpenAI ChatCompletionMessage:
role: Literal["assistant"]
content: Optional[str]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
->
ChatCompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
"error in server response: finish_reason not found"
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
),
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)),
)
async def convert_openai_chat_completion_stream(
stream: AsyncStream[OpenAIChatCompletionChunk],
enable_incremental_tool_calls: bool,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
stop_reason = None
toolcall_buffer = {}
async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
logprobs = getattr(choice, "logprobs", None)
# if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately
if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first
if choice.delta.content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(logprobs),
)
)
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
if not enable_incremental_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
else:
tool_call = choice.delta.tool_calls[0]
if "name" not in toolcall_buffer:
toolcall_buffer["call_id"] = tool_call.id
toolcall_buffer["name"] = None
toolcall_buffer["content"] = ""
if "arguments" not in toolcall_buffer:
toolcall_buffer["arguments"] = ""
if tool_call.function.name:
toolcall_buffer["name"] = tool_call.function.name
delta = f"{toolcall_buffer['name']}("
if tool_call.function.arguments:
toolcall_buffer["arguments"] += tool_call.function.arguments
delta = toolcall_buffer["arguments"]
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(logprobs),
)
)
if toolcall_buffer:
delta = ")"
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
try:
arguments = json.loads(toolcall_buffer["arguments"])
tool_call = ToolCall(
call_id=toolcall_buffer["call_id"],
tool_name=toolcall_buffer["name"],
arguments=arguments,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
except json.JSONDecodeError:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=ToolCallDelta(
tool_call=toolcall_buffer["content"],
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)

View file

@ -57,17 +57,6 @@ def get_distribution_template() -> DistributionTemplate:
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id,
provider_id="fireworks",
metadata=m.metadata,
model_type=m.model_type,
)
for m in MODEL_ENTRIES
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
@ -82,6 +71,16 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="code-interpreter",
),
]
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_id="fireworks",
model_type=m.model_type,
metadata=m.metadata,
)
for m in MODEL_ENTRIES
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
@ -98,7 +97,7 @@ def get_distribution_template() -> DistributionTemplate:
container_image=None,
template_path=None,
providers=providers,
default_models=default_models,
default_models=default_models + [embedding_model],
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={

View file

@ -93,59 +93,48 @@ models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-8B
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata:
embedding_dimension: 384

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .dev import get_distribution_template # noqa: F401

View file

@ -0,0 +1,36 @@
version: '2'
distribution_spec:
description: Distribution for running e2e tests in CI
providers:
inference:
- remote::openai
- remote::fireworks
- remote::anthropic
- remote::gemini
- inline::sentence-transformers
vector_io:
- inline::sqlite-vec
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda

View file

@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Tuple
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.anthropic.models import MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.gemini.models import MODEL_ENTRIES as GEMINI_MODEL_ENTRIES
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
# in this template, we allow each API key to be optional
providers = [
(
"openai",
OPENAI_MODEL_ENTRIES,
OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"),
),
(
"fireworks",
FIREWORKS_MODEL_ENTRIES,
FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"),
),
(
"anthropic",
ANTHROPIC_MODEL_ENTRIES,
AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"),
),
(
"gemini",
GEMINI_MODEL_ENTRIES,
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
),
]
inference_providers = []
default_models = []
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
for provider_id, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_type=f"remote::{provider_id}",
config=config,
)
)
default_models.extend(
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id,
provider_id=provider_id,
model_type=m.model_type,
metadata=m.metadata,
)
for m in model_entries
)
return inference_providers, default_models
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": [
"remote::openai",
"remote::fireworks",
"remote::anthropic",
"remote::gemini",
"inline::sentence-transformers",
],
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "dev"
vector_io_provider = Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id=embedding_provider.provider_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
inference_providers, default_models = get_inference_providers()
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Distribution for running e2e tests in CI",
container_image=None,
template_path=None,
providers=providers,
default_models=[],
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers + [embedding_provider],
"vector_io": [vector_io_provider],
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"FIREWORKS_API_KEY": (
"",
"Fireworks API Key",
),
"OPENAI_API_KEY": (
"",
"OpenAI API Key",
),
},
)

View file

@ -0,0 +1,263 @@
version: '2'
image_name: dev
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
models:
- metadata: {}
model_id: openai/gpt-4o
provider_id: openai
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: openai/gpt-4o-mini
provider_id: openai
provider_model_id: openai/gpt-4o-mini
model_type: llm
- metadata: {}
model_id: openai/chatgpt-4o-latest
provider_id: openai
provider_model_id: openai/chatgpt-4o-latest
model_type: llm
- metadata:
embedding_dimension: 1536
context_length: 8192
model_id: openai/text-embedding-3-small
provider_id: openai
provider_model_id: openai/text-embedding-3-small
model_type: embedding
- metadata:
embedding_dimension: 3072
context_length: 8192
model_id: openai/text-embedding-3-large
provider_id: openai
provider_model_id: openai/text-embedding-3-large
model_type: embedding
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-8B
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata: {}
model_id: anthropic/claude-3-5-sonnet-latest
provider_id: anthropic
provider_model_id: anthropic/claude-3-5-sonnet-latest
model_type: llm
- metadata: {}
model_id: anthropic/claude-3-7-sonnet-latest
provider_id: anthropic
provider_model_id: anthropic/claude-3-7-sonnet-latest
model_type: llm
- metadata: {}
model_id: anthropic/claude-3-5-haiku-latest
provider_id: anthropic
provider_model_id: anthropic/claude-3-5-haiku-latest
model_type: llm
- metadata:
embedding_dimension: 1024
context_length: 32000
model_id: anthropic/voyage-3
provider_id: anthropic
provider_model_id: anthropic/voyage-3
model_type: embedding
- metadata:
embedding_dimension: 512
context_length: 32000
model_id: anthropic/voyage-3-lite
provider_id: anthropic
provider_model_id: anthropic/voyage-3-lite
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 32000
model_id: anthropic/voyage-code-3
provider_id: anthropic
provider_model_id: anthropic/voyage-code-3
model_type: embedding
- metadata: {}
model_id: gemini/gemini-1.5-flash
provider_id: gemini
provider_model_id: gemini/gemini-1.5-flash
model_type: llm
- metadata: {}
model_id: gemini/gemini-1.5-pro
provider_id: gemini
provider_model_id: gemini/gemini-1.5-pro
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 2048
model_id: gemini/text-embedding-004
provider_id: gemini
provider_model_id: gemini/text-embedding-004
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -6,6 +6,7 @@ distribution_spec:
providers:
inference:
- inline::meta-reference
- remote::ollama
eval:
- inline::meta-reference
scoring:
@ -15,7 +16,6 @@ distribution_spec:
- inline::torchtune
datasetio:
- inline::localfs
- remote::huggingface
telemetry:
- inline::meta-reference
agents:

View file

@ -21,6 +21,10 @@ providers:
max_seq_len: 4096
checkpoint_dir: null
create_distributed_process_group: False
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -34,9 +38,6 @@ providers:
config:
openai_api_key: ${env.OPENAI_API_KEY:}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
@ -47,7 +48,9 @@ providers:
post_training:
- provider_id: torchtune-post-training
provider_type: inline::torchtune
config: {}
config: {
checkpoint_format: huggingface
}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .groq import get_distribution_template # noqa: F401

View file

@ -0,0 +1,29 @@
version: '2'
distribution_spec:
description: Use Groq for running LLM inference
providers:
inference:
- remote::groq
vector_io:
- inline::faiss
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
image_type: conda

View file

@ -0,0 +1,68 @@
---
orphan: true
---
# Groq Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} ({{ model.provider_model_id }})`
{% endfor %}
{% endif %}
### Prerequisite: API Keys
Make sure you have access to a Groq API Key. You can get one by visiting [Groq](https://api.groq.com/).
## Running Llama Stack with Groq
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env GROQ_API_KEY=$GROQ_API_KEY
```
### Via Conda
```bash
llama stack build --template groq --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env GROQ_API_KEY=$GROQ_API_KEY
```

View file

@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ToolGroupInput,
)
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.groq.models import _MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::groq"],
"vector_io": ["inline::faiss"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
],
}
name = "groq"
inference_provider = Provider(
provider_id=name,
provider_type=f"remote::{name}",
config=GroqConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id,
provider_id=name,
)
for m in _MODEL_ENTRIES
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Use Groq for running LLM inference",
docker_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=default_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"GROQ_API_KEY": (
"",
"Groq API Key",
),
},
)

View file

@ -0,0 +1,136 @@
version: '2'
image_name: groq
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/groq/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: groq
provider_model_id: llama3-8b-8192
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: groq
provider_model_id: llama-3.1-8b-instant
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct
provider_id: groq
provider_model_id: llama3-70b-8192
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: groq
provider_model_id: llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: groq
provider_model_id: llama-3.2-3b-preview
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -74,6 +74,7 @@ docs = [
"sphinxcontrib.redoc",
"sphinxcontrib.video",
"sphinxcontrib.mermaid",
"tomli",
]
[project.urls]

View file

@ -96,7 +96,7 @@ def agent_config(llama_stack_client, text_model_id):
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 1.0,
"temperature": 0.0001,
"top_p": 0.9,
},
},
@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
assert "get_boiling_point" in logs_str
def test_rag_agent(llama_stack_client, agent_config):
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config):
**agent_config,
"toolgroups": [
dict(
name="builtin::rag",
name=rag_tool_name,
args={
"vector_db_ids": [vector_db_id],
},
@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config):
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
(
"What `tune` command to use for getting access to Llama3-8B-Instruct ?",
"download",
),
]
for prompt, expected_kw in user_prompts:
response = rag_agent.create_turn(
@ -496,23 +493,36 @@ def test_rag_agent(llama_stack_client, agent_config):
)
# rag is called
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory"
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
# document ids are present in metadata
assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"]
assert expected_kw in response.output_message.content.lower()
assert all(
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
)
if expected_kw:
assert expected_kw in response.output_message.content.lower()
def test_rag_and_code_agent(llama_stack_client, agent_config):
urls = ["chat.rst"]
documents = [
documents = []
documents.append(
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
document_id="nba_wiki",
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
metadata={},
)
for i, url in enumerate(urls)
]
)
documents.append(
Document(
document_id="perplexity_wiki",
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
Srinivas, the CEO, worked at OpenAI as an AI researcher.
Konwinski was among the founding team at Databricks.
Yarats, the CTO, was an AI research scientist at Meta.
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
metadata={},
)
)
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
@ -528,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
**agent_config,
"toolgroups": [
dict(
name="builtin::rag",
name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]},
),
"builtin::code_interpreter",
@ -546,24 +556,34 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
"Here is a csv file, can you describe it?",
[inflation_doc],
"code_interpreter",
"",
),
(
"What are the top 5 topics that were explained? Only list succinct bullet points.",
"when was Perplexity the company founded?",
[],
"query_from_memory",
"knowledge_search",
"2022",
),
(
"when was the nba created?",
[],
"knowledge_search",
"1949",
),
]
for prompt, docs, tool_name in user_prompts:
for prompt, docs, tool_name, expected_kw in user_prompts:
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
documents=docs,
stream=False,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert f"Tool:{tool_name}" in logs_str
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == tool_name
if expected_kw:
assert expected_kw in response.output_message.content.lower()
def test_create_turn_response(llama_stack_client, agent_config):

View file

@ -116,12 +116,14 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
providers = [p for p in client.providers.list() if p.api == "inference"]
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
if text_model_id:
model_ids = [m.identifier for m in client.models.list()]
if text_model_id and text_model_id not in model_ids:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
if vision_model_id:
if vision_model_id and vision_model_id not in model_ids:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_dimension:
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:

View file

@ -19,6 +19,16 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
print(f"Provider: {provider.provider_type} for model {model_id}")
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
@pytest.fixture(scope="session")
def provider_tool_format(inference_provider_type):
return (
@ -28,23 +38,18 @@ def provider_tool_format(inference_provider_type):
)
@pytest.fixture
def get_weather_tool_definition():
return {
"tool_name": "get_weather",
"description": "Get the current weather",
"parameters": {
"location": {
"param_type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
}
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:sanity",
],
)
def test_text_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
def test_text_completion_non_streaming(client_with_models, text_model_id):
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
content=tc["content"],
stream=False,
model_id=text_model_id,
sampling_params={
@ -55,9 +60,18 @@ def test_text_completion_non_streaming(client_with_models, text_model_id):
# assert "blue" in response.content.lower().strip()
def test_text_completion_streaming(client_with_models, text_model_id):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:sanity",
],
)
def test_text_completion_streaming(client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
@ -70,12 +84,21 @@ def test_text_completion_streaming(client_with_models, text_model_id):
assert len(content_str) > 10
def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=text_model_id,
sampling_params={
@ -90,12 +113,21 @@ def test_completion_log_probs_non_streaming(client_with_models, text_model_id, i
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
@ -114,8 +146,15 @@ def test_completion_log_probs_streaming(client_with_models, text_model_id, infer
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.parametrize("test_case", ["completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:structured_output",
],
)
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
class AnswerFormat(BaseModel):
name: str
year_born: str
@ -144,16 +183,17 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te
@pytest.mark.parametrize(
"question,expected",
"test_case",
[
("Which planet do humans live on?", "Earth"),
(
"Which planet has rings around it with a name starting with letter S?",
"Saturn",
),
"inference:chat_completion:non_streaming_01",
"inference:chat_completion:non_streaming_02",
],
)
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
@ -170,13 +210,17 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, q
@pytest.mark.parametrize(
"question,expected",
"test_case",
[
("What's the name of the Sun in latin?", "Sol"),
("What is the name of the US captial?", "Washington"),
"inference:chat_completion:streaming_01",
"inference:chat_completion:streaming_02",
],
)
def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
def test_text_chat_completion_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[{"role": "user", "content": question}],
@ -187,28 +231,34 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, quest
assert expected.lower() in "".join(streamed_content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_non_streaming(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_choice="auto",
tool_prompt_format=provider_tool_format,
tool_prompt_format=tool_prompt_format,
stream=False,
)
# No content is returned for the system message since we expect the
# response to be a tool call
assert response.completion_message.content == ""
# some models can return content for the response in addition to the tool call
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1
assert response.completion_message.tool_calls[0].tool_name == "get_weather"
assert response.completion_message.tool_calls[0].arguments == {"location": "San Francisco, CA"}
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
# Will extract streamed text and separate it from tool invocation content
@ -224,57 +274,80 @@ def extract_tool_invocation_content(response):
return tool_invocation_content
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_streaming(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_choice="auto",
tool_prompt_format=provider_tool_format,
tool_prompt_format=tool_prompt_format,
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
expected_tool_name = tc["tools"][0]["tool_name"]
expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_required(
client_with_models,
text_model_id,
get_weather_tool_definition,
provider_tool_format,
test_case,
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_config={
"tool_choice": "required",
"tool_prompt_format": provider_tool_format,
"tool_prompt_format": tool_prompt_format,
},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
expected_tool_name = tc["tools"][0]["tool_name"]
expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
def test_text_chat_completion_with_tool_choice_none(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
stream=True,
)
@ -282,7 +355,12 @@ def test_text_chat_completion_with_tool_choice_none(
assert tool_invocation_content == ""
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:structured_output",
],
)
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
first_name: str
@ -309,64 +387,24 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize(
"streaming",
"test_case",
[
True,
False,
"inference:chat_completion:tool_calling_tools_absent",
],
)
def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
def test_text_chat_completion_tool_calling_tools_not_in_request(
client_with_models, text_model_id, test_case, streaming
):
tc = TestCase(test_case)
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = {
"model_id": text_model_id,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "What pods are in the namespace openshift-lightspeed?",
},
{
"role": "assistant",
"content": "",
"stop_reason": "end_of_turn",
"tool_calls": [
{
"call_id": "1",
"tool_name": "get_object_namespace_list",
"arguments": {
"kind": "pod",
"namespace": "openshift-lightspeed",
},
}
],
},
{
"role": "tool",
"call_id": "1",
"tool_name": "get_object_namespace_list",
"content": "the objects are pod1, pod2, pod3",
},
],
"tools": [
{
"tool_name": "get_object_namespace_list",
"description": "Get the list of objects in a namespace",
"parameters": {
"kind": {
"param_type": "string",
"description": "the type of object",
"required": True,
},
"namespace": {
"param_type": "string",
"description": "the name of the namespace",
"required": True,
},
},
}
],
"messages": tc["messages"],
"tools": tc["tools"],
"tool_choice": "auto",
"tool_prompt_format": tool_prompt_format,
"stream": streaming,

4
uv.lock generated
View file

@ -1,5 +1,4 @@
version = 1
revision = 1
requires-python = ">=3.10"
resolution-markers = [
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
@ -913,6 +912,7 @@ docs = [
{ name = "sphinxcontrib-mermaid" },
{ name = "sphinxcontrib-redoc" },
{ name = "sphinxcontrib-video" },
{ name = "tomli" },
]
test = [
{ name = "aiosqlite" },
@ -971,13 +971,13 @@ requires-dist = [
{ name = "sphinxcontrib-redoc", marker = "extra == 'docs'" },
{ name = "sphinxcontrib-video", marker = "extra == 'docs'" },
{ name = "termcolor" },
{ name = "tomli", marker = "extra == 'docs'" },
{ name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" },
{ name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" },
{ name = "types-requests", marker = "extra == 'dev'" },
{ name = "types-setuptools", marker = "extra == 'dev'" },
{ name = "uvicorn", marker = "extra == 'dev'" },
]
provides-extras = ["dev", "test", "docs"]
[[package]]
name = "llama-stack-client"