mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-08 00:11:30 +00:00
Merge branch 'main' into fiddlecube-guard
This commit is contained in:
commit
7fdc50a837
324 changed files with 12802 additions and 3145 deletions
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
|
@ -8,4 +8,3 @@
|
|||
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
|
||||
|
||||
[//]: # (## Documentation)
|
||||
[//]: # (- [ ] Added a Changelog entry if the change is significant)
|
||||
|
|
|
|||
|
|
@ -30,10 +30,7 @@ repos:
|
|||
rev: v0.9.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [
|
||||
--fix,
|
||||
--exit-non-zero-on-fix
|
||||
]
|
||||
exclude: ^llama_stack/strong_typing/.*$
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/adamchainz/blacken-docs
|
||||
|
|
@ -47,7 +44,13 @@ repos:
|
|||
rev: 0.5.26
|
||||
hooks:
|
||||
- id: uv-export
|
||||
args: ["--frozen", "--no-hashes", "--no-emit-project"]
|
||||
args: [
|
||||
"--frozen",
|
||||
"--no-hashes",
|
||||
"--no-emit-project",
|
||||
"--output-file=requirements.txt"
|
||||
]
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-sync
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
|
|
|
|||
37
.ruff.toml
37
.ruff.toml
|
|
@ -1,37 +0,0 @@
|
|||
# Suggested config from pytorch that we can adapt
|
||||
lint.select = ["B", "C", "E" , "F" , "N", "W", "B9"]
|
||||
|
||||
line-length = 120
|
||||
|
||||
# C408 ignored because we like the dict keyword argument syntax
|
||||
# E501 is not flexible enough, we're using B950 instead
|
||||
# N812 ignored because import torch.nn.functional as F is PyTorch convention
|
||||
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
|
||||
# E731 allow usage of assigning lambda expressions
|
||||
# E701 let black auto-format statements on one line
|
||||
# E704 let black auto-format statements on one line
|
||||
lint.ignore = [
|
||||
"E203", "E305", "E402", "E501", "E721", "E741", "F405", "F821", "F841",
|
||||
"C408", "E302", "W291", "E303", "N812", "N817", "E731", "E701",
|
||||
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
|
||||
"C901", "C405", "C414", "N803", "N999", "C403", "C416", "B028", "C419", "C401", "B023",
|
||||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
||||
# to line this up with executable bit
|
||||
"EXE001",
|
||||
# random naming hints don't need
|
||||
"N802",
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
"B007", "B008"
|
||||
]
|
||||
|
||||
exclude = [
|
||||
"./.git",
|
||||
"./docs/*",
|
||||
"./build",
|
||||
"./scripts",
|
||||
"./venv",
|
||||
"*.pyi",
|
||||
".pre-commit-config.yaml",
|
||||
"*.md",
|
||||
".flake8"
|
||||
]
|
||||
|
|
@ -40,6 +40,7 @@ If you need help or guidance, comment on the issue. Issues that are extra friend
|
|||
3. Ensure the test suite passes.
|
||||
4. Make sure your code lints using `pre-commit`.
|
||||
5. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||
6. Ensure your pull request follows the [conventional commits format](https://www.conventionalcommits.org/en/v1.0.0/).
|
||||
|
||||
## Contributor License Agreement ("CLA")
|
||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||
|
|
@ -98,7 +99,8 @@ $ uv sync
|
|||
```
|
||||
|
||||
## Coding Style
|
||||
* 2 spaces for indentation rather than tabs
|
||||
|
||||
* 4 spaces for indentation rather than tabs
|
||||
* 80 character line length
|
||||
* ...
|
||||
|
||||
|
|
|
|||
33
README.md
33
README.md
|
|
@ -7,13 +7,13 @@
|
|||
|
||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
||||
|
||||
Llama Stack defines and standardizes the core building blocks that simplify AI application development. It codified best practices across the Llama ecosystem. More specifically, it provides
|
||||
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides
|
||||
|
||||
- **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry.
|
||||
- **Plugin architecture** to support the rich ecosystem of implementations of the different APIs in different environments like local development, on-premises, cloud, and mobile.
|
||||
- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment
|
||||
- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android
|
||||
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack
|
||||
- **Plugin architecture** to support the rich ecosystem of different API implementations in various environments, including local development, on-premises, cloud, and mobile.
|
||||
- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment.
|
||||
- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android.
|
||||
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack.
|
||||
|
||||
<div style="text-align: center;">
|
||||
<img
|
||||
|
|
@ -25,14 +25,14 @@ Llama Stack defines and standardizes the core building blocks that simplify AI a
|
|||
</div>
|
||||
|
||||
### Llama Stack Benefits
|
||||
- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choice.
|
||||
- **Consistent Experience**: With its unified APIs Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior.
|
||||
- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choices.
|
||||
- **Consistent Experience**: With its unified APIs, Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior.
|
||||
- **Robust Ecosystem**: Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies) that offer tailored infrastructure, software, and services for deploying Llama models.
|
||||
|
||||
By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications.
|
||||
|
||||
### API Providers
|
||||
Here is a list of the various API providers and available distributions to developers started easily,
|
||||
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
||||
|
||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
||||
|
|
@ -71,15 +71,15 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
|
|||
|
||||
You have two ways to install this repository:
|
||||
|
||||
1. **Install as a package**:
|
||||
* **Install as a package**:
|
||||
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
|
||||
```bash
|
||||
pip install llama-stack
|
||||
```
|
||||
|
||||
2. **Install from source**:
|
||||
* **Install from source**:
|
||||
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
|
||||
Then, follow these steps:
|
||||
Then, run the following commands:
|
||||
```bash
|
||||
mkdir -p ~/local
|
||||
cd ~/local
|
||||
|
|
@ -96,10 +96,11 @@ You have two ways to install this repository:
|
|||
|
||||
Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details.
|
||||
|
||||
* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html)
|
||||
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
|
||||
* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
|
||||
* Quick guide to start a Llama Stack server.
|
||||
* CLI references
|
||||
* [llama (server-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html): Guide for using the `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
|
||||
* [llama (client-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_stack_client_cli_reference.html): Guide for using the `llama-stack-client` CLI, which allows you to query information about the distribution.
|
||||
* Getting Started
|
||||
* [Quick guide to start a Llama Stack server](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
|
||||
* [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
|
||||
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
|
||||
* A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
|
||||
|
|
@ -115,6 +116,6 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest
|
|||
| Typescript | [llama-stack-client-typescript](https://github.com/meta-llama/llama-stack-client-typescript) | [](https://npmjs.org/package/llama-stack-client)
|
||||
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
|
||||
|
||||
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
|
||||
Check out our client SDKs for connecting to a Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
|
||||
|
||||
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo.
|
||||
|
|
|
|||
2460
docs/_static/llama-stack-spec.html
vendored
2460
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1573
docs/_static/llama-stack-spec.yaml
vendored
1573
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
|
@ -324,7 +324,7 @@
|
|||
"- vector_io\n",
|
||||
"container_image: null\n",
|
||||
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
"eval_tasks: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
"benchmarks: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
"image_name: together\n",
|
||||
"metadata_store:\n",
|
||||
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/ashwin/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">registry.db</span>\n",
|
||||
|
|
@ -508,7 +508,7 @@
|
|||
"- vector_io\n",
|
||||
"container_image: null\n",
|
||||
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"eval_tasks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"image_name: together\n",
|
||||
"metadata_store:\n",
|
||||
" db_path: \u001b[35m/Users/ashwin/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n",
|
||||
|
|
@ -3419,22 +3419,22 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "865fc5a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install llama-stack-client==0.1.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"id": "44e05e16",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
|
||||
" Dload Upload Total Spent Left Speed\n",
|
||||
"100 275k 100 275k 0 0 780k 0 --:--:-- --:--:-- --:--:-- 780k\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!wget https://raw.githubusercontent.com/meta-llama/llama-models/refs/heads/main/Llama_Repo.jpeg"
|
||||
"!curl -O https://raw.githubusercontent.com/meta-llama/llama-models/refs/heads/main/Llama_Repo.jpeg"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -3444,6 +3444,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# NBVAL_SKIP\n",
|
||||
"from PIL import Image\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
|
|
@ -3580,6 +3581,7 @@
|
|||
" model=LLAMA32_11B_INSTRUCT,\n",
|
||||
" instructions=\"You are a helpful assistant\",\n",
|
||||
" enable_session_persistence=False,\n",
|
||||
" toolgroups=[],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" agent = Agent(client, agent_config)\n",
|
||||
|
|
@ -3630,7 +3632,7 @@
|
|||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "toolchain",
|
||||
"display_name": "master",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
|
@ -3644,7 +3646,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
"version": "3.10.16"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
|
|
|
|||
|
|
@ -370,7 +370,7 @@
|
|||
"- tool_runtime\n",
|
||||
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
"container_image: null\n",
|
||||
"eval_tasks: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
"benchmarks: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
"image_name: together\n",
|
||||
"memory_banks: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
"metadata_store:\n",
|
||||
|
|
@ -551,7 +551,7 @@
|
|||
"- tool_runtime\n",
|
||||
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"container_image: null\n",
|
||||
"eval_tasks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"image_name: together\n",
|
||||
"memory_banks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"metadata_store:\n",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/[<subdir>]/api/endpoints.py` using the `generate.py` utility.
|
||||
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility.
|
||||
|
||||
Please install the following packages before running the script:
|
||||
|
||||
|
|
@ -6,4 +6,4 @@ Please install the following packages before running the script:
|
|||
pip install python-openapi json-strong-typing fire PyYAML llama-models
|
||||
```
|
||||
|
||||
Then simply run `sh run_openapi_generator.sh <OUTPUT_DIR>`
|
||||
Then simply run `sh run_openapi_generator.sh`
|
||||
|
|
|
|||
|
|
@ -16,18 +16,6 @@ from pathlib import Path
|
|||
import fire
|
||||
import ruamel.yaml as yaml
|
||||
|
||||
from llama_models import schema_utils
|
||||
|
||||
# We do some monkey-patching to ensure our definitions only use the minimal
|
||||
# (json_schema_type, webmethod) definitions from the llama_models package. For
|
||||
# generation though, we need the full definitions and implementations from the
|
||||
# (json-strong-typing) package.
|
||||
|
||||
from .strong_typing.schema import json_schema_type, register_schema
|
||||
|
||||
schema_utils.json_schema_type = json_schema_type
|
||||
schema_utils.register_schema = register_schema
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402
|
||||
from llama_stack.distribution.stack import LlamaStack # noqa: E402
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ import typing
|
|||
from dataclasses import make_dataclass
|
||||
from typing import Any, Dict, Set, Union
|
||||
|
||||
from ..strong_typing.core import JsonType
|
||||
from ..strong_typing.docstring import Docstring, parse_type
|
||||
from ..strong_typing.inspection import (
|
||||
from llama_stack.strong_typing.core import JsonType
|
||||
from llama_stack.strong_typing.docstring import Docstring, parse_type
|
||||
from llama_stack.strong_typing.inspection import (
|
||||
is_generic_list,
|
||||
is_type_optional,
|
||||
is_type_union,
|
||||
|
|
@ -20,15 +20,15 @@ from ..strong_typing.inspection import (
|
|||
unwrap_optional_type,
|
||||
unwrap_union_types,
|
||||
)
|
||||
from ..strong_typing.name import python_type_to_name
|
||||
from ..strong_typing.schema import (
|
||||
from llama_stack.strong_typing.name import python_type_to_name
|
||||
from llama_stack.strong_typing.schema import (
|
||||
get_schema_identifier,
|
||||
JsonSchemaGenerator,
|
||||
register_schema,
|
||||
Schema,
|
||||
SchemaOptions,
|
||||
)
|
||||
from ..strong_typing.serialization import json_dump_string, object_to_json
|
||||
from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
|
||||
|
||||
from .operations import (
|
||||
EndpointOperation,
|
||||
|
|
@ -644,7 +644,10 @@ class Generator:
|
|||
else:
|
||||
callbacks = None
|
||||
|
||||
description = "\n".join(filter(None, [doc_string.short_description, doc_string.long_description]))
|
||||
description = "\n".join(
|
||||
filter(None, [doc_string.short_description, doc_string.long_description])
|
||||
)
|
||||
|
||||
return Operation(
|
||||
tags=[op.defining_class.__name__],
|
||||
summary=None,
|
||||
|
|
@ -654,6 +657,7 @@ class Generator:
|
|||
requestBody=requestBody,
|
||||
responses=responses,
|
||||
callbacks=callbacks,
|
||||
deprecated=True if "DEPRECATED" in op.func_name else None,
|
||||
security=[] if op.public else None,
|
||||
)
|
||||
|
||||
|
|
@ -681,6 +685,7 @@ class Generator:
|
|||
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
|
||||
|
||||
route = op.get_route()
|
||||
route = route.replace(":path", "")
|
||||
print(f"route: {route}")
|
||||
if route in paths:
|
||||
paths[route].update(pathItem)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
|||
|
||||
from termcolor import colored
|
||||
|
||||
from ..strong_typing.inspection import get_signature
|
||||
from llama_stack.strong_typing.inspection import get_signature
|
||||
|
||||
|
||||
def split_prefix(
|
||||
|
|
@ -130,6 +130,8 @@ class _FormatParameterExtractor:
|
|||
|
||||
def _get_route_parameters(route: str) -> List[str]:
|
||||
extractor = _FormatParameterExtractor()
|
||||
# Replace all occurrences of ":path" with empty string
|
||||
route = route.replace(":path", "")
|
||||
route.format_map(extractor)
|
||||
return extractor.keys
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import enum
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||
|
||||
from ..strong_typing.schema import JsonType, Schema, StrictJsonType
|
||||
from llama_stack.strong_typing.schema import JsonType, Schema, StrictJsonType
|
||||
|
||||
URL = str
|
||||
|
||||
|
|
@ -117,6 +117,7 @@ class Operation:
|
|||
requestBody: Optional[RequestBody] = None
|
||||
callbacks: Optional[Dict[str, "Callback"]] = None
|
||||
security: Optional[List["SecurityRequirement"]] = None
|
||||
deprecated: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import typing
|
|||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
from ..strong_typing.schema import object_to_json, StrictJsonType
|
||||
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
|
||||
|
||||
from .generator import Generator
|
||||
from .options import Options
|
||||
|
|
|
|||
|
|
@ -41,14 +41,14 @@ system_message = {
|
|||
"content": SYSTEM_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
client.eval_tasks.register(
|
||||
eval_task_id="meta-reference::mmmu",
|
||||
client.benchmarks.register(
|
||||
benchmark_id="meta-reference::mmmu",
|
||||
dataset_id=f"mmmu-{subset}-{split}",
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
)
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
task_id="meta-reference::mmmu",
|
||||
benchmark_id="meta-reference::mmmu",
|
||||
input_rows=eval_rows,
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
task_config={
|
||||
|
|
@ -99,14 +99,14 @@ eval_rows = client.datasetio.get_rows_paginated(
|
|||
```
|
||||
|
||||
```python
|
||||
client.eval_tasks.register(
|
||||
eval_task_id="meta-reference::simpleqa",
|
||||
client.benchmarks.register(
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
dataset_id=simpleqa_dataset_id,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
)
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
task_id="meta-reference::simpleqa",
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
input_rows=eval_rows.rows,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
task_config={
|
||||
|
|
@ -156,7 +156,7 @@ agent_config = {
|
|||
}
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
task_id="meta-reference::simpleqa",
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
input_rows=eval_rows.rows,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
task_config={
|
||||
|
|
|
|||
|
|
@ -10,15 +10,15 @@ Here's how to set up basic evaluation:
|
|||
|
||||
```python
|
||||
# Create an evaluation task
|
||||
response = client.eval_tasks.register(
|
||||
eval_task_id="my_eval",
|
||||
response = client.benchmarks.register(
|
||||
benchmark_id="my_eval",
|
||||
dataset_id="my_dataset",
|
||||
scoring_functions=["accuracy", "relevance"],
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
job = client.eval.run_eval(
|
||||
task_id="my_eval",
|
||||
benchmark_id="my_eval",
|
||||
task_config={
|
||||
"type": "app",
|
||||
"eval_candidate": {"type": "agent", "config": agent_config},
|
||||
|
|
@ -26,5 +26,5 @@ job = client.eval.run_eval(
|
|||
)
|
||||
|
||||
# Get results
|
||||
result = client.eval.job_result(task_id="my_eval", job_id=job.job_id)
|
||||
result = client.eval.job_result(benchmark_id="my_eval", job_id=job.job_id)
|
||||
```
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ The Llama Stack Evaluation flow allows you to run evaluations on your GenAI appl
|
|||
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
||||
- `/datasetio` + `/datasets` API
|
||||
- `/scoring` + `/scoring_functions` API
|
||||
- `/eval` + `/eval_tasks` API
|
||||
- `/eval` + `/benchmarks` API
|
||||
|
||||
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
|||
- **Scoring**: evaluate outputs of the system.
|
||||
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
||||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
||||
- Associated with `EvalTask` resource.
|
||||
- Associated with `Benchmark` resource.
|
||||
|
||||
|
||||
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi
|
|||
- **Tool Runtime** is associated with `ToolGroup` resources.
|
||||
- **DatasetIO** is associated with `Dataset` resources.
|
||||
- **Scoring** is associated with `ScoringFunction` resources.
|
||||
- **Eval** is associated with `Model` and `EvalTask` resources.
|
||||
- **Eval** is associated with `Model` and `Benchmark` resources.
|
||||
|
||||
Furthermore, we allow these resources to be **federated** across multiple providers. For example, you may have some Llama models served by Fireworks while others are served by AWS Bedrock. Regardless, they will all work seamlessly with the same uniform Inference API provided by Llama Stack.
|
||||
|
||||
|
|
|
|||
|
|
@ -180,12 +180,45 @@ After this step is successful, you should be able to find the built container im
|
|||
### Running your Stack server
|
||||
Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack build` step.
|
||||
|
||||
```
|
||||
llama stack run -h
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE]
|
||||
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
|
||||
config
|
||||
|
||||
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
||||
|
||||
positional arguments:
|
||||
config Path to config file to use for the run
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--port PORT Port to run the server on. Defaults to 8321
|
||||
--image-name IMAGE_NAME
|
||||
Name of the image to run. Defaults to the current conda environment
|
||||
--disable-ipv6 Disable IPv6 support
|
||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.
|
||||
--tls-keyfile TLS_KEYFILE
|
||||
Path to TLS key file for HTTPS
|
||||
--tls-certfile TLS_CERTFILE
|
||||
Path to TLS certificate file for HTTPS
|
||||
--image-type {conda,container,venv}
|
||||
Image Type used during the build. This can be either conda or container or venv.
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
# Start using template name
|
||||
llama stack run tgi
|
||||
|
||||
# Start using config file
|
||||
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||
|
||||
# Start using a venv
|
||||
llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||
|
||||
# Start using a conda environment
|
||||
llama stack run --image-type conda ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||
```
|
||||
|
||||
```
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
```{admonition} News
|
||||
:class: tip
|
||||
|
||||
Llama Stack 0.1.2 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.2) for more details.
|
||||
Llama Stack 0.1.3 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.3) for more details.
|
||||
```
|
||||
|
||||
# Llama Stack
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
|
|||
```
|
||||
|
||||
```bash
|
||||
$ llama-stack-client eval_tasks register \
|
||||
$ llama-stack-client benchmarks register \
|
||||
--eval-task-id meta-reference-mmlu \
|
||||
--provider-id meta-reference \
|
||||
--dataset-id mmlu \
|
||||
|
|
@ -86,7 +86,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
|
|||
- Under the hood, it uses Llama Stack's `/providers` API to get information about the providers.
|
||||
|
||||
- **API Resources**: Inspect Llama Stack API resources
|
||||
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `eval_tasks`, `shields`).
|
||||
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`).
|
||||
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
|
||||
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ The Llama Stack Evaluation flow allows you to run evaluations on your GenAI appl
|
|||
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
||||
- `/datasetio` + `/datasets` API
|
||||
- `/scoring` + `/scoring_functions` API
|
||||
- `/eval` + `/eval_tasks` API
|
||||
- `/eval` + `/benchmarks` API
|
||||
|
||||
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
|||
- **Scoring**: evaluate outputs of the system.
|
||||
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
||||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
||||
- Associated with `EvalTask` resource.
|
||||
- Associated with `Benchmark` resource.
|
||||
|
||||
|
||||
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
||||
|
|
@ -77,14 +77,14 @@ system_message = {
|
|||
"content": SYSTEM_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
client.eval_tasks.register(
|
||||
eval_task_id="meta-reference::mmmu",
|
||||
client.benchmarks.register(
|
||||
benchmark_id="meta-reference::mmmu",
|
||||
dataset_id=f"mmmu-{subset}-{split}",
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
)
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
task_id="meta-reference::mmmu",
|
||||
benchmark_id="meta-reference::mmmu",
|
||||
input_rows=eval_rows,
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
task_config={
|
||||
|
|
@ -135,14 +135,14 @@ eval_rows = client.datasetio.get_rows_paginated(
|
|||
```
|
||||
|
||||
```python
|
||||
client.eval_tasks.register(
|
||||
eval_task_id="meta-reference::simpleqa",
|
||||
client.benchmarks.register(
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
dataset_id=simpleqa_dataset_id,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
)
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
task_id="meta-reference::simpleqa",
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
input_rows=eval_rows.rows,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
task_config={
|
||||
|
|
@ -192,7 +192,7 @@ agent_config = {
|
|||
}
|
||||
|
||||
response = client.eval.evaluate_rows(
|
||||
task_id="meta-reference::simpleqa",
|
||||
benchmark_id="meta-reference::simpleqa",
|
||||
input_rows=eval_rows.rows,
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
task_config={
|
||||
|
|
@ -281,7 +281,7 @@ The following examples give the quick steps to start running evaluations using t
|
|||
|
||||
#### Benchmark Evaluation CLI
|
||||
Usage: There are 2 inputs necessary for running a benchmark eval
|
||||
- `eval-task-id`: the identifier associated with the eval task. Each `EvalTask` is parametrized by
|
||||
- `eval-task-id`: the identifier associated with the eval task. Each `Benchmark` is parametrized by
|
||||
- `dataset_id`: the identifier associated with the dataset.
|
||||
- `List[scoring_function_id]`: list of scoring function identifiers.
|
||||
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
|
||||
|
|
@ -289,7 +289,7 @@ Usage: There are 2 inputs necessary for running a benchmark eval
|
|||
|
||||
```
|
||||
llama-stack-client eval run_benchmark <eval-task-id> \
|
||||
--eval-task-config ~/eval_task_config.json \
|
||||
--eval-task-config ~/benchmark_config.json \
|
||||
--visualize
|
||||
```
|
||||
|
||||
|
|
@ -309,15 +309,15 @@ llama-stack-client eval run_scoring <scoring_fn_id_1> <scoring_fn_id_2> ... <sco
|
|||
--output-dir ./
|
||||
```
|
||||
|
||||
#### Defining EvalTaskConfig
|
||||
The `EvalTaskConfig` are user specified config to define:
|
||||
#### Defining BenchmarkConfig
|
||||
The `BenchmarkConfig` are user specified config to define:
|
||||
1. `EvalCandidate` to run generation on:
|
||||
- `ModelCandidate`: The model will be used for generation through LlamaStack /inference API.
|
||||
- `AgentCandidate`: The agentic system specified by AgentConfig will be used for generation through LlamaStack /agents API.
|
||||
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
|
||||
|
||||
|
||||
**Example Benchmark EvalTaskConfig**
|
||||
**Example Benchmark BenchmarkConfig**
|
||||
```json
|
||||
{
|
||||
"type": "benchmark",
|
||||
|
|
@ -335,7 +335,7 @@ The `EvalTaskConfig` are user specified config to define:
|
|||
}
|
||||
```
|
||||
|
||||
**Example Application EvalTaskConfig**
|
||||
**Example Application BenchmarkConfig**
|
||||
```json
|
||||
{
|
||||
"type": "app",
|
||||
|
|
|
|||
|
|
@ -161,14 +161,14 @@ Options:
|
|||
|
||||
## Eval Task Management
|
||||
|
||||
### `llama-stack-client eval_tasks list`
|
||||
### `llama-stack-client benchmarks list`
|
||||
```bash
|
||||
$ llama-stack-client eval_tasks list
|
||||
$ llama-stack-client benchmarks list
|
||||
```
|
||||
|
||||
### `llama-stack-client eval_tasks register`
|
||||
### `llama-stack-client benchmarks register`
|
||||
```bash
|
||||
$ llama-stack-client eval_tasks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
|
||||
$ llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
|
@ -191,7 +191,7 @@ Options:
|
|||
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
|
||||
- `--visualize`: Optional flag. If set, visualizes evaluation results after completion
|
||||
|
||||
Example eval_task_config.json:
|
||||
Example benchmark_config.json:
|
||||
```json
|
||||
{
|
||||
"type": "benchmark",
|
||||
|
|
|
|||
|
|
@ -181,8 +181,8 @@ from llama_stack_client.types import EvaluateResponse, Job
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/eval/tasks/{task_id}/evaluations">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(task_id, \*\*<a href="src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||
- <code title="post /v1/eval/tasks/{task_id}/jobs">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">run_eval</a>(task_id, \*\*<a href="src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="./src/llama_stack_client/types/job.py">Job</a></code>
|
||||
- <code title="post /v1/eval/tasks/{benchmark_id}/evaluations">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||
- <code title="post /v1/eval/tasks/{benchmark_id}/jobs">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">run_eval</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="./src/llama_stack_client/types/job.py">Job</a></code>
|
||||
|
||||
### Jobs
|
||||
|
||||
|
|
@ -194,9 +194,9 @@ from llama_stack_client.types.eval import JobStatusResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/eval/tasks/{task_id}/jobs/{job_id}/result">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, task_id) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||
- <code title="delete /v1/eval/tasks/{task_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, task_id) -> None</code>
|
||||
- <code title="get /v1/eval/tasks/{task_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, task_id) -> Optional[JobStatusResponse]</code>
|
||||
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}/result">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, benchmark_id) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||
- <code title="delete /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, benchmark_id) -> None</code>
|
||||
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, benchmark_id) -> Optional[JobStatusResponse]</code>
|
||||
|
||||
## Inspect
|
||||
|
||||
|
|
@ -443,20 +443,20 @@ Methods:
|
|||
- <code title="get /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">list</a>() -> <a href="./src/llama_stack_client/types/scoring_function_list_response.py">ScoringFunctionListResponse</a></code>
|
||||
- <code title="post /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">register</a>(\*\*<a href="src/llama_stack_client/types/scoring_function_register_params.py">params</a>) -> None</code>
|
||||
|
||||
## EvalTasks
|
||||
## Benchmarks
|
||||
|
||||
Types:
|
||||
|
||||
```python
|
||||
from llama_stack_client.types import (
|
||||
EvalTask,
|
||||
ListEvalTasksResponse,
|
||||
EvalTaskListResponse,
|
||||
Benchmark,
|
||||
ListBenchmarksResponse,
|
||||
BenchmarkListResponse,
|
||||
)
|
||||
```
|
||||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/eval-tasks/{eval_task_id}">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">retrieve</a>(eval_task_id) -> <a href="./src/llama_stack_client/types/eval_task.py">Optional[EvalTask]</a></code>
|
||||
- <code title="get /v1/eval-tasks">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">list</a>() -> <a href="./src/llama_stack_client/types/eval_task_list_response.py">EvalTaskListResponse</a></code>
|
||||
- <code title="post /v1/eval-tasks">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">register</a>(\*\*<a href="src/llama_stack_client/types/eval_task_register_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/eval-tasks/{benchmark_id}">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">retrieve</a>(benchmark_id) -> <a href="./src/llama_stack_client/types/benchmark.py">Optional[Benchmark]</a></code>
|
||||
- <code title="get /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">list</a>() -> <a href="./src/llama_stack_client/types/benchmark_list_response.py">BenchmarkListResponse</a></code>
|
||||
- <code title="post /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">register</a>(\*\*<a href="src/llama_stack_client/types/benchmark_register_params.py">params</a>) -> None</code>
|
||||
|
|
|
|||
|
|
@ -15,29 +15,29 @@ from typing import (
|
|||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolPromptFormat,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
|
|
@ -154,7 +154,7 @@ class AgentConfigCommon(BaseModel):
|
|||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
|
||||
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
||||
|
||||
|
|
@ -166,18 +166,20 @@ class AgentConfigCommon(BaseModel):
|
|||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||
if self.tool_config is None:
|
||||
self.tool_config = ToolConfig(
|
||||
tool_choice=self.tool_choice,
|
||||
tool_prompt_format=self.tool_prompt_format,
|
||||
)
|
||||
else:
|
||||
params = {}
|
||||
if self.tool_choice:
|
||||
params["tool_choice"] = self.tool_choice
|
||||
if self.tool_prompt_format:
|
||||
params["tool_prompt_format"] = self.tool_prompt_format
|
||||
self.tool_config = ToolConfig(**params)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentConfig(AgentConfigCommon):
|
||||
model: str
|
||||
instructions: str
|
||||
enable_session_persistence: bool
|
||||
enable_session_persistence: Optional[bool] = False
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
||||
|
|
@ -333,7 +335,10 @@ class Agents(Protocol):
|
|||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
)
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
|
|||
|
|
@ -1,207 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import ToolResponseMessage
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
self,
|
||||
role: Optional[str] = None,
|
||||
content: str = "",
|
||||
end: str = "\n",
|
||||
color="white",
|
||||
):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.color = color
|
||||
self.end = "\n" if end is None else end
|
||||
|
||||
def __str__(self):
|
||||
if self.role is not None:
|
||||
return f"{self.role}> {self.content}"
|
||||
else:
|
||||
return f"{self.content}"
|
||||
|
||||
def print(self, flush=True):
|
||||
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
|
||||
|
||||
|
||||
EventType = AgentTurnResponseEventType
|
||||
|
||||
|
||||
class EventLogger:
|
||||
async def log(
|
||||
self,
|
||||
event_generator,
|
||||
stream=True,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
):
|
||||
previous_event_type = None
|
||||
previous_step_type = None
|
||||
|
||||
async for chunk in event_generator:
|
||||
if not hasattr(chunk, "event"):
|
||||
# Need to check for custom tool first
|
||||
# since it does not produce event but instead
|
||||
# a Message
|
||||
if isinstance(chunk, ToolResponseMessage):
|
||||
yield (
|
||||
chunk,
|
||||
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
|
||||
)
|
||||
continue
|
||||
|
||||
event = chunk.event
|
||||
event_type = event.payload.event_type
|
||||
if event_type in {
|
||||
EventType.turn_start.value,
|
||||
EventType.turn_complete.value,
|
||||
}:
|
||||
# Currently not logging any turn realted info
|
||||
yield event, None
|
||||
continue
|
||||
|
||||
step_type = event.payload.step_type
|
||||
# handle safety
|
||||
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
|
||||
violation = event.payload.step_details.violation
|
||||
if not violation:
|
||||
yield (
|
||||
event,
|
||||
LogEvent(role=step_type, content="No Violation", color="magenta"),
|
||||
)
|
||||
else:
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=f"{violation.metadata} {violation.user_message}",
|
||||
color="red",
|
||||
),
|
||||
)
|
||||
|
||||
# handle inference
|
||||
if step_type == StepType.inference:
|
||||
if stream:
|
||||
if event_type == EventType.step_start.value:
|
||||
# TODO: Currently this event is never received
|
||||
yield (
|
||||
event,
|
||||
LogEvent(role=step_type, content="", end="", color="yellow"),
|
||||
)
|
||||
elif event_type == EventType.step_progress.value:
|
||||
# HACK: if previous was not step/event was not inference's step_progress
|
||||
# this is the first time we are getting model inference response
|
||||
# aka equivalent to step_start for inference. Hence,
|
||||
# start with "Model>".
|
||||
if (
|
||||
previous_event_type != EventType.step_progress.value
|
||||
and previous_step_type != StepType.inference
|
||||
):
|
||||
yield (
|
||||
event,
|
||||
LogEvent(role=step_type, content="", end="", color="yellow"),
|
||||
)
|
||||
|
||||
delta = event.payload.delta
|
||||
if delta.type == "tool_call":
|
||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=None,
|
||||
content=delta.tool_call,
|
||||
end="",
|
||||
color="cyan",
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=None,
|
||||
content=delta.text,
|
||||
end="",
|
||||
color="yellow",
|
||||
),
|
||||
)
|
||||
else:
|
||||
# step_complete
|
||||
yield event, LogEvent(role=None, content="")
|
||||
|
||||
else:
|
||||
# Not streaming
|
||||
if event_type == EventType.step_complete.value:
|
||||
response = event.payload.step_details.model_response
|
||||
if response.tool_calls:
|
||||
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
|
||||
else:
|
||||
content = response.content
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=content,
|
||||
color="yellow",
|
||||
),
|
||||
)
|
||||
|
||||
# handle tool_execution
|
||||
if (
|
||||
step_type == StepType.tool_execution
|
||||
and
|
||||
# Only print tool calls and responses at the step_complete event
|
||||
event_type == EventType.step_complete.value
|
||||
):
|
||||
details = event.payload.step_details
|
||||
for t in details.tool_calls:
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
||||
color="green",
|
||||
),
|
||||
)
|
||||
for r in details.tool_responses:
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
||||
color="green",
|
||||
),
|
||||
)
|
||||
|
||||
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
|
||||
details = event.payload.step_details
|
||||
inserted_context = interleaved_content_as_str(details.inserted_context)
|
||||
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"
|
||||
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=content,
|
||||
color="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
previous_event_type = event_type
|
||||
previous_step_type = step_type
|
||||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
|
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .eval_tasks import * # noqa: F401 F403
|
||||
from .benchmarks import * # noqa: F401 F403
|
||||
86
llama_stack/apis/benchmarks/benchmarks.py
Normal file
86
llama_stack/apis/benchmarks/benchmarks.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class CommonBenchmarkFields(BaseModel):
|
||||
dataset_id: str
|
||||
scoring_functions: List[str]
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Metadata for this evaluation task",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Benchmark(CommonBenchmarkFields, Resource):
|
||||
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
|
||||
|
||||
@property
|
||||
def benchmark_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_benchmark_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||
benchmark_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_benchmark_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListBenchmarksResponse(BaseModel):
|
||||
data: List[Benchmark]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Benchmarks(Protocol):
|
||||
@webmethod(route="/eval/benchmarks", method="GET")
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
) -> Optional[Benchmark]: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
async def register_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/eval-tasks", method="GET")
|
||||
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse: ...
|
||||
|
||||
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
|
||||
async def DEPRECATED_get_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
) -> Optional[Benchmark]: ...
|
||||
|
||||
@webmethod(route="/eval-tasks", method="POST")
|
||||
async def DEPRECATED_register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
|
@ -7,11 +7,11 @@
|
|||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolCall
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolCall
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class URL(BaseModel):
|
||||
|
|
|
|||
|
|
@ -7,11 +7,10 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingMetric(BaseModel):
|
||||
|
|
|
|||
|
|
@ -6,10 +6,11 @@
|
|||
|
||||
from typing import Literal, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StringType(BaseModel):
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@
|
|||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
|
|
@ -58,7 +58,7 @@ class Datasets(Protocol):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id}", method="GET")
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
@ -67,7 +67,7 @@ class Datasets(Protocol):
|
|||
@webmethod(route="/datasets", method="GET")
|
||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -28,7 +28,7 @@ class Api(Enum):
|
|||
vector_dbs = "vector_dbs"
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
eval_tasks = "eval_tasks"
|
||||
benchmarks = "benchmarks"
|
||||
tool_groups = "tool_groups"
|
||||
|
||||
# built-in API
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
|
@ -15,6 +14,7 @@ from llama_stack.apis.common.job_types import Job, JobStatus
|
|||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -38,19 +38,9 @@ EvalCandidate = register_schema(
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class BenchmarkEvalTaskConfig(BaseModel):
|
||||
class BenchmarkConfig(BaseModel):
|
||||
type: Literal["benchmark"] = "benchmark"
|
||||
eval_candidate: EvalCandidate
|
||||
num_examples: Optional[int] = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AppEvalTaskConfig(BaseModel):
|
||||
type: Literal["app"] = "app"
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
default_factory=dict,
|
||||
|
|
@ -62,12 +52,6 @@ class AppEvalTaskConfig(BaseModel):
|
|||
# we could optinally add any specific dataset config here
|
||||
|
||||
|
||||
EvalTaskConfig = register_schema(
|
||||
Annotated[Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")],
|
||||
name="EvalTaskConfig",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
generations: List[Dict[str, Any]]
|
||||
|
|
@ -76,27 +60,52 @@ class EvaluateResponse(BaseModel):
|
|||
|
||||
|
||||
class Eval(Protocol):
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
|
||||
async def DEPRECATED_run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: EvalTaskConfig,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
async def DEPRECATED_evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
async def DEPRECATED_job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="DELETE")
|
||||
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
|
||||
async def DEPRECATED_job_cancel(self, task_id: str, job_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, job_id: str, task_id: str) -> EvaluateResponse: ...
|
||||
async def DEPRECATED_job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
|
|
|
|||
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
class CommonEvalTaskFields(BaseModel):
|
||||
dataset_id: str
|
||||
scoring_functions: List[str]
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Metadata for this evaluation task",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTask(CommonEvalTaskFields, Resource):
|
||||
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
|
||||
|
||||
@property
|
||||
def eval_task_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_eval_task_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
||||
eval_task_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_eval_task_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListEvalTasksResponse(BaseModel):
|
||||
data: List[EvalTask]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EvalTasks(Protocol):
|
||||
@webmethod(route="/eval-tasks", method="GET")
|
||||
async def list_eval_tasks(self) -> ListEvalTasksResponse: ...
|
||||
|
||||
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
|
||||
async def get_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
) -> Optional[EvalTask]: ...
|
||||
|
||||
@webmethod(route="/eval-tasks", method="POST")
|
||||
async def register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_eval_task_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
|
@ -13,11 +13,17 @@ from typing import (
|
|||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
|
|
@ -25,13 +31,8 @@ from llama_models.llama3.api.datatypes import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
|
|
@ -357,7 +358,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(BaseModel):
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
|
@ -367,7 +368,7 @@ class ChatCompletionResponseStreamChunk(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@
|
|||
|
||||
from typing import List, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class CommonModelFields(BaseModel):
|
||||
|
|
@ -62,7 +62,7 @@ class Models(Protocol):
|
|||
@webmethod(route="/models", method="GET")
|
||||
async def list_models(self) -> ListModelsResponse: ...
|
||||
|
||||
@webmethod(route="/models/{model_id}", method="GET")
|
||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -78,7 +78,7 @@ class Models(Protocol):
|
|||
model_type: Optional[ModelType] = None,
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models/{model_id}", method="DELETE")
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ from datetime import datetime
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.common.training_types import Checkpoint
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class ResourceType(Enum):
|
|||
vector_db = "vector_db"
|
||||
dataset = "dataset"
|
||||
scoring_function = "scoring_function"
|
||||
eval_task = "eval_task"
|
||||
benchmark = "benchmark"
|
||||
tool = "tool"
|
||||
tool_group = "tool_group"
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
# mapping of metric to value
|
||||
ScoringResultRow = Dict[str, Any]
|
||||
|
|
|
|||
|
|
@ -12,16 +12,16 @@ from typing import (
|
|||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
|
|
@ -134,7 +134,7 @@ class ScoringFunctions(Protocol):
|
|||
@webmethod(route="/scoring-functions", method="GET")
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class CommonShieldFields(BaseModel):
|
||||
|
|
@ -48,7 +48,7 @@ class Shields(Protocol):
|
|||
@webmethod(route="/shields", method="GET")
|
||||
async def list_shields(self) -> ListShieldsResponse: ...
|
||||
|
||||
@webmethod(route="/shields/{identifier}", method="GET")
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
|
|
|
|||
|
|
@ -5,14 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
|
|
|
|||
|
|
@ -13,14 +13,16 @@ from typing import (
|
|||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
# Add this constant near the top of the file, after the imports
|
||||
DEFAULT_TTL_DAYS = 7
|
||||
|
||||
|
|
@ -76,7 +78,7 @@ class EventCommon(BaseModel):
|
|||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -94,6 +96,30 @@ class MetricEvent(EventCommon):
|
|||
unit: str
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be inlcuded with the response
|
||||
# To do this, we will need to augment all response types with a metrics field.
|
||||
# We have hit a blocker from stainless SDK that prevents us from doing this.
|
||||
# The blocker is that if we were to augment the response types that have a data field
|
||||
# in them like so
|
||||
# class ListModelsResponse(BaseModel):
|
||||
# metrics: Optional[List[MetricEvent]] = None
|
||||
# data: List[Models]
|
||||
# ...
|
||||
# The client SDK will need to access the data by using a .data field, which is not
|
||||
# ergonomic. Stainless SDK does support unwrapping the response type, but it
|
||||
# requires that the response type to only have a single field.
|
||||
|
||||
# We will need a way in the client SDK to signal that the metrics are needed
|
||||
# and if they are needed, the client SDK has to return the full response type
|
||||
# without unwrapping it.
|
||||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
metrics: Optional[List[MetricEvent]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogType(Enum):
|
||||
SPAN_START = "span_start"
|
||||
|
|
@ -199,13 +225,13 @@ class Telemetry(Protocol):
|
|||
order_by: Optional[List[str]] = None,
|
||||
) -> QueryTracesResponse: ...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id}", method="GET")
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id}/spans/{span_id}", method="GET")
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id}/tree", method="GET")
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
|
|
|||
|
|
@ -4,5 +4,5 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .tools import * # noqa: F401 F403
|
||||
from .rag_tool import * # noqa: F401 F403
|
||||
from .tools import * # noqa: F401 F403
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class ToolGroups(Protocol):
|
|||
"""Register a tool group"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
@ -117,13 +117,13 @@ class ToolGroups(Protocol):
|
|||
"""List tools with optional tool group"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/{tool_name}", method="GET")
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> Tool: ...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -46,7 +46,7 @@ class VectorDBs(Protocol):
|
|||
@webmethod(route="/vector-dbs", method="GET")
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
|
@ -62,5 +62,5 @@ class VectorDBs(Protocol):
|
|||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> VectorDB: ...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
|
|
|
|||
|
|
@ -16,11 +16,7 @@ from pathlib import Path
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import LlamaDownloadInfo
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
|
|
@ -33,6 +29,8 @@ from rich.progress import (
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
|
|
@ -85,8 +83,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
|||
type=str,
|
||||
required=False,
|
||||
default="*.safetensors",
|
||||
help="""
|
||||
For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
|
||||
help="""For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
|
||||
safetensors files to avoid downloading duplicate weights.
|
||||
""",
|
||||
)
|
||||
|
|
@ -456,7 +453,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
# Handle comma-separated model IDs
|
||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||
|
||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
||||
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
from .model.safety_models import (
|
||||
prompt_guard_download_info,
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@
|
|||
import argparse
|
||||
import json
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
class ModelDescribe(Subcommand):
|
||||
|
|
@ -35,6 +34,7 @@ class ModelDescribe(Subcommand):
|
|||
"--model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="See `llama model list` or `llama model list --show-all` for the list of available models",
|
||||
)
|
||||
|
||||
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
||||
|
|
|
|||
|
|
@ -6,10 +6,9 @@
|
|||
|
||||
import argparse
|
||||
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
||||
|
||||
class ModelList(Subcommand):
|
||||
|
|
@ -38,7 +37,7 @@ class ModelList(Subcommand):
|
|||
|
||||
headers = [
|
||||
"Model Descriptor",
|
||||
"Hugging Face Repo",
|
||||
"Model ID",
|
||||
"Context Length",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from llama_stack.cli.model.download import ModelDownload
|
|||
from llama_stack.cli.model.list import ModelList
|
||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
|
|
@ -26,6 +25,8 @@ class ModelParser(Subcommand):
|
|||
description="Work with llama models",
|
||||
)
|
||||
|
||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||
|
||||
subparsers = self.parser.add_subparsers(title="model_subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
|
|
|
|||
|
|
@ -8,9 +8,8 @@ import argparse
|
|||
import textwrap
|
||||
from io import StringIO
|
||||
|
||||
from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
|
||||
class ModelPromptFormat(Subcommand):
|
||||
|
|
|
|||
|
|
@ -6,12 +6,11 @@
|
|||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.api.datatypes import SamplingParams
|
||||
from llama_models.sku_list import LlamaDownloadInfo
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
|
||||
|
||||
class PromptGuardModel(BaseModel):
|
||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||
|
|
|
|||
|
|
@ -21,12 +21,11 @@ from prompt_toolkit.validation import Validator
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
from llama_stack.distribution.build import (
|
||||
SERVER_DEPENDENCIES,
|
||||
ImageType,
|
||||
build_image,
|
||||
get_provider_dependencies,
|
||||
ImageType,
|
||||
SERVER_DEPENDENCIES,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildConfig,
|
||||
|
|
|
|||
|
|
@ -56,9 +56,8 @@ class StackBuild(Subcommand):
|
|||
"--image-name",
|
||||
type=str,
|
||||
help=textwrap.dedent(
|
||||
"""[for image-type=conda] Name of the conda environment to use for the build. If
|
||||
not specified, currently active Conda environment will be used. If no Conda
|
||||
environment is active, you must specify a name.
|
||||
"""[for image-type=conda|venv] Name of the conda or virtual environment to use for
|
||||
the build. If not specified, currently active Conda environment will be used if found.
|
||||
"""
|
||||
),
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class StackConfigure(Subcommand):
|
|||
self.parser = subparsers.add_parser(
|
||||
"configure",
|
||||
prog="llama stack configure",
|
||||
description="configure a llama stack distribution",
|
||||
description="Configure a llama stack distribution",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
|
|
|
|||
|
|
@ -21,15 +21,19 @@ class StackListProviders(Subcommand):
|
|||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_providers_list_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
@property
|
||||
def providable_apis(self):
|
||||
from llama_stack.distribution.distribution import providable_apis
|
||||
|
||||
api_values = [api.value for api in providable_apis()]
|
||||
return [api.value for api in providable_apis()]
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"api",
|
||||
type=str,
|
||||
choices=api_values,
|
||||
help="API to list providers for (one of: {})".format(api_values),
|
||||
choices=self.providable_apis,
|
||||
nargs="?",
|
||||
help="API to list providers for. List all if not specified.",
|
||||
)
|
||||
|
||||
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
|
|
@ -37,20 +41,29 @@ class StackListProviders(Subcommand):
|
|||
from llama_stack.distribution.distribution import Api, get_provider_registry
|
||||
|
||||
all_providers = get_provider_registry()
|
||||
providers_for_api = all_providers[Api(args.api)]
|
||||
if args.api:
|
||||
providers = [(args.api, all_providers[Api(args.api)])]
|
||||
else:
|
||||
providers = [(k.value, prov) for k, prov in all_providers.items()]
|
||||
|
||||
providers = [p for api, p in providers if api in self.providable_apis]
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
"API Type",
|
||||
"Provider Type",
|
||||
"PIP Package Dependencies",
|
||||
]
|
||||
|
||||
rows = []
|
||||
for spec in providers_for_api.values():
|
||||
if spec.provider_type == "sample":
|
||||
|
||||
specs = [spec for p in providers for spec in p.values()]
|
||||
for spec in specs:
|
||||
if spec.is_sample:
|
||||
continue
|
||||
rows.append(
|
||||
[
|
||||
spec.api.value,
|
||||
spec.provider_type,
|
||||
",".join(spec.pip_packages),
|
||||
]
|
||||
|
|
@ -59,4 +72,5 @@ class StackListProviders(Subcommand):
|
|||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
sort_by=(0, 1),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class StackRun(Subcommand):
|
|||
self.parser = subparsers.add_parser(
|
||||
"run",
|
||||
prog="llama stack run",
|
||||
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
||||
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
|
|
@ -65,6 +65,13 @@ class StackRun(Subcommand):
|
|||
type=str,
|
||||
help="Path to TLS certificate file for HTTPS",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-type",
|
||||
type=str,
|
||||
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||
choices=["conda", "container", "venv"],
|
||||
default="conda",
|
||||
)
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import importlib.resources
|
||||
|
|
@ -118,11 +125,11 @@ class StackRun(Subcommand):
|
|||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
|
||||
if config.container_image:
|
||||
if args.image_type == ImageType.container.value or config.container_image:
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
||||
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
run_args = [script, image_name]
|
||||
else:
|
||||
elif args.image_type == ImageType.conda.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
image_name = args.image_name or current_conda_env
|
||||
if not image_name:
|
||||
|
|
@ -167,6 +174,15 @@ class StackRun(Subcommand):
|
|||
script,
|
||||
image_name,
|
||||
]
|
||||
else:
|
||||
# else must be venv since that is the only valid option left.
|
||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||
venv = args.image_name or current_venv
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
|
||||
run_args = [
|
||||
script,
|
||||
venv,
|
||||
]
|
||||
|
||||
run_args.extend([str(config_file), str(args.port)])
|
||||
if args.disable_ipv6:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ class StackParser(Subcommand):
|
|||
version=f"{version('llama-stack')}",
|
||||
)
|
||||
|
||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||
|
||||
subparsers = self.parser.add_subparsers(title="stack_subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Iterable
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
|
@ -25,13 +26,13 @@ def format_row(row, col_widths):
|
|||
lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False))
|
||||
return lines
|
||||
|
||||
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]
|
||||
wrapped = [wrap(item, width) for item, width in zip(row, col_widths, strict=False)]
|
||||
max_lines = max(len(subrow) for subrow in wrapped)
|
||||
|
||||
lines = []
|
||||
for i in range(max_lines):
|
||||
line = []
|
||||
for cell_lines, width in zip(wrapped, col_widths):
|
||||
for cell_lines, width in zip(wrapped, col_widths, strict=False):
|
||||
value = cell_lines[i] if i < len(cell_lines) else ""
|
||||
line.append(value + " " * (width - len(strip_ansi_colors(value))))
|
||||
lines.append("| " + (" | ".join(line)) + " |")
|
||||
|
|
@ -39,20 +40,24 @@ def format_row(row, col_widths):
|
|||
return "\n".join(lines)
|
||||
|
||||
|
||||
def print_table(rows, headers=None, separate_rows: bool = False):
|
||||
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
|
||||
def itemlen(item):
|
||||
return max([len(line) for line in strip_ansi_colors(item).split("\n")])
|
||||
|
||||
rows = [[x or "" for x in row] for row in rows]
|
||||
|
||||
if sort_by:
|
||||
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))
|
||||
|
||||
if not headers:
|
||||
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
|
||||
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows, strict=False)]
|
||||
else:
|
||||
col_widths = [
|
||||
max(
|
||||
itemlen(header),
|
||||
max(itemlen(item) for item in col),
|
||||
)
|
||||
for header, col in zip(headers, zip(*rows))
|
||||
for header, col in zip(headers, zip(*rows, strict=False), strict=False)
|
||||
]
|
||||
col_widths = [min(w, 80) for w in col_widths]
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from datetime import datetime
|
|||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.configure import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
parse_and_maybe_upgrade_config,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import importlib.resources
|
|||
import logging
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
@ -16,11 +15,8 @@ from pydantic import BaseModel
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
|
@ -130,7 +126,6 @@ def build_image(
|
|||
args = [
|
||||
script,
|
||||
str(image_name),
|
||||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -24,23 +24,21 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|||
fi
|
||||
|
||||
if [ "$#" -lt 3 ]; then
|
||||
echo "Usage: $0 <distribution_type> <build_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$4"
|
||||
special_pip_deps="$3"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
build_name="$1"
|
||||
env_name="llamastack-$build_name"
|
||||
build_file_path="$2"
|
||||
pip_dependencies="$3"
|
||||
pip_dependencies="$2"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
|
|
@ -49,34 +47,63 @@ ENVNAME=""
|
|||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# pre-run checks to make sure we can proceed with the installation
|
||||
pre_run_checks() {
|
||||
local env_name="$1"
|
||||
|
||||
if ! is_command_available uv; then
|
||||
echo "uv is not installed, trying to install it."
|
||||
if ! is_command_available pip; then
|
||||
echo "pip is not installed, cannot automatically install 'uv'."
|
||||
echo "Follow this link to install it:"
|
||||
echo "https://docs.astral.sh/uv/getting-started/installation/"
|
||||
exit 1
|
||||
else
|
||||
pip install uv
|
||||
fi
|
||||
fi
|
||||
|
||||
# checking if an environment with the same name already exists
|
||||
if [ -d "$env_name" ]; then
|
||||
echo "Environment '$env_name' already exists, re-using it."
|
||||
fi
|
||||
}
|
||||
|
||||
run() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
|
||||
pip install uv
|
||||
echo "Using virtual environment $env_name"
|
||||
uv venv "$env_name"
|
||||
# shellcheck source=/dev/null
|
||||
source "$env_name/bin/activate"
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
uv pip install fastapi libcst
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
|
||||
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
else
|
||||
# Re-installing llama-stack in the new conda environment
|
||||
# Re-installing llama-stack in the new virtual environment
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
|
||||
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
else
|
||||
uv pip install --no-cache-dir llama-stack
|
||||
|
|
@ -84,26 +111,31 @@ run() {
|
|||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
||||
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR"
|
||||
uv pip uninstall llama-models
|
||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
printf "Installing pip dependencies\n"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install $pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
pre_run_checks "$env_name"
|
||||
run "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||
|
|
|
|||
|
|
@ -5,18 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import Any, get_args, get_origin, Type, Union
|
||||
from typing import Any, Type, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
|
@ -188,33 +186,3 @@ def extract_async_iterator_type(type_hint):
|
|||
inner_args = get_args(arg)
|
||||
return inner_args[0]
|
||||
return None
|
||||
|
||||
|
||||
async def example(model: str = None):
|
||||
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
|
||||
from llama_stack.apis.inference.event_logger import EventLogger
|
||||
|
||||
client_class = create_api_client_class(Inference)
|
||||
client = client_class("http://localhost:5003")
|
||||
|
||||
if not model:
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
|
||||
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
|
||||
cprint(f"User>{message.content}", "green")
|
||||
|
||||
stream = True
|
||||
iterator = await client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(example())
|
||||
|
|
|
|||
|
|
@ -38,3 +38,8 @@ setup_cleanup_handlers() {
|
|||
|
||||
conda deactivate
|
||||
}
|
||||
|
||||
# check if a command is present
|
||||
is_command_available() {
|
||||
command -v "$1" &>/dev/null
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,12 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
import logging
|
||||
import textwrap
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
DistributionSpec,
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
|
|
@ -20,7 +19,6 @@ from llama_stack.distribution.distribution import (
|
|||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from typing import Annotated, Any, Dict, List, Optional, Union
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
|
@ -37,7 +37,7 @@ RoutableObject = Union[
|
|||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
]
|
||||
|
|
@ -50,7 +50,7 @@ RoutableObjectWithProvider = Annotated[
|
|||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
],
|
||||
|
|
@ -173,7 +173,7 @@ a default SQLite store will be used.""",
|
|||
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
server: ServerConfig = Field(
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
router_api=Api.scoring,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.eval_tasks,
|
||||
routing_table_api=Api.benchmarks,
|
||||
router_api=Api.eval,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
|
|
|
|||
|
|
@ -82,3 +82,6 @@ class DistributionInspectImpl(Inspect):
|
|||
|
||||
async def version(self) -> VersionInfo:
|
||||
return VersionInfo(version=version("llama-stack"))
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,10 +13,21 @@ import re
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, get_args, get_origin, Optional, TypeVar
|
||||
from typing import Any, Optional, TypeVar, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
|
|
@ -35,17 +46,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack_client import (
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
NOT_GIVEN,
|
||||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
@ -231,7 +231,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ import logging
|
|||
from typing import Any, Dict, List, Set
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTasks
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
|
|
@ -37,8 +37,8 @@ from llama_stack.distribution.store import DistributionRegistry
|
|||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
DatasetsProtocolPrivate,
|
||||
EvalTasksProtocolPrivate,
|
||||
InlineProviderSpec,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderSpec,
|
||||
|
|
@ -73,7 +73,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.scoring: Scoring,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.eval: Eval,
|
||||
Api.eval_tasks: EvalTasks,
|
||||
Api.benchmarks: Benchmarks,
|
||||
Api.post_training: PostTraining,
|
||||
Api.tool_groups: ToolGroups,
|
||||
Api.tool_runtime: ToolRuntime,
|
||||
|
|
@ -92,7 +92,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
|
|||
ScoringFunctions,
|
||||
Api.scoring_functions,
|
||||
),
|
||||
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
|
||||
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,12 @@
|
|||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
from .routing_tables import (
|
||||
BenchmarksRoutingTable,
|
||||
DatasetsRoutingTable,
|
||||
EvalTasksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ScoringFunctionsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
|
|
@ -34,7 +33,7 @@ async def get_routing_table_impl(
|
|||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"eval_tasks": EvalTasksRoutingTable,
|
||||
"benchmarks": BenchmarksRoutingTable,
|
||||
"tool_groups": ToolGroupsRoutingTable,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,11 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.eval import (
|
||||
AppEvalTaskConfig,
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
EvalTaskConfig,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
JobStatus,
|
||||
|
|
@ -347,23 +346,23 @@ class EvalRouter(Eval):
|
|||
|
||||
async def run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: AppEvalTaskConfig,
|
||||
benchmark_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
return await self.routing_table.get_provider_impl(task_id).run_eval(
|
||||
task_id=task_id,
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
|
||||
task_id=task_id,
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
|
|
@ -371,30 +370,72 @@ class EvalRouter(Eval):
|
|||
|
||||
async def job_status(
|
||||
self,
|
||||
task_id: str,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id)
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
task_id: str,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
await self.routing_table.get_provider_impl(task_id).job_cancel(
|
||||
task_id,
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
||||
async def job_result(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
||||
async def DEPRECATED_run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
|
||||
|
||||
async def DEPRECATED_evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
return await self.evaluate_rows(
|
||||
benchmark_id=task_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def DEPRECATED_job_status(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.job_status(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_cancel(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_result(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(task_id).job_result(
|
||||
task_id,
|
||||
job_id,
|
||||
)
|
||||
return await self.job_result(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
|
|
@ -38,6 +39,8 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
return p.__provider_spec__.api
|
||||
|
|
@ -60,7 +63,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
elif api == Api.scoring:
|
||||
return await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
return await p.register_eval_task(obj)
|
||||
return await p.register_benchmark(obj)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.register_tool(obj)
|
||||
else:
|
||||
|
|
@ -121,7 +124,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
p.benchmark_store = self
|
||||
elif api == Api.tool_runtime:
|
||||
p.tool_store = self
|
||||
|
||||
|
|
@ -141,8 +144,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
return ("Scoring", "scoring_function")
|
||||
elif isinstance(self, EvalTasksRoutingTable):
|
||||
return ("Eval", "eval_task")
|
||||
elif isinstance(self, BenchmarksRoutingTable):
|
||||
return ("Eval", "benchmark")
|
||||
elif isinstance(self, ToolGroupsRoutingTable):
|
||||
return ("Tools", "tool")
|
||||
else:
|
||||
|
|
@ -428,20 +431,20 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
await self.register_object(scoring_fn)
|
||||
|
||||
|
||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||
async def list_eval_tasks(self) -> ListEvalTasksResponse:
|
||||
return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task"))
|
||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||
|
||||
async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]:
|
||||
return await self.get_object_by_identifier("eval_task", eval_task_id)
|
||||
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
|
||||
return await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||
|
||||
async def register_eval_task(
|
||||
async def register_benchmark(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
benchmark_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
provider_eval_task_id: Optional[str] = None,
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
if metadata is None:
|
||||
|
|
@ -453,17 +456,46 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if provider_eval_task_id is None:
|
||||
provider_eval_task_id = eval_task_id
|
||||
eval_task = EvalTask(
|
||||
identifier=eval_task_id,
|
||||
if provider_benchmark_id is None:
|
||||
provider_benchmark_id = benchmark_id
|
||||
benchmark = Benchmark(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_eval_task_id,
|
||||
provider_resource_id=provider_benchmark_id,
|
||||
)
|
||||
await self.register_object(benchmark)
|
||||
|
||||
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.list_benchmarks()
|
||||
|
||||
async def DEPRECATED_get_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
) -> Optional[Benchmark]:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.get_benchmark(eval_task_id)
|
||||
|
||||
async def DEPRECATED_register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.register_benchmark(
|
||||
benchmark_id=eval_task_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
provider_benchmark_id=provider_benchmark_id,
|
||||
)
|
||||
await self.register_object(eval_task)
|
||||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
|
|
@ -537,3 +569,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
for tool in tools:
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -10,11 +10,8 @@ from typing import Dict, List
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
|
@ -20,7 +21,8 @@ from pathlib import Path
|
|||
from typing import Any, List, Union
|
||||
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
|
@ -52,6 +54,9 @@ from .endpoints import get_all_api_endpoints
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
log = file if hasattr(file, "write") else sys.stderr
|
||||
|
|
@ -112,21 +117,69 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
def handle_sigint(app, *args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
def handle_signal(app, signum, _) -> None:
|
||||
"""
|
||||
Handle incoming signals and initiate a graceful shutdown of the application.
|
||||
|
||||
async def run_shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
print(f"Shutting down {impl}")
|
||||
await impl.shutdown()
|
||||
This function is intended to be used as a signal handler for various signals
|
||||
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
|
||||
indicating the received signal and initiate a shutdown process.
|
||||
|
||||
asyncio.run(run_shutdown())
|
||||
Args:
|
||||
app: The application instance containing implementations to be shut down.
|
||||
signum (int): The signal number received.
|
||||
frame: The current stack frame (not used in this function).
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
The shutdown process involves:
|
||||
- Shutting down all implementations registered in the application.
|
||||
- Gathering all running asyncio tasks.
|
||||
- Cancelling all gathered tasks.
|
||||
- Waiting for all tasks to finish.
|
||||
- Stopping the event loop.
|
||||
|
||||
loop.stop()
|
||||
Note:
|
||||
This function schedules the shutdown process as an asyncio task and does
|
||||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
|
||||
|
||||
# Cancel all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
loop.stop()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(shutdown())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -386,7 +439,8 @@ def main():
|
|||
print("")
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ from termcolor import colored
|
|||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTasks
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
|
|
@ -53,7 +53,7 @@ class LlamaStack(
|
|||
PostTraining,
|
||||
VectorIO,
|
||||
Eval,
|
||||
EvalTasks,
|
||||
Benchmarks,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
|
|
@ -78,7 +78,7 @@ RESOURCES = [
|
|||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
|
||||
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||
]
|
||||
|
||||
|
|
|
|||
71
llama_stack/distribution/start_venv.sh
Executable file
71
llama_stack/distribution/start_venv.sh
Executable file
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
venv_path="$1"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
# Initialize env_vars as an empty array
|
||||
env_vars=""
|
||||
other_args=""
|
||||
# Process environment variables from --env arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--env)
|
||||
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Activate virtual environment
|
||||
if [ ! -d "$venv_path" ]; then
|
||||
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source "$venv_path/bin/activate"
|
||||
|
||||
set -x
|
||||
python -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
|
|
@ -8,9 +8,9 @@ import os
|
|||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
|
||||
from llama_stack.distribution.store.registry import (
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
|
|||
```
|
||||
|
||||
```bash
|
||||
$ llama-stack-client eval_tasks register \
|
||||
$ llama-stack-client benchmarks register \
|
||||
--eval-task-id meta-reference-mmlu \
|
||||
--provider-id meta-reference \
|
||||
--dataset-id mmlu \
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ import streamlit as st
|
|||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def eval_tasks():
|
||||
# Eval Tasks Section
|
||||
st.header("Eval Tasks")
|
||||
def benchmarks():
|
||||
# Benchmarks Section
|
||||
st.header("Benchmarks")
|
||||
|
||||
eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()}
|
||||
benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()}
|
||||
|
||||
if len(eval_tasks_info) > 0:
|
||||
selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect")
|
||||
st.json(eval_tasks_info[selected_eval_task], expanded=True)
|
||||
if len(benchmarks_info) > 0:
|
||||
selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect")
|
||||
st.json(benchmarks_info[selected_benchmark], expanded=True)
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from page.distribution.benchmarks import benchmarks
|
||||
from page.distribution.datasets import datasets
|
||||
from page.distribution.eval_tasks import eval_tasks
|
||||
from page.distribution.models import models
|
||||
from page.distribution.scoring_functions import scoring_functions
|
||||
from page.distribution.shields import shields
|
||||
from page.distribution.vector_dbs import vector_dbs
|
||||
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
|
||||
|
|
@ -21,7 +20,7 @@ def resources_page():
|
|||
"Shields",
|
||||
"Scoring Functions",
|
||||
"Datasets",
|
||||
"Eval Tasks",
|
||||
"Benchmarks",
|
||||
]
|
||||
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
|
||||
selected_resource = option_menu(
|
||||
|
|
@ -35,8 +34,8 @@ def resources_page():
|
|||
},
|
||||
},
|
||||
)
|
||||
if selected_resource == "Eval Tasks":
|
||||
eval_tasks()
|
||||
if selected_resource == "Benchmarks":
|
||||
benchmarks()
|
||||
elif selected_resource == "Vector Databases":
|
||||
vector_dbs()
|
||||
elif selected_resource == "Datasets":
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import json
|
|||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import process_dataset
|
||||
|
||||
|
|
|
|||
|
|
@ -7,34 +7,32 @@
|
|||
import json
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def select_eval_task_1():
|
||||
# Select Eval Tasks
|
||||
def select_benchmark_1():
|
||||
# Select Benchmarks
|
||||
st.subheader("1. Choose An Eval Task")
|
||||
eval_tasks = llama_stack_api.client.eval_tasks.list()
|
||||
eval_tasks = {et.identifier: et for et in eval_tasks}
|
||||
eval_tasks_names = list(eval_tasks.keys())
|
||||
selected_eval_task = st.selectbox(
|
||||
benchmarks = llama_stack_api.client.benchmarks.list()
|
||||
benchmarks = {et.identifier: et for et in benchmarks}
|
||||
benchmarks_names = list(benchmarks.keys())
|
||||
selected_benchmark = st.selectbox(
|
||||
"Choose an eval task.",
|
||||
options=eval_tasks_names,
|
||||
options=benchmarks_names,
|
||||
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
|
||||
)
|
||||
with st.expander("View Eval Task"):
|
||||
st.json(eval_tasks[selected_eval_task], expanded=True)
|
||||
st.json(benchmarks[selected_benchmark], expanded=True)
|
||||
|
||||
st.session_state["selected_eval_task"] = selected_eval_task
|
||||
st.session_state["eval_tasks"] = eval_tasks
|
||||
st.session_state["selected_benchmark"] = selected_benchmark
|
||||
st.session_state["benchmarks"] = benchmarks
|
||||
if st.button("Confirm", key="confirm_1"):
|
||||
st.session_state["selected_eval_task_1_next"] = True
|
||||
st.session_state["selected_benchmark_1_next"] = True
|
||||
|
||||
|
||||
def define_eval_candidate_2():
|
||||
if not st.session_state.get("selected_eval_task_1_next", None):
|
||||
if not st.session_state.get("selected_benchmark_1_next", None):
|
||||
return
|
||||
|
||||
st.subheader("2. Define Eval Candidate")
|
||||
|
|
@ -163,11 +161,11 @@ def run_evaluation_3():
|
|||
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
|
||||
"""
|
||||
)
|
||||
selected_eval_task = st.session_state["selected_eval_task"]
|
||||
eval_tasks = st.session_state["eval_tasks"]
|
||||
selected_benchmark = st.session_state["selected_benchmark"]
|
||||
benchmarks = st.session_state["benchmarks"]
|
||||
eval_candidate = st.session_state["eval_candidate"]
|
||||
|
||||
dataset_id = eval_tasks[selected_eval_task].dataset_id
|
||||
dataset_id = benchmarks[selected_benchmark].dataset_id
|
||||
rows = llama_stack_api.client.datasetio.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
|
|
@ -182,16 +180,16 @@ def run_evaluation_3():
|
|||
help="Number of examples from the dataset to evaluate. ",
|
||||
)
|
||||
|
||||
eval_task_config = {
|
||||
benchmark_config = {
|
||||
"type": "benchmark",
|
||||
"eval_candidate": eval_candidate,
|
||||
"scoring_params": {},
|
||||
}
|
||||
|
||||
with st.expander("View Evaluation Task", expanded=True):
|
||||
st.json(eval_tasks[selected_eval_task], expanded=True)
|
||||
st.json(benchmarks[selected_benchmark], expanded=True)
|
||||
with st.expander("View Evaluation Task Configuration", expanded=True):
|
||||
st.json(eval_task_config, expanded=True)
|
||||
st.json(benchmark_config, expanded=True)
|
||||
|
||||
# Add run button and handle evaluation
|
||||
if st.button("Run Evaluation"):
|
||||
|
|
@ -211,10 +209,10 @@ def run_evaluation_3():
|
|||
progress_bar.progress(progress, text=progress_text)
|
||||
# Run evaluation for current row
|
||||
eval_res = llama_stack_api.client.eval.evaluate_rows(
|
||||
task_id=selected_eval_task,
|
||||
benchmark_id=selected_benchmark,
|
||||
input_rows=[r],
|
||||
scoring_functions=eval_tasks[selected_eval_task].scoring_functions,
|
||||
task_config=eval_task_config,
|
||||
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
||||
task_config=benchmark_config,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
|
|
@ -227,7 +225,7 @@ def run_evaluation_3():
|
|||
output_res[k] = []
|
||||
output_res[k].append(eval_res.generations[0][k])
|
||||
|
||||
for scoring_fn in eval_tasks[selected_eval_task].scoring_functions:
|
||||
for scoring_fn in benchmarks[selected_benchmark].scoring_functions:
|
||||
if scoring_fn not in output_res:
|
||||
output_res[scoring_fn] = []
|
||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||
|
|
@ -247,7 +245,7 @@ def native_evaluation_page():
|
|||
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Generation + Scoring)")
|
||||
|
||||
select_eval_task_1()
|
||||
select_benchmark_1()
|
||||
define_eval_candidate_2()
|
||||
run_evaluation_3()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from llama_stack_client.lib.agents.agent import Agent
|
|||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
|
||||
|
||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||
|
|
|
|||
|
|
@ -8,13 +8,11 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
277
llama_stack/models/llama/datatypes.py
Normal file
277
llama_stack/models/llama/datatypes.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
# import all for backwards compatibility
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
register_schema(ToolCall)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = True
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
type: Literal["greedy"] = "greedy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopKSamplingStrategy(BaseModel):
|
||||
type: Literal["top_k"] = "top_k"
|
||||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = register_schema(
|
||||
Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="SamplingStrategy",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SamplingParams(BaseModel):
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
|
||||
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8-mixed"
|
||||
|
||||
int8 = "int8"
|
||||
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
class ModelFamily(Enum):
|
||||
llama2 = "llama2"
|
||||
llama3 = "llama3"
|
||||
llama3_1 = "llama3_1"
|
||||
llama3_2 = "llama3_2"
|
||||
llama3_3 = "llama3_3"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
class CoreModelId(Enum):
|
||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||
|
||||
# Llama 2 family
|
||||
llama2_7b = "Llama-2-7b"
|
||||
llama2_13b = "Llama-2-13b"
|
||||
llama2_70b = "Llama-2-70b"
|
||||
llama2_7b_chat = "Llama-2-7b-chat"
|
||||
llama2_13b_chat = "Llama-2-13b-chat"
|
||||
llama2_70b_chat = "Llama-2-70b-chat"
|
||||
|
||||
# Llama 3 family
|
||||
llama3_8b = "Llama-3-8B"
|
||||
llama3_70b = "Llama-3-70B"
|
||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||
|
||||
# Llama 3.1 family
|
||||
llama3_1_8b = "Llama3.1-8B"
|
||||
llama3_1_70b = "Llama3.1-70B"
|
||||
llama3_1_405b = "Llama3.1-405B"
|
||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||
|
||||
# Llama 3.2 family
|
||||
llama3_2_1b = "Llama3.2-1B"
|
||||
llama3_2_3b = "Llama3.2-3B"
|
||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||
|
||||
# Llama 3.3 family
|
||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||
|
||||
# Safety models
|
||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||
|
||||
|
||||
def is_multimodal(model_id) -> bool:
|
||||
if model_id in [
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def model_family(model_id) -> ModelFamily:
|
||||
if model_id in [
|
||||
CoreModelId.llama2_7b,
|
||||
CoreModelId.llama2_13b,
|
||||
CoreModelId.llama2_70b,
|
||||
CoreModelId.llama2_7b_chat,
|
||||
CoreModelId.llama2_13b_chat,
|
||||
CoreModelId.llama2_70b_chat,
|
||||
]:
|
||||
return ModelFamily.llama2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_8b,
|
||||
CoreModelId.llama3_70b,
|
||||
CoreModelId.llama3_8b_instruct,
|
||||
CoreModelId.llama3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_1_8b,
|
||||
CoreModelId.llama3_1_70b,
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_8b_instruct,
|
||||
CoreModelId.llama3_1_70b_instruct,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_1
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_2_1b,
|
||||
CoreModelId.llama3_2_3b,
|
||||
CoreModelId.llama3_2_1b_instruct,
|
||||
CoreModelId.llama3_2_3b_instruct,
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_3
|
||||
elif model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_2_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return ModelFamily.safety
|
||||
else:
|
||||
raise ValueError(f"Unknown model family for {model_id}")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
core_model_id: CoreModelId
|
||||
description: str
|
||||
huggingface_repo: Optional[str] = None
|
||||
recommended_sampling_params: Optional[SamplingParams] = None
|
||||
arch_args: Dict[str, Any]
|
||||
variant: str = ""
|
||||
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
pth_file_count: int
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
# silence pydantic until we remove the `model_` fields
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def model_family(self) -> ModelFamily:
|
||||
return model_family(self.core_model_id)
|
||||
|
||||
# The SKU is uniquely identified by (model_id, variant) combo
|
||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||
if not self.variant:
|
||||
return self.core_model_id.value
|
||||
return f"{self.core_model_id.value}:{self.variant}"
|
||||
|
||||
@property
|
||||
def is_instruct_model(self) -> bool:
|
||||
return "instruct" in self.id.name
|
||||
|
||||
# Featured models are shown in the non-exhaustive model list
|
||||
@property
|
||||
def is_featured(self) -> bool:
|
||||
return self.model_family in [
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.safety,
|
||||
]
|
||||
|
||||
@property
|
||||
def max_seq_length(self) -> int:
|
||||
if self.model_family == ModelFamily.llama2:
|
||||
return 4096
|
||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||
return 4096
|
||||
elif self.model_family == ModelFamily.llama3:
|
||||
return 8192
|
||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama3_2:
|
||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||
return 8192
|
||||
return 131072
|
||||
elif self.core_model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return 131072
|
||||
else:
|
||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
||||
BIN
llama_stack/models/llama/llama3/dog.jpg
Normal file
BIN
llama_stack/models/llama/llama3/dog.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 39 KiB |
257
llama_stack/models/llama/llama3/interface.py
Normal file
257
llama_stack/models/llama/llama3/interface.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
|
||||
from . import template_data
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
ToolResponseGenerator,
|
||||
)
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class Template:
|
||||
def __init__(
|
||||
self,
|
||||
role,
|
||||
template_name,
|
||||
data_provider=None,
|
||||
notes=None,
|
||||
):
|
||||
self.role = role
|
||||
self.template_name = template_name
|
||||
self.data_provider = data_provider or ""
|
||||
self._notes = notes or ""
|
||||
|
||||
@property
|
||||
def notes(self):
|
||||
default = "↵ represents newline"
|
||||
notes = default
|
||||
if self._notes:
|
||||
notes += "\n"
|
||||
notes += self._notes
|
||||
return notes
|
||||
|
||||
|
||||
TEMPLATES = [
|
||||
Template(
|
||||
"user",
|
||||
"user-default",
|
||||
"user_default",
|
||||
),
|
||||
Template(
|
||||
"user",
|
||||
"user-images",
|
||||
"user_images",
|
||||
),
|
||||
Template("user", "user-interleaved-images", "user_interleaved_images"),
|
||||
Template(
|
||||
"assistant",
|
||||
"assistant-builtin-tool-call",
|
||||
"assistant_builtin_tool_call",
|
||||
"Notice <|python_tag|>",
|
||||
),
|
||||
Template(
|
||||
"assistant",
|
||||
"assistant-custom-tool-call",
|
||||
"assistant_custom_tool_call",
|
||||
"Notice <function=...> format",
|
||||
),
|
||||
Template(
|
||||
"assistant",
|
||||
"assistant-default",
|
||||
"assistant_default",
|
||||
),
|
||||
Template(
|
||||
"system",
|
||||
"system-builtin-and-custom-tools",
|
||||
"system_message_builtin_and_custom_tools",
|
||||
),
|
||||
Template(
|
||||
"system",
|
||||
"system-builtin-tools-only",
|
||||
"system_message_builtin_tools_only",
|
||||
),
|
||||
Template(
|
||||
"system",
|
||||
"system-custom-tools-only",
|
||||
"system_message_custom_tools_only",
|
||||
),
|
||||
Template(
|
||||
"system",
|
||||
"system-default",
|
||||
"system_default",
|
||||
),
|
||||
Template(
|
||||
"tool",
|
||||
"tool-success",
|
||||
"tool_success",
|
||||
"Note ipython header and [stdout]",
|
||||
),
|
||||
Template(
|
||||
"tool",
|
||||
"tool-failure",
|
||||
"tool_failure",
|
||||
"Note ipython header and [stderr]",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class LLama31Interface:
|
||||
def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json):
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
self.tool_prompt_format = tool_prompt_format
|
||||
|
||||
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
|
||||
model_input = self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
self.tool_prompt_format,
|
||||
)
|
||||
return model_input.tokens
|
||||
|
||||
def tool_response_messages(self, *args, **kwargs):
|
||||
template = ToolResponseGenerator().gen(*args, **kwargs)
|
||||
return [
|
||||
RawMessage(
|
||||
role="tool",
|
||||
content=template.render(),
|
||||
)
|
||||
]
|
||||
|
||||
def system_messages(
|
||||
self,
|
||||
builtin_tools: List[BuiltinTool],
|
||||
custom_tools: List[ToolDefinition],
|
||||
instruction: Optional[str] = None,
|
||||
) -> List[RawMessage]:
|
||||
messages = []
|
||||
|
||||
default_gen = SystemDefaultGenerator()
|
||||
default_template = default_gen.gen()
|
||||
|
||||
sys_content = ""
|
||||
|
||||
tool_template = None
|
||||
if builtin_tools or custom_tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(builtin_tools + custom_tools)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
sys_content += default_template.render()
|
||||
|
||||
if instruction:
|
||||
sys_content += "\n\n"
|
||||
sys_content += instruction
|
||||
|
||||
sys_content += "\n"
|
||||
messages.append(RawMessage(role="system", content=sys_content))
|
||||
|
||||
if custom_tools:
|
||||
if self.tool_prompt_format == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
elif self.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
tool_gen = FunctionTagCustomToolGenerator()
|
||||
else:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
|
||||
|
||||
custom_template = tool_gen.gen(custom_tools)
|
||||
messages.append(RawMessage(role="user", content=custom_template.render()))
|
||||
|
||||
return messages
|
||||
|
||||
def assistant_response_messages(
|
||||
self,
|
||||
content: str,
|
||||
stop_reason: StopReason,
|
||||
tool_call: Optional[ToolCall] = None,
|
||||
) -> List[RawMessage]:
|
||||
tool_calls = []
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
return [
|
||||
RawMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
]
|
||||
|
||||
def user_message(self, content: str) -> List[RawMessage]:
|
||||
return [RawMessage(role="user", content=content)]
|
||||
|
||||
def display_message_as_tokens(self, message: RawMessage) -> None:
|
||||
"""Util to print tokenized string to shell"""
|
||||
tokens = self.formatter.encode_message(message, self.tool_prompt_format)
|
||||
on_colors = [
|
||||
"on_red",
|
||||
"on_green",
|
||||
"on_yellow",
|
||||
"on_blue",
|
||||
"on_magenta",
|
||||
"on_cyan",
|
||||
]
|
||||
for i, t in enumerate(tokens):
|
||||
on_col = on_colors[i % len(on_colors)]
|
||||
print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
|
||||
print("\n", end="")
|
||||
|
||||
|
||||
def list_jinja_templates() -> List[Template]:
|
||||
return TEMPLATES
|
||||
|
||||
|
||||
def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
|
||||
by_name = {t.template_name: t for t in TEMPLATES}
|
||||
if name not in by_name:
|
||||
raise ValueError(f"No template found for `{name}`")
|
||||
|
||||
template = by_name[name]
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
data_func = getattr(template_data, template.data_provider)
|
||||
if template.role == "system":
|
||||
messages = interface.system_messages(**data_func())
|
||||
elif template.role == "tool":
|
||||
messages = interface.tool_response_messages(**data_func())
|
||||
elif template.role == "assistant":
|
||||
messages = interface.assistant_response_messages(**data_func())
|
||||
elif template.role == "user":
|
||||
messages = interface.user_message(**data_func())
|
||||
|
||||
tokens = interface.get_tokens(messages)
|
||||
special_tokens = list(interface.tokenizer.special_tokens.values())
|
||||
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
|
||||
return template, tokens
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue