mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
Merge remote-tracking branch 'upstream/main' into add_nvidia_safety_provider
Merging upstream changes
This commit is contained in:
commit
688e1806d1
227 changed files with 7536 additions and 3147 deletions
|
@ -29,13 +29,8 @@ repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.4
|
rev: v0.9.4
|
||||||
hooks:
|
hooks:
|
||||||
# Run the linter with import sorting.
|
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [
|
exclude: ^llama_stack/strong_typing/.*$
|
||||||
--fix,
|
|
||||||
--exit-non-zero-on-fix,
|
|
||||||
--select, I,
|
|
||||||
]
|
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
- repo: https://github.com/adamchainz/blacken-docs
|
- repo: https://github.com/adamchainz/blacken-docs
|
||||||
|
@ -49,7 +44,13 @@ repos:
|
||||||
rev: 0.5.26
|
rev: 0.5.26
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
args: ["--frozen", "--no-hashes", "--no-emit-project"]
|
args: [
|
||||||
|
"--frozen",
|
||||||
|
"--no-hashes",
|
||||||
|
"--no-emit-project",
|
||||||
|
"--output-file=requirements.txt"
|
||||||
|
]
|
||||||
|
files: ^pyproject\.toml$
|
||||||
- id: uv-sync
|
- id: uv-sync
|
||||||
|
|
||||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
|
37
.ruff.toml
37
.ruff.toml
|
@ -1,37 +0,0 @@
|
||||||
# Suggested config from pytorch that we can adapt
|
|
||||||
lint.select = ["B", "C", "E" , "F" , "N", "W", "B9"]
|
|
||||||
|
|
||||||
line-length = 120
|
|
||||||
|
|
||||||
# C408 ignored because we like the dict keyword argument syntax
|
|
||||||
# E501 is not flexible enough, we're using B950 instead
|
|
||||||
# N812 ignored because import torch.nn.functional as F is PyTorch convention
|
|
||||||
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
|
|
||||||
# E731 allow usage of assigning lambda expressions
|
|
||||||
# E701 let black auto-format statements on one line
|
|
||||||
# E704 let black auto-format statements on one line
|
|
||||||
lint.ignore = [
|
|
||||||
"E203", "E305", "E402", "E501", "E721", "E741", "F405", "F821", "F841",
|
|
||||||
"C408", "E302", "W291", "E303", "N812", "N817", "E731", "E701",
|
|
||||||
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
|
|
||||||
"C901", "C405", "C414", "N803", "N999", "C403", "C416", "B028", "C419", "C401", "B023",
|
|
||||||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
|
||||||
# to line this up with executable bit
|
|
||||||
"EXE001",
|
|
||||||
# random naming hints don't need
|
|
||||||
"N802",
|
|
||||||
# these ignores are from flake8-bugbear; please fix!
|
|
||||||
"B007", "B008"
|
|
||||||
]
|
|
||||||
|
|
||||||
exclude = [
|
|
||||||
"./.git",
|
|
||||||
"./docs/*",
|
|
||||||
"./build",
|
|
||||||
"./scripts",
|
|
||||||
"./venv",
|
|
||||||
"*.pyi",
|
|
||||||
".pre-commit-config.yaml",
|
|
||||||
"*.md",
|
|
||||||
".flake8"
|
|
||||||
]
|
|
2771
docs/_static/llama-stack-spec.html
vendored
2771
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1670
docs/_static/llama-stack-spec.yaml
vendored
1670
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -324,7 +324,7 @@
|
||||||
"- vector_io\n",
|
"- vector_io\n",
|
||||||
"container_image: null\n",
|
"container_image: null\n",
|
||||||
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
|
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
|
||||||
"eval_tasks: <span style=\"font-weight: bold\">[]</span>\n",
|
"benchmarks: <span style=\"font-weight: bold\">[]</span>\n",
|
||||||
"image_name: together\n",
|
"image_name: together\n",
|
||||||
"metadata_store:\n",
|
"metadata_store:\n",
|
||||||
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/ashwin/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">registry.db</span>\n",
|
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/ashwin/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">registry.db</span>\n",
|
||||||
|
@ -508,7 +508,7 @@
|
||||||
"- vector_io\n",
|
"- vector_io\n",
|
||||||
"container_image: null\n",
|
"container_image: null\n",
|
||||||
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
"eval_tasks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
"benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
"image_name: together\n",
|
"image_name: together\n",
|
||||||
"metadata_store:\n",
|
"metadata_store:\n",
|
||||||
" db_path: \u001b[35m/Users/ashwin/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n",
|
" db_path: \u001b[35m/Users/ashwin/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n",
|
||||||
|
@ -3419,22 +3419,22 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 3,
|
||||||
"id": "865fc5a8",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!pip install llama-stack-client==0.1.0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "44e05e16",
|
"id": "44e05e16",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
|
||||||
|
" Dload Upload Total Spent Left Speed\n",
|
||||||
|
"100 275k 100 275k 0 0 780k 0 --:--:-- --:--:-- --:--:-- 780k\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"!wget https://raw.githubusercontent.com/meta-llama/llama-models/refs/heads/main/Llama_Repo.jpeg"
|
"!curl -O https://raw.githubusercontent.com/meta-llama/llama-models/refs/heads/main/Llama_Repo.jpeg"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -3444,6 +3444,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# NBVAL_SKIP\n",
|
||||||
"from PIL import Image\n",
|
"from PIL import Image\n",
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -3580,6 +3581,7 @@
|
||||||
" model=LLAMA32_11B_INSTRUCT,\n",
|
" model=LLAMA32_11B_INSTRUCT,\n",
|
||||||
" instructions=\"You are a helpful assistant\",\n",
|
" instructions=\"You are a helpful assistant\",\n",
|
||||||
" enable_session_persistence=False,\n",
|
" enable_session_persistence=False,\n",
|
||||||
|
" toolgroups=[],\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" agent = Agent(client, agent_config)\n",
|
" agent = Agent(client, agent_config)\n",
|
||||||
|
@ -3630,7 +3632,7 @@
|
||||||
"provenance": []
|
"provenance": []
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "toolchain",
|
"display_name": "master",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
@ -3644,7 +3646,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.16"
|
||||||
},
|
},
|
||||||
"widgets": {
|
"widgets": {
|
||||||
"application/vnd.jupyter.widget-state+json": {
|
"application/vnd.jupyter.widget-state+json": {
|
||||||
|
|
|
@ -370,7 +370,7 @@
|
||||||
"- tool_runtime\n",
|
"- tool_runtime\n",
|
||||||
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
|
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
|
||||||
"container_image: null\n",
|
"container_image: null\n",
|
||||||
"eval_tasks: <span style=\"font-weight: bold\">[]</span>\n",
|
"benchmarks: <span style=\"font-weight: bold\">[]</span>\n",
|
||||||
"image_name: together\n",
|
"image_name: together\n",
|
||||||
"memory_banks: <span style=\"font-weight: bold\">[]</span>\n",
|
"memory_banks: <span style=\"font-weight: bold\">[]</span>\n",
|
||||||
"metadata_store:\n",
|
"metadata_store:\n",
|
||||||
|
@ -551,7 +551,7 @@
|
||||||
"- tool_runtime\n",
|
"- tool_runtime\n",
|
||||||
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
"container_image: null\n",
|
"container_image: null\n",
|
||||||
"eval_tasks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
"benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
"image_name: together\n",
|
"image_name: together\n",
|
||||||
"memory_banks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
"memory_banks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
"metadata_store:\n",
|
"metadata_store:\n",
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/[<subdir>]/api/endpoints.py` using the `generate.py` utility.
|
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility.
|
||||||
|
|
||||||
Please install the following packages before running the script:
|
Please install the following packages before running the script:
|
||||||
|
|
||||||
|
@ -6,4 +6,4 @@ Please install the following packages before running the script:
|
||||||
pip install python-openapi json-strong-typing fire PyYAML llama-models
|
pip install python-openapi json-strong-typing fire PyYAML llama-models
|
||||||
```
|
```
|
||||||
|
|
||||||
Then simply run `sh run_openapi_generator.sh <OUTPUT_DIR>`
|
Then simply run `sh run_openapi_generator.sh`
|
||||||
|
|
|
@ -16,18 +16,6 @@ from pathlib import Path
|
||||||
import fire
|
import fire
|
||||||
import ruamel.yaml as yaml
|
import ruamel.yaml as yaml
|
||||||
|
|
||||||
from llama_models import schema_utils
|
|
||||||
|
|
||||||
# We do some monkey-patching to ensure our definitions only use the minimal
|
|
||||||
# (json_schema_type, webmethod) definitions from the llama_models package. For
|
|
||||||
# generation though, we need the full definitions and implementations from the
|
|
||||||
# (json-strong-typing) package.
|
|
||||||
|
|
||||||
from .strong_typing.schema import json_schema_type, register_schema
|
|
||||||
|
|
||||||
schema_utils.json_schema_type = json_schema_type
|
|
||||||
schema_utils.register_schema = register_schema
|
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402
|
||||||
from llama_stack.distribution.stack import LlamaStack # noqa: E402
|
from llama_stack.distribution.stack import LlamaStack # noqa: E402
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ import typing
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
from typing import Any, Dict, Set, Union
|
from typing import Any, Dict, Set, Union
|
||||||
|
|
||||||
from ..strong_typing.core import JsonType
|
from llama_stack.strong_typing.core import JsonType
|
||||||
from ..strong_typing.docstring import Docstring, parse_type
|
from llama_stack.strong_typing.docstring import Docstring, parse_type
|
||||||
from ..strong_typing.inspection import (
|
from llama_stack.strong_typing.inspection import (
|
||||||
is_generic_list,
|
is_generic_list,
|
||||||
is_type_optional,
|
is_type_optional,
|
||||||
is_type_union,
|
is_type_union,
|
||||||
|
@ -20,15 +20,15 @@ from ..strong_typing.inspection import (
|
||||||
unwrap_optional_type,
|
unwrap_optional_type,
|
||||||
unwrap_union_types,
|
unwrap_union_types,
|
||||||
)
|
)
|
||||||
from ..strong_typing.name import python_type_to_name
|
from llama_stack.strong_typing.name import python_type_to_name
|
||||||
from ..strong_typing.schema import (
|
from llama_stack.strong_typing.schema import (
|
||||||
get_schema_identifier,
|
get_schema_identifier,
|
||||||
JsonSchemaGenerator,
|
JsonSchemaGenerator,
|
||||||
register_schema,
|
register_schema,
|
||||||
Schema,
|
Schema,
|
||||||
SchemaOptions,
|
SchemaOptions,
|
||||||
)
|
)
|
||||||
from ..strong_typing.serialization import json_dump_string, object_to_json
|
from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
|
||||||
|
|
||||||
from .operations import (
|
from .operations import (
|
||||||
EndpointOperation,
|
EndpointOperation,
|
||||||
|
@ -647,6 +647,7 @@ class Generator:
|
||||||
description = "\n".join(
|
description = "\n".join(
|
||||||
filter(None, [doc_string.short_description, doc_string.long_description])
|
filter(None, [doc_string.short_description, doc_string.long_description])
|
||||||
)
|
)
|
||||||
|
|
||||||
return Operation(
|
return Operation(
|
||||||
tags=[op.defining_class.__name__],
|
tags=[op.defining_class.__name__],
|
||||||
summary=None,
|
summary=None,
|
||||||
|
@ -656,6 +657,7 @@ class Generator:
|
||||||
requestBody=requestBody,
|
requestBody=requestBody,
|
||||||
responses=responses,
|
responses=responses,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
deprecated=True if "DEPRECATED" in op.func_name else None,
|
||||||
security=[] if op.public else None,
|
security=[] if op.public else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from ..strong_typing.inspection import get_signature
|
from llama_stack.strong_typing.inspection import get_signature
|
||||||
|
|
||||||
|
|
||||||
def split_prefix(
|
def split_prefix(
|
||||||
|
|
|
@ -9,7 +9,7 @@ import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, ClassVar, Dict, List, Optional, Union
|
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||||
|
|
||||||
from ..strong_typing.schema import JsonType, Schema, StrictJsonType
|
from llama_stack.strong_typing.schema import JsonType, Schema, StrictJsonType
|
||||||
|
|
||||||
URL = str
|
URL = str
|
||||||
|
|
||||||
|
@ -117,6 +117,7 @@ class Operation:
|
||||||
requestBody: Optional[RequestBody] = None
|
requestBody: Optional[RequestBody] = None
|
||||||
callbacks: Optional[Dict[str, "Callback"]] = None
|
callbacks: Optional[Dict[str, "Callback"]] = None
|
||||||
security: Optional[List["SecurityRequirement"]] = None
|
security: Optional[List["SecurityRequirement"]] = None
|
||||||
|
deprecated: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -9,7 +9,7 @@ import typing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
from ..strong_typing.schema import object_to_json, StrictJsonType
|
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
|
||||||
|
|
||||||
from .generator import Generator
|
from .generator import Generator
|
||||||
from .options import Options
|
from .options import Options
|
||||||
|
|
|
@ -41,14 +41,14 @@ system_message = {
|
||||||
"content": SYSTEM_PROMPT_TEMPLATE,
|
"content": SYSTEM_PROMPT_TEMPLATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
client.eval_tasks.register(
|
client.benchmarks.register(
|
||||||
eval_task_id="meta-reference::mmmu",
|
benchmark_id="meta-reference::mmmu",
|
||||||
dataset_id=f"mmmu-{subset}-{split}",
|
dataset_id=f"mmmu-{subset}-{split}",
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
task_id="meta-reference::mmmu",
|
benchmark_id="meta-reference::mmmu",
|
||||||
input_rows=eval_rows,
|
input_rows=eval_rows,
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
task_config={
|
task_config={
|
||||||
|
@ -99,14 +99,14 @@ eval_rows = client.datasetio.get_rows_paginated(
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client.eval_tasks.register(
|
client.benchmarks.register(
|
||||||
eval_task_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
dataset_id=simpleqa_dataset_id,
|
dataset_id=simpleqa_dataset_id,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
task_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
task_config={
|
||||||
|
@ -156,7 +156,7 @@ agent_config = {
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
task_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
task_config={
|
||||||
|
|
|
@ -10,15 +10,15 @@ Here's how to set up basic evaluation:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Create an evaluation task
|
# Create an evaluation task
|
||||||
response = client.eval_tasks.register(
|
response = client.benchmarks.register(
|
||||||
eval_task_id="my_eval",
|
benchmark_id="my_eval",
|
||||||
dataset_id="my_dataset",
|
dataset_id="my_dataset",
|
||||||
scoring_functions=["accuracy", "relevance"],
|
scoring_functions=["accuracy", "relevance"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run evaluation
|
# Run evaluation
|
||||||
job = client.eval.run_eval(
|
job = client.eval.run_eval(
|
||||||
task_id="my_eval",
|
benchmark_id="my_eval",
|
||||||
task_config={
|
task_config={
|
||||||
"type": "app",
|
"type": "app",
|
||||||
"eval_candidate": {"type": "agent", "config": agent_config},
|
"eval_candidate": {"type": "agent", "config": agent_config},
|
||||||
|
@ -26,5 +26,5 @@ job = client.eval.run_eval(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get results
|
# Get results
|
||||||
result = client.eval.job_result(task_id="my_eval", job_id=job.job_id)
|
result = client.eval.job_result(benchmark_id="my_eval", job_id=job.job_id)
|
||||||
```
|
```
|
||||||
|
|
|
@ -5,7 +5,7 @@ The Llama Stack Evaluation flow allows you to run evaluations on your GenAI appl
|
||||||
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
||||||
- `/datasetio` + `/datasets` API
|
- `/datasetio` + `/datasets` API
|
||||||
- `/scoring` + `/scoring_functions` API
|
- `/scoring` + `/scoring_functions` API
|
||||||
- `/eval` + `/eval_tasks` API
|
- `/eval` + `/benchmarks` API
|
||||||
|
|
||||||
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
||||||
- **Scoring**: evaluate outputs of the system.
|
- **Scoring**: evaluate outputs of the system.
|
||||||
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
||||||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
||||||
- Associated with `EvalTask` resource.
|
- Associated with `Benchmark` resource.
|
||||||
|
|
||||||
|
|
||||||
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
||||||
|
|
|
@ -42,7 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi
|
||||||
- **Tool Runtime** is associated with `ToolGroup` resources.
|
- **Tool Runtime** is associated with `ToolGroup` resources.
|
||||||
- **DatasetIO** is associated with `Dataset` resources.
|
- **DatasetIO** is associated with `Dataset` resources.
|
||||||
- **Scoring** is associated with `ScoringFunction` resources.
|
- **Scoring** is associated with `ScoringFunction` resources.
|
||||||
- **Eval** is associated with `Model` and `EvalTask` resources.
|
- **Eval** is associated with `Model` and `Benchmark` resources.
|
||||||
|
|
||||||
Furthermore, we allow these resources to be **federated** across multiple providers. For example, you may have some Llama models served by Fireworks while others are served by AWS Bedrock. Regardless, they will all work seamlessly with the same uniform Inference API provided by Llama Stack.
|
Furthermore, we allow these resources to be **federated** across multiple providers. For example, you may have some Llama models served by Fireworks while others are served by AWS Bedrock. Regardless, they will all work seamlessly with the same uniform Inference API provided by Llama Stack.
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,8 @@ The main points to consider are:
|
||||||
```
|
```
|
||||||
llama stack build -h
|
llama stack build -h
|
||||||
|
|
||||||
usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates | --no-list-templates] [--image-type {conda,container,venv}] [--image-name IMAGE_NAME]
|
usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates]
|
||||||
|
[--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only]
|
||||||
|
|
||||||
Build a Llama stack container
|
Build a Llama stack container
|
||||||
|
|
||||||
|
@ -32,14 +33,14 @@ options:
|
||||||
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml.
|
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml.
|
||||||
If this argument is not provided, you will be prompted to enter information interactively
|
If this argument is not provided, you will be prompted to enter information interactively
|
||||||
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
|
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
|
||||||
--list-templates, --no-list-templates
|
--list-templates Show the available templates for building a Llama Stack distribution
|
||||||
Show the available templates for building a Llama Stack distribution (default: False)
|
|
||||||
--image-type {conda,container,venv}
|
--image-type {conda,container,venv}
|
||||||
Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config.
|
Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config.
|
||||||
--image-name IMAGE_NAME
|
--image-name IMAGE_NAME
|
||||||
[for image-type=conda] Name of the conda environment to use for the build. If
|
[for image-type=conda] Name of the conda environment to use for the build. If
|
||||||
not specified, currently active Conda environment will be used. If no Conda
|
not specified, currently active Conda environment will be used. If no Conda
|
||||||
environment is active, you must specify a name.
|
environment is active, you must specify a name.
|
||||||
|
--print-deps-only Print the dependencies for the stack only, without building the stack
|
||||||
```
|
```
|
||||||
|
|
||||||
After this step is complete, a file named `<name>-build.yaml` and template file `<name>-run.yaml` will be generated and saved at the output file path specified at the end of the command.
|
After this step is complete, a file named `<name>-build.yaml` and template file `<name>-run.yaml` will be generated and saved at the output file path specified at the end of the command.
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
```{admonition} News
|
```{admonition} News
|
||||||
:class: tip
|
:class: tip
|
||||||
|
|
||||||
Llama Stack 0.1.2 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.2) for more details.
|
Llama Stack 0.1.3 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.3) for more details.
|
||||||
```
|
```
|
||||||
|
|
||||||
# Llama Stack
|
# Llama Stack
|
||||||
|
|
|
@ -64,7 +64,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client eval_tasks register \
|
$ llama-stack-client benchmarks register \
|
||||||
--eval-task-id meta-reference-mmlu \
|
--eval-task-id meta-reference-mmlu \
|
||||||
--provider-id meta-reference \
|
--provider-id meta-reference \
|
||||||
--dataset-id mmlu \
|
--dataset-id mmlu \
|
||||||
|
@ -86,7 +86,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
|
||||||
- Under the hood, it uses Llama Stack's `/providers` API to get information about the providers.
|
- Under the hood, it uses Llama Stack's `/providers` API to get information about the providers.
|
||||||
|
|
||||||
- **API Resources**: Inspect Llama Stack API resources
|
- **API Resources**: Inspect Llama Stack API resources
|
||||||
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `eval_tasks`, `shields`).
|
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`).
|
||||||
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
|
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
|
||||||
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
|
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ The Llama Stack Evaluation flow allows you to run evaluations on your GenAI appl
|
||||||
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
||||||
- `/datasetio` + `/datasets` API
|
- `/datasetio` + `/datasets` API
|
||||||
- `/scoring` + `/scoring_functions` API
|
- `/scoring` + `/scoring_functions` API
|
||||||
- `/eval` + `/eval_tasks` API
|
- `/eval` + `/benchmarks` API
|
||||||
|
|
||||||
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
||||||
- **Scoring**: evaluate outputs of the system.
|
- **Scoring**: evaluate outputs of the system.
|
||||||
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
||||||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
||||||
- Associated with `EvalTask` resource.
|
- Associated with `Benchmark` resource.
|
||||||
|
|
||||||
|
|
||||||
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
||||||
|
@ -77,14 +77,14 @@ system_message = {
|
||||||
"content": SYSTEM_PROMPT_TEMPLATE,
|
"content": SYSTEM_PROMPT_TEMPLATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
client.eval_tasks.register(
|
client.benchmarks.register(
|
||||||
eval_task_id="meta-reference::mmmu",
|
benchmark_id="meta-reference::mmmu",
|
||||||
dataset_id=f"mmmu-{subset}-{split}",
|
dataset_id=f"mmmu-{subset}-{split}",
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
task_id="meta-reference::mmmu",
|
benchmark_id="meta-reference::mmmu",
|
||||||
input_rows=eval_rows,
|
input_rows=eval_rows,
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
task_config={
|
task_config={
|
||||||
|
@ -135,14 +135,14 @@ eval_rows = client.datasetio.get_rows_paginated(
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client.eval_tasks.register(
|
client.benchmarks.register(
|
||||||
eval_task_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
dataset_id=simpleqa_dataset_id,
|
dataset_id=simpleqa_dataset_id,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
task_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
task_config={
|
||||||
|
@ -192,7 +192,7 @@ agent_config = {
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
task_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
task_config={
|
||||||
|
@ -281,7 +281,7 @@ The following examples give the quick steps to start running evaluations using t
|
||||||
|
|
||||||
#### Benchmark Evaluation CLI
|
#### Benchmark Evaluation CLI
|
||||||
Usage: There are 2 inputs necessary for running a benchmark eval
|
Usage: There are 2 inputs necessary for running a benchmark eval
|
||||||
- `eval-task-id`: the identifier associated with the eval task. Each `EvalTask` is parametrized by
|
- `eval-task-id`: the identifier associated with the eval task. Each `Benchmark` is parametrized by
|
||||||
- `dataset_id`: the identifier associated with the dataset.
|
- `dataset_id`: the identifier associated with the dataset.
|
||||||
- `List[scoring_function_id]`: list of scoring function identifiers.
|
- `List[scoring_function_id]`: list of scoring function identifiers.
|
||||||
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
|
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
|
||||||
|
@ -289,7 +289,7 @@ Usage: There are 2 inputs necessary for running a benchmark eval
|
||||||
|
|
||||||
```
|
```
|
||||||
llama-stack-client eval run_benchmark <eval-task-id> \
|
llama-stack-client eval run_benchmark <eval-task-id> \
|
||||||
--eval-task-config ~/eval_task_config.json \
|
--eval-task-config ~/benchmark_config.json \
|
||||||
--visualize
|
--visualize
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -309,15 +309,15 @@ llama-stack-client eval run_scoring <scoring_fn_id_1> <scoring_fn_id_2> ... <sco
|
||||||
--output-dir ./
|
--output-dir ./
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Defining EvalTaskConfig
|
#### Defining BenchmarkConfig
|
||||||
The `EvalTaskConfig` are user specified config to define:
|
The `BenchmarkConfig` are user specified config to define:
|
||||||
1. `EvalCandidate` to run generation on:
|
1. `EvalCandidate` to run generation on:
|
||||||
- `ModelCandidate`: The model will be used for generation through LlamaStack /inference API.
|
- `ModelCandidate`: The model will be used for generation through LlamaStack /inference API.
|
||||||
- `AgentCandidate`: The agentic system specified by AgentConfig will be used for generation through LlamaStack /agents API.
|
- `AgentCandidate`: The agentic system specified by AgentConfig will be used for generation through LlamaStack /agents API.
|
||||||
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
|
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
|
||||||
|
|
||||||
|
|
||||||
**Example Benchmark EvalTaskConfig**
|
**Example Benchmark BenchmarkConfig**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
|
@ -335,7 +335,7 @@ The `EvalTaskConfig` are user specified config to define:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**Example Application EvalTaskConfig**
|
**Example Application BenchmarkConfig**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"type": "app",
|
"type": "app",
|
||||||
|
|
|
@ -39,7 +39,7 @@ You should see a table like this:
|
||||||
|
|
||||||
```
|
```
|
||||||
+----------------------------------+------------------------------------------+----------------+
|
+----------------------------------+------------------------------------------+----------------+
|
||||||
| Model Descriptor | Hugging Face Repo | Context Length |
|
| Model Descriptor(ID) | Hugging Face Repo | Context Length |
|
||||||
+----------------------------------+------------------------------------------+----------------+
|
+----------------------------------+------------------------------------------+----------------+
|
||||||
| Llama3.1-8B | meta-llama/Llama-3.1-8B | 128K |
|
| Llama3.1-8B | meta-llama/Llama-3.1-8B | 128K |
|
||||||
+----------------------------------+------------------------------------------+----------------+
|
+----------------------------------+------------------------------------------+----------------+
|
||||||
|
|
|
@ -63,7 +63,7 @@ You should see a table like this:
|
||||||
|
|
||||||
```
|
```
|
||||||
+----------------------------------+------------------------------------------+----------------+
|
+----------------------------------+------------------------------------------+----------------+
|
||||||
| Model Descriptor | Hugging Face Repo | Context Length |
|
| Model Descriptor(ID) | Hugging Face Repo | Context Length |
|
||||||
+----------------------------------+------------------------------------------+----------------+
|
+----------------------------------+------------------------------------------+----------------+
|
||||||
| Llama3.1-8B | meta-llama/Llama-3.1-8B | 128K |
|
| Llama3.1-8B | meta-llama/Llama-3.1-8B | 128K |
|
||||||
+----------------------------------+------------------------------------------+----------------+
|
+----------------------------------+------------------------------------------+----------------+
|
||||||
|
|
|
@ -161,14 +161,14 @@ Options:
|
||||||
|
|
||||||
## Eval Task Management
|
## Eval Task Management
|
||||||
|
|
||||||
### `llama-stack-client eval_tasks list`
|
### `llama-stack-client benchmarks list`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client eval_tasks list
|
$ llama-stack-client benchmarks list
|
||||||
```
|
```
|
||||||
|
|
||||||
### `llama-stack-client eval_tasks register`
|
### `llama-stack-client benchmarks register`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client eval_tasks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
|
$ llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
|
@ -191,7 +191,7 @@ Options:
|
||||||
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
|
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
|
||||||
- `--visualize`: Optional flag. If set, visualizes evaluation results after completion
|
- `--visualize`: Optional flag. If set, visualizes evaluation results after completion
|
||||||
|
|
||||||
Example eval_task_config.json:
|
Example benchmark_config.json:
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
|
|
|
@ -181,8 +181,8 @@ from llama_stack_client.types import EvaluateResponse, Job
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
|
|
||||||
- <code title="post /v1/eval/tasks/{task_id}/evaluations">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(task_id, \*\*<a href="src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
- <code title="post /v1/eval/tasks/{benchmark_id}/evaluations">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||||
- <code title="post /v1/eval/tasks/{task_id}/jobs">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">run_eval</a>(task_id, \*\*<a href="src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="./src/llama_stack_client/types/job.py">Job</a></code>
|
- <code title="post /v1/eval/tasks/{benchmark_id}/jobs">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">run_eval</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="./src/llama_stack_client/types/job.py">Job</a></code>
|
||||||
|
|
||||||
### Jobs
|
### Jobs
|
||||||
|
|
||||||
|
@ -194,9 +194,9 @@ from llama_stack_client.types.eval import JobStatusResponse
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
|
|
||||||
- <code title="get /v1/eval/tasks/{task_id}/jobs/{job_id}/result">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, task_id) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}/result">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, benchmark_id) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||||
- <code title="delete /v1/eval/tasks/{task_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, task_id) -> None</code>
|
- <code title="delete /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, benchmark_id) -> None</code>
|
||||||
- <code title="get /v1/eval/tasks/{task_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, task_id) -> Optional[JobStatusResponse]</code>
|
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, benchmark_id) -> Optional[JobStatusResponse]</code>
|
||||||
|
|
||||||
## Inspect
|
## Inspect
|
||||||
|
|
||||||
|
@ -443,20 +443,20 @@ Methods:
|
||||||
- <code title="get /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">list</a>() -> <a href="./src/llama_stack_client/types/scoring_function_list_response.py">ScoringFunctionListResponse</a></code>
|
- <code title="get /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">list</a>() -> <a href="./src/llama_stack_client/types/scoring_function_list_response.py">ScoringFunctionListResponse</a></code>
|
||||||
- <code title="post /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">register</a>(\*\*<a href="src/llama_stack_client/types/scoring_function_register_params.py">params</a>) -> None</code>
|
- <code title="post /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">register</a>(\*\*<a href="src/llama_stack_client/types/scoring_function_register_params.py">params</a>) -> None</code>
|
||||||
|
|
||||||
## EvalTasks
|
## Benchmarks
|
||||||
|
|
||||||
Types:
|
Types:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.types import (
|
from llama_stack_client.types import (
|
||||||
EvalTask,
|
Benchmark,
|
||||||
ListEvalTasksResponse,
|
ListBenchmarksResponse,
|
||||||
EvalTaskListResponse,
|
BenchmarkListResponse,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
|
|
||||||
- <code title="get /v1/eval-tasks/{eval_task_id}">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">retrieve</a>(eval_task_id) -> <a href="./src/llama_stack_client/types/eval_task.py">Optional[EvalTask]</a></code>
|
- <code title="get /v1/eval-tasks/{benchmark_id}">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">retrieve</a>(benchmark_id) -> <a href="./src/llama_stack_client/types/benchmark.py">Optional[Benchmark]</a></code>
|
||||||
- <code title="get /v1/eval-tasks">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">list</a>() -> <a href="./src/llama_stack_client/types/eval_task_list_response.py">EvalTaskListResponse</a></code>
|
- <code title="get /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">list</a>() -> <a href="./src/llama_stack_client/types/benchmark_list_response.py">BenchmarkListResponse</a></code>
|
||||||
- <code title="post /v1/eval-tasks">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">register</a>(\*\*<a href="src/llama_stack_client/types/eval_task_register_params.py">params</a>) -> None</code>
|
- <code title="post /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">register</a>(\*\*<a href="src/llama_stack_client/types/benchmark_register_params.py">params</a>) -> None</code>
|
||||||
|
|
|
@ -19,7 +19,6 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
|
@ -38,6 +37,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
|
@ -179,7 +179,7 @@ class AgentConfigCommon(BaseModel):
|
||||||
class AgentConfig(AgentConfigCommon):
|
class AgentConfig(AgentConfigCommon):
|
||||||
model: str
|
model: str
|
||||||
instructions: str
|
instructions: str
|
||||||
enable_session_persistence: bool
|
enable_session_persistence: Optional[bool] = False
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,206 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
|
||||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
|
||||||
from llama_stack.apis.inference import ToolResponseMessage
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
interleaved_content_as_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
role: Optional[str] = None,
|
|
||||||
content: str = "",
|
|
||||||
end: str = "\n",
|
|
||||||
color="white",
|
|
||||||
):
|
|
||||||
self.role = role
|
|
||||||
self.content = content
|
|
||||||
self.color = color
|
|
||||||
self.end = "\n" if end is None else end
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.role is not None:
|
|
||||||
return f"{self.role}> {self.content}"
|
|
||||||
else:
|
|
||||||
return f"{self.content}"
|
|
||||||
|
|
||||||
def print(self, flush=True):
|
|
||||||
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
|
|
||||||
|
|
||||||
|
|
||||||
EventType = AgentTurnResponseEventType
|
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
|
||||||
async def log(
|
|
||||||
self,
|
|
||||||
event_generator,
|
|
||||||
stream=True,
|
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
||||||
):
|
|
||||||
previous_event_type = None
|
|
||||||
previous_step_type = None
|
|
||||||
|
|
||||||
async for chunk in event_generator:
|
|
||||||
if not hasattr(chunk, "event"):
|
|
||||||
# Need to check for custom tool first
|
|
||||||
# since it does not produce event but instead
|
|
||||||
# a Message
|
|
||||||
if isinstance(chunk, ToolResponseMessage):
|
|
||||||
yield (
|
|
||||||
chunk,
|
|
||||||
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
event = chunk.event
|
|
||||||
event_type = event.payload.event_type
|
|
||||||
if event_type in {
|
|
||||||
EventType.turn_start.value,
|
|
||||||
EventType.turn_complete.value,
|
|
||||||
}:
|
|
||||||
# Currently not logging any turn realted info
|
|
||||||
yield event, None
|
|
||||||
continue
|
|
||||||
|
|
||||||
step_type = event.payload.step_type
|
|
||||||
# handle safety
|
|
||||||
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
|
|
||||||
violation = event.payload.step_details.violation
|
|
||||||
if not violation:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(role=step_type, content="No Violation", color="magenta"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"{violation.metadata} {violation.user_message}",
|
|
||||||
color="red",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle inference
|
|
||||||
if step_type == StepType.inference:
|
|
||||||
if stream:
|
|
||||||
if event_type == EventType.step_start.value:
|
|
||||||
# TODO: Currently this event is never received
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(role=step_type, content="", end="", color="yellow"),
|
|
||||||
)
|
|
||||||
elif event_type == EventType.step_progress.value:
|
|
||||||
# HACK: if previous was not step/event was not inference's step_progress
|
|
||||||
# this is the first time we are getting model inference response
|
|
||||||
# aka equivalent to step_start for inference. Hence,
|
|
||||||
# start with "Model>".
|
|
||||||
if (
|
|
||||||
previous_event_type != EventType.step_progress.value
|
|
||||||
and previous_step_type != StepType.inference
|
|
||||||
):
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(role=step_type, content="", end="", color="yellow"),
|
|
||||||
)
|
|
||||||
|
|
||||||
delta = event.payload.delta
|
|
||||||
if delta.type == "tool_call":
|
|
||||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=None,
|
|
||||||
content=delta.tool_call,
|
|
||||||
end="",
|
|
||||||
color="cyan",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=None,
|
|
||||||
content=delta.text,
|
|
||||||
end="",
|
|
||||||
color="yellow",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# step_complete
|
|
||||||
yield event, LogEvent(role=None, content="")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Not streaming
|
|
||||||
if event_type == EventType.step_complete.value:
|
|
||||||
response = event.payload.step_details.model_response
|
|
||||||
if response.tool_calls:
|
|
||||||
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
|
|
||||||
else:
|
|
||||||
content = response.content
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=content,
|
|
||||||
color="yellow",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle tool_execution
|
|
||||||
if (
|
|
||||||
step_type == StepType.tool_execution
|
|
||||||
and
|
|
||||||
# Only print tool calls and responses at the step_complete event
|
|
||||||
event_type == EventType.step_complete.value
|
|
||||||
):
|
|
||||||
details = event.payload.step_details
|
|
||||||
for t in details.tool_calls:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
|
||||||
color="green",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for r in details.tool_responses:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
|
||||||
color="green",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
|
|
||||||
details = event.payload.step_details
|
|
||||||
inserted_context = interleaved_content_as_str(details.inserted_context)
|
|
||||||
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"
|
|
||||||
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=content,
|
|
||||||
color="cyan",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
previous_event_type = event_type
|
|
||||||
previous_step_type = step_type
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .eval_tasks import * # noqa: F401 F403
|
from .benchmarks import * # noqa: F401 F403
|
86
llama_stack/apis/benchmarks/benchmarks.py
Normal file
86
llama_stack/apis/benchmarks/benchmarks.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
class CommonBenchmarkFields(BaseModel):
|
||||||
|
dataset_id: str
|
||||||
|
scoring_functions: List[str]
|
||||||
|
metadata: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Metadata for this evaluation task",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Benchmark(CommonBenchmarkFields, Resource):
|
||||||
|
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def benchmark_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_benchmark_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||||
|
benchmark_id: str
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
provider_benchmark_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListBenchmarksResponse(BaseModel):
|
||||||
|
data: List[Benchmark]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Benchmarks(Protocol):
|
||||||
|
@webmethod(route="/eval/benchmarks", method="GET")
|
||||||
|
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||||
|
async def get_benchmark(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
) -> Optional[Benchmark]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks", method="POST")
|
||||||
|
async def register_benchmark(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
provider_benchmark_id: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval-tasks", method="GET")
|
||||||
|
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
|
||||||
|
async def DEPRECATED_get_eval_task(
|
||||||
|
self,
|
||||||
|
eval_task_id: str,
|
||||||
|
) -> Optional[Benchmark]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval-tasks", method="POST")
|
||||||
|
async def DEPRECATED_register_eval_task(
|
||||||
|
self,
|
||||||
|
eval_task_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
provider_benchmark_id: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None: ...
|
|
@ -7,10 +7,11 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolCall
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import ToolCall
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class URL(BaseModel):
|
class URL(BaseModel):
|
||||||
|
|
|
@ -7,10 +7,10 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -5,9 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Job(BaseModel):
|
class Job(BaseModel):
|
||||||
|
|
|
@ -7,9 +7,10 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingMetric(BaseModel):
|
class PostTrainingMetric(BaseModel):
|
||||||
|
|
|
@ -6,10 +6,11 @@
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class StringType(BaseModel):
|
class StringType(BaseModel):
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonDatasetFields(BaseModel):
|
class CommonDatasetFields(BaseModel):
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -28,7 +28,7 @@ class Api(Enum):
|
||||||
vector_dbs = "vector_dbs"
|
vector_dbs = "vector_dbs"
|
||||||
datasets = "datasets"
|
datasets = "datasets"
|
||||||
scoring_functions = "scoring_functions"
|
scoring_functions = "scoring_functions"
|
||||||
eval_tasks = "eval_tasks"
|
benchmarks = "benchmarks"
|
||||||
tool_groups = "tool_groups"
|
tool_groups = "tool_groups"
|
||||||
|
|
||||||
# built-in API
|
# built-in API
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
@ -15,6 +14,7 @@ from llama_stack.apis.common.job_types import Job, JobStatus
|
||||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||||
from llama_stack.apis.scoring import ScoringResult
|
from llama_stack.apis.scoring import ScoringResult
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -38,19 +38,9 @@ EvalCandidate = register_schema(
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BenchmarkEvalTaskConfig(BaseModel):
|
class BenchmarkConfig(BaseModel):
|
||||||
type: Literal["benchmark"] = "benchmark"
|
type: Literal["benchmark"] = "benchmark"
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
num_examples: Optional[int] = Field(
|
|
||||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AppEvalTaskConfig(BaseModel):
|
|
||||||
type: Literal["app"] = "app"
|
|
||||||
eval_candidate: EvalCandidate
|
|
||||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
|
@ -62,12 +52,6 @@ class AppEvalTaskConfig(BaseModel):
|
||||||
# we could optinally add any specific dataset config here
|
# we could optinally add any specific dataset config here
|
||||||
|
|
||||||
|
|
||||||
EvalTaskConfig = register_schema(
|
|
||||||
Annotated[Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")],
|
|
||||||
name="EvalTaskConfig",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EvaluateResponse(BaseModel):
|
class EvaluateResponse(BaseModel):
|
||||||
generations: List[Dict[str, Any]]
|
generations: List[Dict[str, Any]]
|
||||||
|
@ -76,27 +60,52 @@ class EvaluateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
task_config: BenchmarkConfig,
|
||||||
|
) -> Job: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_functions: List[str],
|
||||||
|
task_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||||
|
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||||
|
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||||
|
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
|
||||||
|
async def DEPRECATED_run_eval(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
task_config: EvalTaskConfig,
|
task_config: BenchmarkConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/tasks/{task_id}/evaluations", method="POST")
|
@webmethod(route="/eval/tasks/{task_id}/evaluations", method="POST")
|
||||||
async def evaluate_rows(
|
async def DEPRECATED_evaluate_rows(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: EvalTaskConfig,
|
task_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="GET")
|
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="GET")
|
||||||
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
async def DEPRECATED_job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="DELETE")
|
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="DELETE")
|
||||||
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
|
async def DEPRECATED_job_cancel(self, task_id: str, job_id: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}/result", method="GET")
|
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}/result", method="GET")
|
||||||
async def job_result(self, job_id: str, task_id: str) -> EvaluateResponse: ...
|
async def DEPRECATED_job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
|
||||||
|
|
|
@ -1,66 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
|
||||||
|
|
||||||
|
|
||||||
class CommonEvalTaskFields(BaseModel):
|
|
||||||
dataset_id: str
|
|
||||||
scoring_functions: List[str]
|
|
||||||
metadata: Dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Metadata for this evaluation task",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EvalTask(CommonEvalTaskFields, Resource):
|
|
||||||
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eval_task_id(self) -> str:
|
|
||||||
return self.identifier
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_eval_task_id(self) -> str:
|
|
||||||
return self.provider_resource_id
|
|
||||||
|
|
||||||
|
|
||||||
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
|
||||||
eval_task_id: str
|
|
||||||
provider_id: Optional[str] = None
|
|
||||||
provider_eval_task_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ListEvalTasksResponse(BaseModel):
|
|
||||||
data: List[EvalTask]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class EvalTasks(Protocol):
|
|
||||||
@webmethod(route="/eval-tasks", method="GET")
|
|
||||||
async def list_eval_tasks(self) -> ListEvalTasksResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
|
|
||||||
async def get_eval_task(
|
|
||||||
self,
|
|
||||||
eval_task_id: str,
|
|
||||||
) -> Optional[EvalTask]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/eval-tasks", method="POST")
|
|
||||||
async def register_eval_task(
|
|
||||||
self,
|
|
||||||
eval_task_id: str,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: List[str],
|
|
||||||
provider_eval_task_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> None: ...
|
|
|
@ -17,7 +17,13 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -25,14 +31,8 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
|
||||||
from llama_stack.apis.models import Model
|
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
|
@ -182,10 +182,12 @@ class ToolChoice(Enum):
|
||||||
|
|
||||||
:cvar auto: The model may use tools if it determines that is appropriate.
|
:cvar auto: The model may use tools if it determines that is appropriate.
|
||||||
:cvar required: The model must use tools.
|
:cvar required: The model must use tools.
|
||||||
|
:cvar none: The model must not use tools.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
auto = "auto"
|
auto = "auto"
|
||||||
required = "required"
|
required = "required"
|
||||||
|
none = "none"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -326,7 +328,7 @@ class SystemMessageBehavior(Enum):
|
||||||
class ToolConfig(BaseModel):
|
class ToolConfig(BaseModel):
|
||||||
"""Configuration for tool use.
|
"""Configuration for tool use.
|
||||||
|
|
||||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
||||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||||
|
@ -337,9 +339,16 @@ class ToolConfig(BaseModel):
|
||||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append)
|
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
if isinstance(self.tool_choice, str):
|
||||||
|
try:
|
||||||
|
self.tool_choice = ToolChoice[self.tool_choice]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# This is an internally used class
|
# This is an internally used class
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import List, Protocol, runtime_checkable
|
from typing import List, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProviderInfo(BaseModel):
|
class ProviderInfo(BaseModel):
|
||||||
|
|
|
@ -7,11 +7,11 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonModelFields(BaseModel):
|
class CommonModelFields(BaseModel):
|
||||||
|
|
|
@ -8,13 +8,13 @@ from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.common.training_types import Checkpoint
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -15,7 +15,7 @@ class ResourceType(Enum):
|
||||||
vector_db = "vector_db"
|
vector_db = "vector_db"
|
||||||
dataset = "dataset"
|
dataset = "dataset"
|
||||||
scoring_function = "scoring_function"
|
scoring_function = "scoring_function"
|
||||||
eval_task = "eval_task"
|
benchmark = "benchmark"
|
||||||
tool = "tool"
|
tool = "tool"
|
||||||
tool_group = "tool_group"
|
tool_group = "tool_group"
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
# mapping of metric to value
|
# mapping of metric to value
|
||||||
ScoringResultRow = Dict[str, Any]
|
ScoringResultRow = Dict[str, Any]
|
||||||
|
|
|
@ -16,12 +16,12 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonShieldFields(BaseModel):
|
class CommonShieldFields(BaseModel):
|
||||||
|
|
|
@ -7,10 +7,10 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
class FilteringFunction(Enum):
|
||||||
|
|
|
@ -17,11 +17,12 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Primitive
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import Primitive
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
# Add this constant near the top of the file, after the imports
|
# Add this constant near the top of the file, after the imports
|
||||||
DEFAULT_TTL_DAYS = 7
|
DEFAULT_TTL_DAYS = 7
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -7,13 +7,13 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from .rag_tool import RAGToolRuntime
|
from .rag_tool import RAGToolRuntime
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -10,12 +10,12 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import InterleavedContent
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
|
|
|
@ -16,8 +16,6 @@ from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.datatypes import Model
|
|
||||||
from llama_models.sku_list import LlamaDownloadInfo
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.progress import (
|
from rich.progress import (
|
||||||
|
@ -31,6 +29,8 @@ from rich.progress import (
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.models.llama.datatypes import Model
|
||||||
|
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
|
@ -56,7 +56,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-id",
|
"--model-id",
|
||||||
required=False,
|
required=False,
|
||||||
help="See `llama model list` or `llama model list --show-all` for the list of available models",
|
help="See `llama model list` or `llama model list --show-all` for the list of available models. Specify multiple model IDs with commas, e.g. --model-id Llama3.2-1B,Llama3.2-3B",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-token",
|
"--hf-token",
|
||||||
|
@ -83,8 +83,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
type=str,
|
type=str,
|
||||||
required=False,
|
required=False,
|
||||||
default="*.safetensors",
|
default="*.safetensors",
|
||||||
help="""
|
help="""For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
|
||||||
For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
|
|
||||||
safetensors files to avoid downloading duplicate weights.
|
safetensors files to avoid downloading duplicate weights.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
@ -454,7 +453,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
# Handle comma-separated model IDs
|
# Handle comma-separated model IDs
|
||||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||||
|
|
||||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
||||||
|
|
||||||
from .model.safety_models import (
|
from .model.safety_models import (
|
||||||
prompt_guard_download_info,
|
prompt_guard_download_info,
|
||||||
|
|
|
@ -7,11 +7,11 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
|
|
||||||
class ModelDescribe(Subcommand):
|
class ModelDescribe(Subcommand):
|
||||||
|
@ -34,6 +34,7 @@ class ModelDescribe(Subcommand):
|
||||||
"--model-id",
|
"--model-id",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
|
help="See `llama model list` or `llama model list --show-all` for the list of available models",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
|
|
@ -6,10 +6,9 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
|
|
||||||
|
|
||||||
class ModelList(Subcommand):
|
class ModelList(Subcommand):
|
||||||
|
@ -37,8 +36,8 @@ class ModelList(Subcommand):
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_sku
|
||||||
|
|
||||||
headers = [
|
headers = [
|
||||||
"Model Descriptor",
|
"Model Descriptor(ID)",
|
||||||
"Model ID",
|
"Hugging Face Repo",
|
||||||
"Context Length",
|
"Context Length",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -8,9 +8,8 @@ import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||||
|
|
||||||
|
|
||||||
class ModelPromptFormat(Subcommand):
|
class ModelPromptFormat(Subcommand):
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
|
||||||
from llama_models.llama3.api.datatypes import SamplingParams
|
|
||||||
from llama_models.sku_list import LlamaDownloadInfo
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
|
||||||
|
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardModel(BaseModel):
|
class PromptGuardModel(BaseModel):
|
||||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||||
|
|
|
@ -15,7 +15,7 @@ class ModelVerifyDownload(Subcommand):
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
"verify-download",
|
"verify-download",
|
||||||
prog="llama model verify-download",
|
prog="llama model verify-download",
|
||||||
description="Verify the downloaded checkpoints' checksums",
|
description="Verify the downloaded checkpoints' checksums for models downloaded from Meta",
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -38,9 +38,8 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--list-templates",
|
"--list-templates",
|
||||||
type=bool,
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
help="Show the available templates for building a Llama Stack distribution",
|
help="Show the available templates for building a Llama Stack distribution",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,9 +55,8 @@ class StackBuild(Subcommand):
|
||||||
"--image-name",
|
"--image-name",
|
||||||
type=str,
|
type=str,
|
||||||
help=textwrap.dedent(
|
help=textwrap.dedent(
|
||||||
"""[for image-type=conda] Name of the conda environment to use for the build. If
|
"""[for image-type=conda|venv] Name of the conda or virtual environment to use for
|
||||||
not specified, currently active Conda environment will be used. If no Conda
|
the build. If not specified, currently active Conda environment will be used if found.
|
||||||
environment is active, you must specify a name.
|
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
default=None,
|
default=None,
|
||||||
|
|
|
@ -17,7 +17,7 @@ class StackConfigure(Subcommand):
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
"configure",
|
"configure",
|
||||||
prog="llama stack configure",
|
prog="llama stack configure",
|
||||||
description="configure a llama stack distribution",
|
description="Configure a llama stack distribution",
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
self._add_arguments()
|
self._add_arguments()
|
||||||
|
|
|
@ -19,7 +19,7 @@ class StackRun(Subcommand):
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
"run",
|
"run",
|
||||||
prog="llama stack run",
|
prog="llama stack run",
|
||||||
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
self._add_arguments()
|
self._add_arguments()
|
||||||
|
|
|
@ -4,75 +4,36 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import re
|
|
||||||
import textwrap
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from termcolor import cprint
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
def strip_ansi_colors(text):
|
|
||||||
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
|
||||||
return ansi_escape.sub("", text)
|
|
||||||
|
|
||||||
|
|
||||||
def format_row(row, col_widths):
|
|
||||||
def wrap(text, width):
|
|
||||||
lines = []
|
|
||||||
for line in text.split("\n"):
|
|
||||||
if line.strip() == "":
|
|
||||||
lines.append("")
|
|
||||||
else:
|
|
||||||
lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False))
|
|
||||||
return lines
|
|
||||||
|
|
||||||
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]
|
|
||||||
max_lines = max(len(subrow) for subrow in wrapped)
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for i in range(max_lines):
|
|
||||||
line = []
|
|
||||||
for cell_lines, width in zip(wrapped, col_widths):
|
|
||||||
value = cell_lines[i] if i < len(cell_lines) else ""
|
|
||||||
line.append(value + " " * (width - len(strip_ansi_colors(value))))
|
|
||||||
lines.append("| " + (" | ".join(line)) + " |")
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
|
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
|
||||||
def itemlen(item):
|
# Convert rows and handle None values
|
||||||
return max([len(line) for line in strip_ansi_colors(item).split("\n")])
|
|
||||||
|
|
||||||
rows = [[x or "" for x in row] for row in rows]
|
rows = [[x or "" for x in row] for row in rows]
|
||||||
|
|
||||||
|
# Sort rows if sort_by is specified
|
||||||
if sort_by:
|
if sort_by:
|
||||||
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))
|
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))
|
||||||
|
|
||||||
if not headers:
|
# Create Rich table
|
||||||
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
|
table = Table(show_lines=separate_rows)
|
||||||
else:
|
|
||||||
col_widths = [
|
|
||||||
max(
|
|
||||||
itemlen(header),
|
|
||||||
max(itemlen(item) for item in col),
|
|
||||||
)
|
|
||||||
for header, col in zip(headers, zip(*rows))
|
|
||||||
]
|
|
||||||
col_widths = [min(w, 80) for w in col_widths]
|
|
||||||
|
|
||||||
header_line = "+".join("-" * (width + 2) for width in col_widths)
|
|
||||||
header_line = f"+{header_line}+"
|
|
||||||
|
|
||||||
|
# Add headers if provided
|
||||||
if headers:
|
if headers:
|
||||||
print(header_line)
|
for header in headers:
|
||||||
cprint(format_row(headers, col_widths), "white", attrs=["bold"])
|
table.add_column(header, style="bold white")
|
||||||
|
else:
|
||||||
|
# Add unnamed columns based on first row
|
||||||
|
for _ in range(len(rows[0]) if rows else 0):
|
||||||
|
table.add_column()
|
||||||
|
|
||||||
print(header_line)
|
# Add rows
|
||||||
for row in rows:
|
for row in rows:
|
||||||
print(format_row(row, col_widths))
|
table.add_row(*row)
|
||||||
if separate_rows:
|
|
||||||
print(header_line)
|
|
||||||
|
|
||||||
if not separate_rows:
|
# Print table
|
||||||
print(header_line)
|
console = Console()
|
||||||
|
console.print(table)
|
||||||
|
|
|
@ -44,7 +44,7 @@ def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-id",
|
"--model-id",
|
||||||
required=True,
|
required=True,
|
||||||
help="Model ID to verify",
|
help="Model ID to verify (only for models downloaded from Meta)",
|
||||||
)
|
)
|
||||||
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,6 @@ def build_image(
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
str(image_name),
|
str(image_name),
|
||||||
str(build_file_path),
|
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -24,23 +24,21 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$#" -lt 3 ]; then
|
if [ "$#" -lt 3 ]; then
|
||||||
echo "Usage: $0 <distribution_type> <build_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||||
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
special_pip_deps="$4"
|
special_pip_deps="$3"
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
build_name="$1"
|
build_name="$1"
|
||||||
env_name="llamastack-$build_name"
|
env_name="llamastack-$build_name"
|
||||||
build_file_path="$2"
|
pip_dependencies="$2"
|
||||||
pip_dependencies="$3"
|
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
# this is set if we actually create a new conda in which case we need to clean up
|
# this is set if we actually create a new conda in which case we need to clean up
|
||||||
|
@ -49,34 +47,63 @@ ENVNAME=""
|
||||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
source "$SCRIPT_DIR/common.sh"
|
source "$SCRIPT_DIR/common.sh"
|
||||||
|
|
||||||
|
# pre-run checks to make sure we can proceed with the installation
|
||||||
|
pre_run_checks() {
|
||||||
|
local env_name="$1"
|
||||||
|
|
||||||
|
if ! is_command_available uv; then
|
||||||
|
echo "uv is not installed, trying to install it."
|
||||||
|
if ! is_command_available pip; then
|
||||||
|
echo "pip is not installed, cannot automatically install 'uv'."
|
||||||
|
echo "Follow this link to install it:"
|
||||||
|
echo "https://docs.astral.sh/uv/getting-started/installation/"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
pip install uv
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# checking if an environment with the same name already exists
|
||||||
|
if [ -d "$env_name" ]; then
|
||||||
|
echo "Environment '$env_name' already exists, re-using it."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
run() {
|
run() {
|
||||||
local env_name="$1"
|
local env_name="$1"
|
||||||
local pip_dependencies="$2"
|
local pip_dependencies="$2"
|
||||||
local special_pip_deps="$3"
|
local special_pip_deps="$3"
|
||||||
|
|
||||||
pip install uv
|
echo "Using virtual environment $env_name"
|
||||||
|
uv venv "$env_name"
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source "$env_name/bin/activate"
|
||||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
# these packages are damaged in test-pypi, so install them first
|
# these packages are damaged in test-pypi, so install them first
|
||||||
uv pip install fastapi libcst
|
uv pip install fastapi libcst
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
# we are building a command line so word splitting is expected
|
||||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||||
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
|
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
||||||
$pip_dependencies
|
$pip_dependencies
|
||||||
if [ -n "$special_pip_deps" ]; then
|
if [ -n "$special_pip_deps" ]; then
|
||||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||||
for part in "${parts[@]}"; do
|
for part in "${parts[@]}"; do
|
||||||
echo "$part"
|
echo "$part"
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
# we are building a command line so word splitting is expected
|
||||||
uv pip install $part
|
uv pip install $part
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
# Re-installing llama-stack in the new conda environment
|
# Re-installing llama-stack in the new virtual environment
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
|
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
|
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
|
||||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||||
else
|
else
|
||||||
uv pip install --no-cache-dir llama-stack
|
uv pip install --no-cache-dir llama-stack
|
||||||
|
@ -84,26 +111,31 @@ run() {
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR"
|
||||||
uv pip uninstall llama-models
|
uv pip uninstall llama-models
|
||||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install pip dependencies
|
# Install pip dependencies
|
||||||
printf "Installing pip dependencies\n"
|
printf "Installing pip dependencies\n"
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
# we are building a command line so word splitting is expected
|
||||||
uv pip install $pip_dependencies
|
uv pip install $pip_dependencies
|
||||||
if [ -n "$special_pip_deps" ]; then
|
if [ -n "$special_pip_deps" ]; then
|
||||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||||
for part in "${parts[@]}"; do
|
for part in "${parts[@]}"; do
|
||||||
echo "$part"
|
echo "$part"
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
# we are building a command line so word splitting is expected
|
||||||
uv pip install $part
|
uv pip install $part
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pre_run_checks "$env_name"
|
||||||
run "$env_name" "$pip_dependencies" "$special_pip_deps"
|
run "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||||
|
|
|
@ -186,33 +186,3 @@ def extract_async_iterator_type(type_hint):
|
||||||
inner_args = get_args(arg)
|
inner_args = get_args(arg)
|
||||||
return inner_args[0]
|
return inner_args[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def example(model: str = None):
|
|
||||||
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
|
|
||||||
from llama_stack.apis.inference.event_logger import EventLogger
|
|
||||||
|
|
||||||
client_class = create_api_client_class(Inference)
|
|
||||||
client = client_class("http://localhost:5003")
|
|
||||||
|
|
||||||
if not model:
|
|
||||||
model = "Llama3.2-3B-Instruct"
|
|
||||||
|
|
||||||
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
|
|
||||||
cprint(f"User>{message.content}", "green")
|
|
||||||
|
|
||||||
stream = True
|
|
||||||
iterator = await client.chat_completion(
|
|
||||||
model=model,
|
|
||||||
messages=[message],
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
async for log in EventLogger().log(iterator):
|
|
||||||
log.print()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(example())
|
|
||||||
|
|
|
@ -38,3 +38,8 @@ setup_cleanup_handlers() {
|
||||||
|
|
||||||
conda deactivate
|
conda deactivate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# check if a command is present
|
||||||
|
is_command_available() {
|
||||||
|
command -v "$1" &>/dev/null
|
||||||
|
}
|
||||||
|
|
|
@ -8,10 +8,10 @@ from typing import Annotated, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
|
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.models import Model, ModelInput
|
from llama_stack.apis.models import Model, ModelInput
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
@ -37,7 +37,7 @@ RoutableObject = Union[
|
||||||
VectorDB,
|
VectorDB,
|
||||||
Dataset,
|
Dataset,
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
EvalTask,
|
Benchmark,
|
||||||
Tool,
|
Tool,
|
||||||
ToolGroup,
|
ToolGroup,
|
||||||
]
|
]
|
||||||
|
@ -50,7 +50,7 @@ RoutableObjectWithProvider = Annotated[
|
||||||
VectorDB,
|
VectorDB,
|
||||||
Dataset,
|
Dataset,
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
EvalTask,
|
Benchmark,
|
||||||
Tool,
|
Tool,
|
||||||
ToolGroup,
|
ToolGroup,
|
||||||
],
|
],
|
||||||
|
@ -173,7 +173,7 @@ a default SQLite store will be used.""",
|
||||||
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
||||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
server: ServerConfig = Field(
|
server: ServerConfig = Field(
|
||||||
|
|
|
@ -44,7 +44,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
router_api=Api.scoring,
|
router_api=Api.scoring,
|
||||||
),
|
),
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
routing_table_api=Api.eval_tasks,
|
routing_table_api=Api.benchmarks,
|
||||||
router_api=Api.eval,
|
router_api=Api.eval,
|
||||||
),
|
),
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
|
|
|
@ -13,7 +13,7 @@ import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, TypeVar, get_args, get_origin
|
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -47,6 +47,8 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,12 +83,13 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
origin = get_origin(annotation)
|
origin = get_origin(annotation)
|
||||||
|
|
||||||
if origin is list:
|
if origin is list:
|
||||||
item_type = get_args(annotation)[0]
|
item_type = get_args(annotation)[0]
|
||||||
try:
|
try:
|
||||||
return [convert_to_pydantic(item_type, item) for item in value]
|
return [convert_to_pydantic(item_type, item) for item in value]
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error converting list {value}")
|
logger.error(f"Error converting list {value} into {item_type}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
elif origin is dict:
|
elif origin is dict:
|
||||||
|
@ -94,17 +97,25 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error converting dict {value}")
|
logger.error(f"Error converting dict {value} into {val_type}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Handle Pydantic models and discriminated unions
|
# Handle Pydantic models and discriminated unions
|
||||||
return TypeAdapter(annotation).validate_python(value)
|
return TypeAdapter(annotation).validate_python(value)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(
|
# TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
|
||||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
# We should get rid of any non-discriminated unions in the API schema.
|
||||||
"yellow",
|
if origin is Union:
|
||||||
)
|
for union_type in get_args(annotation):
|
||||||
|
try:
|
||||||
|
return convert_to_pydantic(union_type, value)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
logger.warning(
|
||||||
|
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||||
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,7 +153,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
for handler in root_logger.handlers[:]:
|
for handler in root_logger.handlers[:]:
|
||||||
root_logger.removeHandler(handler)
|
root_logger.removeHandler(handler)
|
||||||
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
|
@ -231,7 +242,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
def _convert_path_to_regex(path: str) -> str:
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
# Convert {param} to named capture groups
|
# Convert {param} to named capture groups
|
||||||
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
|
# handle {param:path} as well which allows for forward slashes in the param value
|
||||||
|
pattern = re.sub(
|
||||||
|
r"{(\w+)(?::path)?}",
|
||||||
|
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
|
||||||
return f"^{pattern}$"
|
return f"^{pattern}$"
|
||||||
|
|
||||||
for api, api_endpoints in endpoints.items():
|
for api, api_endpoints in endpoints.items():
|
||||||
|
@ -415,4 +432,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if param_name in body:
|
if param_name in body:
|
||||||
value = body.get(param_name)
|
value = body.get(param_name)
|
||||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||||
|
|
||||||
return converted_body
|
return converted_body
|
||||||
|
|
|
@ -9,10 +9,10 @@ import logging
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.eval_tasks import EvalTasks
|
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
|
@ -37,8 +37,8 @@ from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
Api,
|
Api,
|
||||||
|
BenchmarksProtocolPrivate,
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
EvalTasksProtocolPrivate,
|
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
|
@ -73,7 +73,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
Api.scoring: Scoring,
|
Api.scoring: Scoring,
|
||||||
Api.scoring_functions: ScoringFunctions,
|
Api.scoring_functions: ScoringFunctions,
|
||||||
Api.eval: Eval,
|
Api.eval: Eval,
|
||||||
Api.eval_tasks: EvalTasks,
|
Api.benchmarks: Benchmarks,
|
||||||
Api.post_training: PostTraining,
|
Api.post_training: PostTraining,
|
||||||
Api.tool_groups: ToolGroups,
|
Api.tool_groups: ToolGroups,
|
||||||
Api.tool_runtime: ToolRuntime,
|
Api.tool_runtime: ToolRuntime,
|
||||||
|
@ -92,7 +92,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
Api.scoring_functions,
|
Api.scoring_functions,
|
||||||
),
|
),
|
||||||
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
|
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,8 @@ from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
from .routing_tables import (
|
from .routing_tables import (
|
||||||
|
BenchmarksRoutingTable,
|
||||||
DatasetsRoutingTable,
|
DatasetsRoutingTable,
|
||||||
EvalTasksRoutingTable,
|
|
||||||
ModelsRoutingTable,
|
ModelsRoutingTable,
|
||||||
ScoringFunctionsRoutingTable,
|
ScoringFunctionsRoutingTable,
|
||||||
ShieldsRoutingTable,
|
ShieldsRoutingTable,
|
||||||
|
@ -33,7 +33,7 @@ async def get_routing_table_impl(
|
||||||
"shields": ShieldsRoutingTable,
|
"shields": ShieldsRoutingTable,
|
||||||
"datasets": DatasetsRoutingTable,
|
"datasets": DatasetsRoutingTable,
|
||||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||||
"eval_tasks": EvalTasksRoutingTable,
|
"benchmarks": BenchmarksRoutingTable,
|
||||||
"tool_groups": ToolGroupsRoutingTable,
|
"tool_groups": ToolGroupsRoutingTable,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,9 +9,8 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
AppEvalTaskConfig,
|
BenchmarkConfig,
|
||||||
Eval,
|
Eval,
|
||||||
EvalTaskConfig,
|
|
||||||
EvaluateResponse,
|
EvaluateResponse,
|
||||||
Job,
|
Job,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
|
@ -129,7 +128,7 @@ class InferenceRouter(Inference):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = None,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
@ -141,20 +140,36 @@ class InferenceRouter(Inference):
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
if tool_config:
|
if tool_config:
|
||||||
if tool_choice != tool_config.tool_choice:
|
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
if tool_prompt_format != tool_config.tool_prompt_format:
|
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||||
else:
|
else:
|
||||||
tool_config = ToolConfig(
|
params = {}
|
||||||
tool_choice=tool_choice,
|
if tool_choice:
|
||||||
tool_prompt_format=tool_prompt_format,
|
params["tool_choice"] = tool_choice
|
||||||
)
|
if tool_prompt_format:
|
||||||
|
params["tool_prompt_format"] = tool_prompt_format
|
||||||
|
tool_config = ToolConfig(**params)
|
||||||
|
|
||||||
|
tools = tools or []
|
||||||
|
if tool_config.tool_choice == ToolChoice.none:
|
||||||
|
tools = []
|
||||||
|
elif tool_config.tool_choice == ToolChoice.auto:
|
||||||
|
pass
|
||||||
|
elif tool_config.tool_choice == ToolChoice.required:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# verify tool_choice is one of the tools
|
||||||
|
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||||
|
if tool_config.tool_choice not in tool_names:
|
||||||
|
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -347,23 +362,23 @@ class EvalRouter(Eval):
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
benchmark_id: str,
|
||||||
task_config: AppEvalTaskConfig,
|
task_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
return await self.routing_table.get_provider_impl(task_id).run_eval(
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
task_id=task_id,
|
benchmark_id=benchmark_id,
|
||||||
task_config=task_config,
|
task_config=task_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: EvalTaskConfig,
|
task_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
task_id=task_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=task_config,
|
task_config=task_config,
|
||||||
|
@ -371,30 +386,72 @@ class EvalRouter(Eval):
|
||||||
|
|
||||||
async def job_status(
|
async def job_status(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Optional[JobStatus]:
|
) -> Optional[JobStatus]:
|
||||||
return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id)
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.get_provider_impl(task_id).job_cancel(
|
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||||
task_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def job_result(
|
async def job_result(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def DEPRECATED_run_eval(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
task_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
|
||||||
|
|
||||||
|
async def DEPRECATED_evaluate_rows(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_functions: List[str],
|
||||||
|
task_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
return await self.evaluate_rows(
|
||||||
|
benchmark_id=task_id,
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
task_config=task_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def DEPRECATED_job_status(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> Optional[JobStatus]:
|
||||||
|
return await self.job_status(benchmark_id=task_id, job_id=job_id)
|
||||||
|
|
||||||
|
async def DEPRECATED_job_cancel(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> None:
|
||||||
|
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
|
||||||
|
|
||||||
|
async def DEPRECATED_job_result(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
return await self.routing_table.get_provider_impl(task_id).job_result(
|
return await self.job_result(benchmark_id=task_id, job_id=job_id)
|
||||||
task_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
|
@ -4,14 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import (
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
@ -38,6 +39,8 @@ from llama_stack.distribution.datatypes import (
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
def get_impl_api(p: Any) -> Api:
|
||||||
return p.__provider_spec__.api
|
return p.__provider_spec__.api
|
||||||
|
@ -60,7 +63,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
return await p.register_scoring_function(obj)
|
return await p.register_scoring_function(obj)
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
return await p.register_eval_task(obj)
|
return await p.register_benchmark(obj)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
return await p.register_tool(obj)
|
return await p.register_tool(obj)
|
||||||
else:
|
else:
|
||||||
|
@ -121,7 +124,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
scoring_functions = await p.list_scoring_functions()
|
scoring_functions = await p.list_scoring_functions()
|
||||||
await add_objects(scoring_functions, pid, ScoringFn)
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
p.eval_task_store = self
|
p.benchmark_store = self
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
p.tool_store = self
|
p.tool_store = self
|
||||||
|
|
||||||
|
@ -141,8 +144,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return ("DatasetIO", "dataset")
|
return ("DatasetIO", "dataset")
|
||||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||||
return ("Scoring", "scoring_function")
|
return ("Scoring", "scoring_function")
|
||||||
elif isinstance(self, EvalTasksRoutingTable):
|
elif isinstance(self, BenchmarksRoutingTable):
|
||||||
return ("Eval", "eval_task")
|
return ("Eval", "benchmark")
|
||||||
elif isinstance(self, ToolGroupsRoutingTable):
|
elif isinstance(self, ToolGroupsRoutingTable):
|
||||||
return ("Tools", "tool")
|
return ("Tools", "tool")
|
||||||
else:
|
else:
|
||||||
|
@ -428,20 +431,20 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
await self.register_object(scoring_fn)
|
await self.register_object(scoring_fn)
|
||||||
|
|
||||||
|
|
||||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
async def list_eval_tasks(self) -> ListEvalTasksResponse:
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task"))
|
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||||
|
|
||||||
async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]:
|
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
|
||||||
return await self.get_object_by_identifier("eval_task", eval_task_id)
|
return await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||||
|
|
||||||
async def register_eval_task(
|
async def register_benchmark(
|
||||||
self,
|
self,
|
||||||
eval_task_id: str,
|
benchmark_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
provider_eval_task_id: Optional[str] = None,
|
provider_benchmark_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
|
@ -453,17 +456,46 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
if provider_eval_task_id is None:
|
if provider_benchmark_id is None:
|
||||||
provider_eval_task_id = eval_task_id
|
provider_benchmark_id = benchmark_id
|
||||||
eval_task = EvalTask(
|
benchmark = Benchmark(
|
||||||
identifier=eval_task_id,
|
identifier=benchmark_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=provider_eval_task_id,
|
provider_resource_id=provider_benchmark_id,
|
||||||
|
)
|
||||||
|
await self.register_object(benchmark)
|
||||||
|
|
||||||
|
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse:
|
||||||
|
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||||
|
return await self.list_benchmarks()
|
||||||
|
|
||||||
|
async def DEPRECATED_get_eval_task(
|
||||||
|
self,
|
||||||
|
eval_task_id: str,
|
||||||
|
) -> Optional[Benchmark]:
|
||||||
|
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||||
|
return await self.get_benchmark(eval_task_id)
|
||||||
|
|
||||||
|
async def DEPRECATED_register_eval_task(
|
||||||
|
self,
|
||||||
|
eval_task_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
provider_benchmark_id: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||||
|
return await self.register_benchmark(
|
||||||
|
benchmark_id=eval_task_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
metadata=metadata,
|
||||||
|
provider_benchmark_id=provider_benchmark_id,
|
||||||
)
|
)
|
||||||
await self.register_object(eval_task)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
|
|
@ -15,10 +15,10 @@ from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.batch_inference import BatchInference
|
from llama_stack.apis.batch_inference import BatchInference
|
||||||
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.eval_tasks import EvalTasks
|
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
|
@ -53,7 +53,7 @@ class LlamaStack(
|
||||||
PostTraining,
|
PostTraining,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
Eval,
|
Eval,
|
||||||
EvalTasks,
|
Benchmarks,
|
||||||
Scoring,
|
Scoring,
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
DatasetIO,
|
DatasetIO,
|
||||||
|
@ -78,7 +78,7 @@ RESOURCES = [
|
||||||
"register_scoring_function",
|
"register_scoring_function",
|
||||||
"list_scoring_functions",
|
"list_scoring_functions",
|
||||||
),
|
),
|
||||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
|
||||||
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client eval_tasks register \
|
$ llama-stack-client benchmarks register \
|
||||||
--eval-task-id meta-reference-mmlu \
|
--eval-task-id meta-reference-mmlu \
|
||||||
--provider-id meta-reference \
|
--provider-id meta-reference \
|
||||||
--dataset-id mmlu \
|
--dataset-id mmlu \
|
||||||
|
|
|
@ -8,12 +8,12 @@ import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
from modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def eval_tasks():
|
def benchmarks():
|
||||||
# Eval Tasks Section
|
# Benchmarks Section
|
||||||
st.header("Eval Tasks")
|
st.header("Benchmarks")
|
||||||
|
|
||||||
eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()}
|
benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()}
|
||||||
|
|
||||||
if len(eval_tasks_info) > 0:
|
if len(benchmarks_info) > 0:
|
||||||
selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect")
|
selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect")
|
||||||
st.json(eval_tasks_info[selected_eval_task], expanded=True)
|
st.json(benchmarks_info[selected_benchmark], expanded=True)
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from page.distribution.benchmarks import benchmarks
|
||||||
from page.distribution.datasets import datasets
|
from page.distribution.datasets import datasets
|
||||||
from page.distribution.eval_tasks import eval_tasks
|
|
||||||
from page.distribution.models import models
|
from page.distribution.models import models
|
||||||
from page.distribution.scoring_functions import scoring_functions
|
from page.distribution.scoring_functions import scoring_functions
|
||||||
from page.distribution.shields import shields
|
from page.distribution.shields import shields
|
||||||
|
@ -20,7 +20,7 @@ def resources_page():
|
||||||
"Shields",
|
"Shields",
|
||||||
"Scoring Functions",
|
"Scoring Functions",
|
||||||
"Datasets",
|
"Datasets",
|
||||||
"Eval Tasks",
|
"Benchmarks",
|
||||||
]
|
]
|
||||||
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
|
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
|
||||||
selected_resource = option_menu(
|
selected_resource = option_menu(
|
||||||
|
@ -34,8 +34,8 @@ def resources_page():
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if selected_resource == "Eval Tasks":
|
if selected_resource == "Benchmarks":
|
||||||
eval_tasks()
|
benchmarks()
|
||||||
elif selected_resource == "Vector Databases":
|
elif selected_resource == "Vector Databases":
|
||||||
vector_dbs()
|
vector_dbs()
|
||||||
elif selected_resource == "Datasets":
|
elif selected_resource == "Datasets":
|
||||||
|
|
|
@ -11,28 +11,28 @@ import streamlit as st
|
||||||
from modules.api import llama_stack_api
|
from modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
def select_eval_task_1():
|
def select_benchmark_1():
|
||||||
# Select Eval Tasks
|
# Select Benchmarks
|
||||||
st.subheader("1. Choose An Eval Task")
|
st.subheader("1. Choose An Eval Task")
|
||||||
eval_tasks = llama_stack_api.client.eval_tasks.list()
|
benchmarks = llama_stack_api.client.benchmarks.list()
|
||||||
eval_tasks = {et.identifier: et for et in eval_tasks}
|
benchmarks = {et.identifier: et for et in benchmarks}
|
||||||
eval_tasks_names = list(eval_tasks.keys())
|
benchmarks_names = list(benchmarks.keys())
|
||||||
selected_eval_task = st.selectbox(
|
selected_benchmark = st.selectbox(
|
||||||
"Choose an eval task.",
|
"Choose an eval task.",
|
||||||
options=eval_tasks_names,
|
options=benchmarks_names,
|
||||||
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
|
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
|
||||||
)
|
)
|
||||||
with st.expander("View Eval Task"):
|
with st.expander("View Eval Task"):
|
||||||
st.json(eval_tasks[selected_eval_task], expanded=True)
|
st.json(benchmarks[selected_benchmark], expanded=True)
|
||||||
|
|
||||||
st.session_state["selected_eval_task"] = selected_eval_task
|
st.session_state["selected_benchmark"] = selected_benchmark
|
||||||
st.session_state["eval_tasks"] = eval_tasks
|
st.session_state["benchmarks"] = benchmarks
|
||||||
if st.button("Confirm", key="confirm_1"):
|
if st.button("Confirm", key="confirm_1"):
|
||||||
st.session_state["selected_eval_task_1_next"] = True
|
st.session_state["selected_benchmark_1_next"] = True
|
||||||
|
|
||||||
|
|
||||||
def define_eval_candidate_2():
|
def define_eval_candidate_2():
|
||||||
if not st.session_state.get("selected_eval_task_1_next", None):
|
if not st.session_state.get("selected_benchmark_1_next", None):
|
||||||
return
|
return
|
||||||
|
|
||||||
st.subheader("2. Define Eval Candidate")
|
st.subheader("2. Define Eval Candidate")
|
||||||
|
@ -161,11 +161,11 @@ def run_evaluation_3():
|
||||||
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
|
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
selected_eval_task = st.session_state["selected_eval_task"]
|
selected_benchmark = st.session_state["selected_benchmark"]
|
||||||
eval_tasks = st.session_state["eval_tasks"]
|
benchmarks = st.session_state["benchmarks"]
|
||||||
eval_candidate = st.session_state["eval_candidate"]
|
eval_candidate = st.session_state["eval_candidate"]
|
||||||
|
|
||||||
dataset_id = eval_tasks[selected_eval_task].dataset_id
|
dataset_id = benchmarks[selected_benchmark].dataset_id
|
||||||
rows = llama_stack_api.client.datasetio.get_rows_paginated(
|
rows = llama_stack_api.client.datasetio.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
|
@ -180,16 +180,16 @@ def run_evaluation_3():
|
||||||
help="Number of examples from the dataset to evaluate. ",
|
help="Number of examples from the dataset to evaluate. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_task_config = {
|
benchmark_config = {
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": eval_candidate,
|
"eval_candidate": eval_candidate,
|
||||||
"scoring_params": {},
|
"scoring_params": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
with st.expander("View Evaluation Task", expanded=True):
|
with st.expander("View Evaluation Task", expanded=True):
|
||||||
st.json(eval_tasks[selected_eval_task], expanded=True)
|
st.json(benchmarks[selected_benchmark], expanded=True)
|
||||||
with st.expander("View Evaluation Task Configuration", expanded=True):
|
with st.expander("View Evaluation Task Configuration", expanded=True):
|
||||||
st.json(eval_task_config, expanded=True)
|
st.json(benchmark_config, expanded=True)
|
||||||
|
|
||||||
# Add run button and handle evaluation
|
# Add run button and handle evaluation
|
||||||
if st.button("Run Evaluation"):
|
if st.button("Run Evaluation"):
|
||||||
|
@ -209,10 +209,10 @@ def run_evaluation_3():
|
||||||
progress_bar.progress(progress, text=progress_text)
|
progress_bar.progress(progress, text=progress_text)
|
||||||
# Run evaluation for current row
|
# Run evaluation for current row
|
||||||
eval_res = llama_stack_api.client.eval.evaluate_rows(
|
eval_res = llama_stack_api.client.eval.evaluate_rows(
|
||||||
task_id=selected_eval_task,
|
benchmark_id=selected_benchmark,
|
||||||
input_rows=[r],
|
input_rows=[r],
|
||||||
scoring_functions=eval_tasks[selected_eval_task].scoring_functions,
|
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
||||||
task_config=eval_task_config,
|
task_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
for k in r.keys():
|
for k in r.keys():
|
||||||
|
@ -225,7 +225,7 @@ def run_evaluation_3():
|
||||||
output_res[k] = []
|
output_res[k] = []
|
||||||
output_res[k].append(eval_res.generations[0][k])
|
output_res[k].append(eval_res.generations[0][k])
|
||||||
|
|
||||||
for scoring_fn in eval_tasks[selected_eval_task].scoring_functions:
|
for scoring_fn in benchmarks[selected_benchmark].scoring_functions:
|
||||||
if scoring_fn not in output_res:
|
if scoring_fn not in output_res:
|
||||||
output_res[scoring_fn] = []
|
output_res[scoring_fn] = []
|
||||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||||
|
@ -245,7 +245,7 @@ def native_evaluation_page():
|
||||||
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
||||||
st.title("📊 Evaluations (Generation + Scoring)")
|
st.title("📊 Evaluations (Generation + Scoring)")
|
||||||
|
|
||||||
select_eval_task_1()
|
select_benchmark_1()
|
||||||
define_eval_candidate_2()
|
define_eval_candidate_2()
|
||||||
run_evaluation_3()
|
run_evaluation_3()
|
||||||
|
|
||||||
|
|
277
llama_stack/models/llama/datatypes.py
Normal file
277
llama_stack/models/llama/datatypes.py
Normal file
|
@ -0,0 +1,277 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
# import all for backwards compatibility
|
||||||
|
from llama_models.datatypes import * # noqa: F403
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
register_schema(ToolCall)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolParamDefinition(BaseModel):
|
||||||
|
param_type: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
required: Optional[bool] = True
|
||||||
|
default: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolDefinition(BaseModel):
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
type: Literal["greedy"] = "greedy"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopPSamplingStrategy(BaseModel):
|
||||||
|
type: Literal["top_p"] = "top_p"
|
||||||
|
temperature: Optional[float] = Field(..., gt=0.0)
|
||||||
|
top_p: Optional[float] = 0.95
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopKSamplingStrategy(BaseModel):
|
||||||
|
type: Literal["top_k"] = "top_k"
|
||||||
|
top_k: int = Field(..., ge=1)
|
||||||
|
|
||||||
|
|
||||||
|
SamplingStrategy = register_schema(
|
||||||
|
Annotated[
|
||||||
|
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
],
|
||||||
|
name="SamplingStrategy",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SamplingParams(BaseModel):
|
||||||
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = 0
|
||||||
|
repetition_penalty: Optional[float] = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointQuantizationFormat(Enum):
|
||||||
|
# default format
|
||||||
|
bf16 = "bf16"
|
||||||
|
|
||||||
|
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||||
|
fp8_mixed = "fp8-mixed"
|
||||||
|
|
||||||
|
int8 = "int8"
|
||||||
|
|
||||||
|
int4 = "int4"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFamily(Enum):
|
||||||
|
llama2 = "llama2"
|
||||||
|
llama3 = "llama3"
|
||||||
|
llama3_1 = "llama3_1"
|
||||||
|
llama3_2 = "llama3_2"
|
||||||
|
llama3_3 = "llama3_3"
|
||||||
|
safety = "safety"
|
||||||
|
|
||||||
|
|
||||||
|
class CoreModelId(Enum):
|
||||||
|
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||||
|
|
||||||
|
# Llama 2 family
|
||||||
|
llama2_7b = "Llama-2-7b"
|
||||||
|
llama2_13b = "Llama-2-13b"
|
||||||
|
llama2_70b = "Llama-2-70b"
|
||||||
|
llama2_7b_chat = "Llama-2-7b-chat"
|
||||||
|
llama2_13b_chat = "Llama-2-13b-chat"
|
||||||
|
llama2_70b_chat = "Llama-2-70b-chat"
|
||||||
|
|
||||||
|
# Llama 3 family
|
||||||
|
llama3_8b = "Llama-3-8B"
|
||||||
|
llama3_70b = "Llama-3-70B"
|
||||||
|
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||||
|
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.1 family
|
||||||
|
llama3_1_8b = "Llama3.1-8B"
|
||||||
|
llama3_1_70b = "Llama3.1-70B"
|
||||||
|
llama3_1_405b = "Llama3.1-405B"
|
||||||
|
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||||
|
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||||
|
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.2 family
|
||||||
|
llama3_2_1b = "Llama3.2-1B"
|
||||||
|
llama3_2_3b = "Llama3.2-3B"
|
||||||
|
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||||
|
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||||
|
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||||
|
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||||
|
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||||
|
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.3 family
|
||||||
|
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||||
|
|
||||||
|
# Safety models
|
||||||
|
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||||
|
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||||
|
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||||
|
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||||
|
|
||||||
|
|
||||||
|
def is_multimodal(model_id) -> bool:
|
||||||
|
if model_id in [
|
||||||
|
CoreModelId.llama3_2_11b_vision,
|
||||||
|
CoreModelId.llama3_2_90b_vision,
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct,
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct,
|
||||||
|
]:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def model_family(model_id) -> ModelFamily:
|
||||||
|
if model_id in [
|
||||||
|
CoreModelId.llama2_7b,
|
||||||
|
CoreModelId.llama2_13b,
|
||||||
|
CoreModelId.llama2_70b,
|
||||||
|
CoreModelId.llama2_7b_chat,
|
||||||
|
CoreModelId.llama2_13b_chat,
|
||||||
|
CoreModelId.llama2_70b_chat,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama2
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_8b,
|
||||||
|
CoreModelId.llama3_70b,
|
||||||
|
CoreModelId.llama3_8b_instruct,
|
||||||
|
CoreModelId.llama3_70b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_1_8b,
|
||||||
|
CoreModelId.llama3_1_70b,
|
||||||
|
CoreModelId.llama3_1_405b,
|
||||||
|
CoreModelId.llama3_1_8b_instruct,
|
||||||
|
CoreModelId.llama3_1_70b_instruct,
|
||||||
|
CoreModelId.llama3_1_405b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_1
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_2_1b,
|
||||||
|
CoreModelId.llama3_2_3b,
|
||||||
|
CoreModelId.llama3_2_1b_instruct,
|
||||||
|
CoreModelId.llama3_2_3b_instruct,
|
||||||
|
CoreModelId.llama3_2_11b_vision,
|
||||||
|
CoreModelId.llama3_2_90b_vision,
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct,
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_2
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_3_70b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_3
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama_guard_3_8b,
|
||||||
|
CoreModelId.llama_guard_2_8b,
|
||||||
|
CoreModelId.llama_guard_3_11b_vision,
|
||||||
|
CoreModelId.llama_guard_3_1b,
|
||||||
|
]:
|
||||||
|
return ModelFamily.safety
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model family for {model_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
core_model_id: CoreModelId
|
||||||
|
description: str
|
||||||
|
huggingface_repo: Optional[str] = None
|
||||||
|
recommended_sampling_params: Optional[SamplingParams] = None
|
||||||
|
arch_args: Dict[str, Any]
|
||||||
|
variant: str = ""
|
||||||
|
|
||||||
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
|
pth_file_count: int
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
# silence pydantic until we remove the `model_` fields
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_family(self) -> ModelFamily:
|
||||||
|
return model_family(self.core_model_id)
|
||||||
|
|
||||||
|
# The SKU is uniquely identified by (model_id, variant) combo
|
||||||
|
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||||
|
if not self.variant:
|
||||||
|
return self.core_model_id.value
|
||||||
|
return f"{self.core_model_id.value}:{self.variant}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_instruct_model(self) -> bool:
|
||||||
|
return "instruct" in self.id.name
|
||||||
|
|
||||||
|
# Featured models are shown in the non-exhaustive model list
|
||||||
|
@property
|
||||||
|
def is_featured(self) -> bool:
|
||||||
|
return self.model_family in [
|
||||||
|
ModelFamily.llama3_1,
|
||||||
|
ModelFamily.llama3_2,
|
||||||
|
ModelFamily.llama3_3,
|
||||||
|
ModelFamily.safety,
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_seq_length(self) -> int:
|
||||||
|
if self.model_family == ModelFamily.llama2:
|
||||||
|
return 4096
|
||||||
|
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||||
|
return 4096
|
||||||
|
elif self.model_family == ModelFamily.llama3:
|
||||||
|
return 8192
|
||||||
|
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||||
|
return 131072
|
||||||
|
elif self.model_family == ModelFamily.llama3_2:
|
||||||
|
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||||
|
return 8192
|
||||||
|
return 131072
|
||||||
|
elif self.core_model_id in [
|
||||||
|
CoreModelId.llama_guard_3_8b,
|
||||||
|
CoreModelId.llama_guard_3_11b_vision,
|
||||||
|
CoreModelId.llama_guard_3_1b,
|
||||||
|
]:
|
||||||
|
return 131072
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
BIN
llama_stack/models/llama/llama3/dog.jpg
Normal file
BIN
llama_stack/models/llama/llama3/dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
257
llama_stack/models/llama/llama3/interface.py
Normal file
257
llama_stack/models/llama/llama3/interface.py
Normal file
|
@ -0,0 +1,257 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||||
|
|
||||||
|
from . import template_data
|
||||||
|
from .prompt_templates import (
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
FunctionTagCustomToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
ToolResponseGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
THIS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
class Template:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role,
|
||||||
|
template_name,
|
||||||
|
data_provider=None,
|
||||||
|
notes=None,
|
||||||
|
):
|
||||||
|
self.role = role
|
||||||
|
self.template_name = template_name
|
||||||
|
self.data_provider = data_provider or ""
|
||||||
|
self._notes = notes or ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def notes(self):
|
||||||
|
default = "↵ represents newline"
|
||||||
|
notes = default
|
||||||
|
if self._notes:
|
||||||
|
notes += "\n"
|
||||||
|
notes += self._notes
|
||||||
|
return notes
|
||||||
|
|
||||||
|
|
||||||
|
TEMPLATES = [
|
||||||
|
Template(
|
||||||
|
"user",
|
||||||
|
"user-default",
|
||||||
|
"user_default",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"user",
|
||||||
|
"user-images",
|
||||||
|
"user_images",
|
||||||
|
),
|
||||||
|
Template("user", "user-interleaved-images", "user_interleaved_images"),
|
||||||
|
Template(
|
||||||
|
"assistant",
|
||||||
|
"assistant-builtin-tool-call",
|
||||||
|
"assistant_builtin_tool_call",
|
||||||
|
"Notice <|python_tag|>",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"assistant",
|
||||||
|
"assistant-custom-tool-call",
|
||||||
|
"assistant_custom_tool_call",
|
||||||
|
"Notice <function=...> format",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"assistant",
|
||||||
|
"assistant-default",
|
||||||
|
"assistant_default",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"system",
|
||||||
|
"system-builtin-and-custom-tools",
|
||||||
|
"system_message_builtin_and_custom_tools",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"system",
|
||||||
|
"system-builtin-tools-only",
|
||||||
|
"system_message_builtin_tools_only",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"system",
|
||||||
|
"system-custom-tools-only",
|
||||||
|
"system_message_custom_tools_only",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"system",
|
||||||
|
"system-default",
|
||||||
|
"system_default",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"tool",
|
||||||
|
"tool-success",
|
||||||
|
"tool_success",
|
||||||
|
"Note ipython header and [stdout]",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"tool",
|
||||||
|
"tool-failure",
|
||||||
|
"tool_failure",
|
||||||
|
"Note ipython header and [stderr]",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LLama31Interface:
|
||||||
|
def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json):
|
||||||
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
self.tool_prompt_format = tool_prompt_format
|
||||||
|
|
||||||
|
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
|
||||||
|
model_input = self.formatter.encode_dialog_prompt(
|
||||||
|
messages,
|
||||||
|
self.tool_prompt_format,
|
||||||
|
)
|
||||||
|
return model_input.tokens
|
||||||
|
|
||||||
|
def tool_response_messages(self, *args, **kwargs):
|
||||||
|
template = ToolResponseGenerator().gen(*args, **kwargs)
|
||||||
|
return [
|
||||||
|
RawMessage(
|
||||||
|
role="tool",
|
||||||
|
content=template.render(),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def system_messages(
|
||||||
|
self,
|
||||||
|
builtin_tools: List[BuiltinTool],
|
||||||
|
custom_tools: List[ToolDefinition],
|
||||||
|
instruction: Optional[str] = None,
|
||||||
|
) -> List[RawMessage]:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
default_gen = SystemDefaultGenerator()
|
||||||
|
default_template = default_gen.gen()
|
||||||
|
|
||||||
|
sys_content = ""
|
||||||
|
|
||||||
|
tool_template = None
|
||||||
|
if builtin_tools or custom_tools:
|
||||||
|
tool_gen = BuiltinToolGenerator()
|
||||||
|
tool_template = tool_gen.gen(builtin_tools + custom_tools)
|
||||||
|
|
||||||
|
sys_content += tool_template.render()
|
||||||
|
sys_content += "\n"
|
||||||
|
|
||||||
|
sys_content += default_template.render()
|
||||||
|
|
||||||
|
if instruction:
|
||||||
|
sys_content += "\n\n"
|
||||||
|
sys_content += instruction
|
||||||
|
|
||||||
|
sys_content += "\n"
|
||||||
|
messages.append(RawMessage(role="system", content=sys_content))
|
||||||
|
|
||||||
|
if custom_tools:
|
||||||
|
if self.tool_prompt_format == ToolPromptFormat.json:
|
||||||
|
tool_gen = JsonCustomToolGenerator()
|
||||||
|
elif self.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||||
|
tool_gen = FunctionTagCustomToolGenerator()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
|
||||||
|
|
||||||
|
custom_template = tool_gen.gen(custom_tools)
|
||||||
|
messages.append(RawMessage(role="user", content=custom_template.render()))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def assistant_response_messages(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
stop_reason: StopReason,
|
||||||
|
tool_call: Optional[ToolCall] = None,
|
||||||
|
) -> List[RawMessage]:
|
||||||
|
tool_calls = []
|
||||||
|
if tool_call:
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
return [
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def user_message(self, content: str) -> List[RawMessage]:
|
||||||
|
return [RawMessage(role="user", content=content)]
|
||||||
|
|
||||||
|
def display_message_as_tokens(self, message: RawMessage) -> None:
|
||||||
|
"""Util to print tokenized string to shell"""
|
||||||
|
tokens = self.formatter.encode_message(message, self.tool_prompt_format)
|
||||||
|
on_colors = [
|
||||||
|
"on_red",
|
||||||
|
"on_green",
|
||||||
|
"on_yellow",
|
||||||
|
"on_blue",
|
||||||
|
"on_magenta",
|
||||||
|
"on_cyan",
|
||||||
|
]
|
||||||
|
for i, t in enumerate(tokens):
|
||||||
|
on_col = on_colors[i % len(on_colors)]
|
||||||
|
print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
|
||||||
|
print("\n", end="")
|
||||||
|
|
||||||
|
|
||||||
|
def list_jinja_templates() -> List[Template]:
|
||||||
|
return TEMPLATES
|
||||||
|
|
||||||
|
|
||||||
|
def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
|
||||||
|
by_name = {t.template_name: t for t in TEMPLATES}
|
||||||
|
if name not in by_name:
|
||||||
|
raise ValueError(f"No template found for `{name}`")
|
||||||
|
|
||||||
|
template = by_name[name]
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
data_func = getattr(template_data, template.data_provider)
|
||||||
|
if template.role == "system":
|
||||||
|
messages = interface.system_messages(**data_func())
|
||||||
|
elif template.role == "tool":
|
||||||
|
messages = interface.tool_response_messages(**data_func())
|
||||||
|
elif template.role == "assistant":
|
||||||
|
messages = interface.assistant_response_messages(**data_func())
|
||||||
|
elif template.role == "user":
|
||||||
|
messages = interface.user_message(**data_func())
|
||||||
|
|
||||||
|
tokens = interface.get_tokens(messages)
|
||||||
|
special_tokens = list(interface.tokenizer.special_tokens.values())
|
||||||
|
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
|
||||||
|
return template, tokens
|
BIN
llama_stack/models/llama/llama3/pasta.jpeg
Normal file
BIN
llama_stack/models/llama/llama3/pasta.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
22
llama_stack/models/llama/llama3/prompt_templates/__init__.py
Normal file
22
llama_stack/models/llama/llama3/prompt_templates/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from .base import PromptTemplate, PromptTemplateGeneratorBase # noqa: F401
|
||||||
|
from .system_prompts import ( # noqa: F401
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
FunctionTagCustomToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
PythonListCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
)
|
||||||
|
from .tool_response import ToolResponseGenerator # noqa: F401
|
39
llama_stack/models/llama/llama3/prompt_templates/base.py
Normal file
39
llama_stack/models/llama/llama3/prompt_templates/base.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptTemplate:
|
||||||
|
template: str
|
||||||
|
data: Dict[str, Any]
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
template = Template(self.template)
|
||||||
|
return template.render(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateGeneratorBase:
|
||||||
|
"""
|
||||||
|
Base class for prompt template generators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def gen(self, *args, **kwargs) -> PromptTemplate:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def data_examples(self) -> List[Any]:
|
||||||
|
raise NotImplementedError()
|
|
@ -0,0 +1,311 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
ToolDefinition,
|
||||||
|
ToolParamDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||||
|
|
||||||
|
|
||||||
|
class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def gen(self, *args, **kwargs) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: {{ today }}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{"today": datetime.now().strftime("%d %B %Y")},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self) -> List[Any]:
|
||||||
|
return [None]
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def _tool_breakdown(self, tools: List[ToolDefinition]):
|
||||||
|
builtin_tools, custom_tools = [], []
|
||||||
|
for dfn in tools:
|
||||||
|
if isinstance(dfn.tool_name, BuiltinTool):
|
||||||
|
builtin_tools.append(dfn)
|
||||||
|
else:
|
||||||
|
custom_tools.append(dfn)
|
||||||
|
|
||||||
|
return builtin_tools, custom_tools
|
||||||
|
|
||||||
|
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
builtin_tools, custom_tools = self._tool_breakdown(tools)
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
{% if builtin_tools or custom_tools -%}
|
||||||
|
Environment: ipython
|
||||||
|
{% endif -%}
|
||||||
|
{% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%}
|
||||||
|
{% if builtin_tools -%}
|
||||||
|
Tools: {{ builtin_tools | join(", ") | trim -}}
|
||||||
|
{% endif %}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{
|
||||||
|
"builtin_tools": [t.tool_name.value for t in builtin_tools],
|
||||||
|
"custom_tools": custom_tools,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
# builtin tools
|
||||||
|
[
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.wolfram_alpha),
|
||||||
|
],
|
||||||
|
# only code interpretor
|
||||||
|
[
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Answer the user's question by making use of the following functions if needed.
|
||||||
|
If none of the function can be used, please say so.
|
||||||
|
Here is a list of functions in JSON format:
|
||||||
|
{% for t in custom_tools -%}
|
||||||
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set tparams = t.parameters -%}
|
||||||
|
{%- set required_params = [] -%}
|
||||||
|
{%- for name, param in tparams.items() if param.required == true -%}
|
||||||
|
{%- set _ = required_params.append(name) -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "{{tname}}",
|
||||||
|
"description": "{{tdesc}}",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": [
|
||||||
|
{%- for name, param in tparams.items() %}
|
||||||
|
{
|
||||||
|
"{{name}}": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "{{param.description}}"
|
||||||
|
}
|
||||||
|
}{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
],
|
||||||
|
"required": {{ required_params | tojson }}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{% endfor %}
|
||||||
|
Return function calls in JSON format.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="trending_songs",
|
||||||
|
description="Returns the trending songs on a Music site",
|
||||||
|
parameters={
|
||||||
|
"n": ToolParamDefinition(
|
||||||
|
param_type="int",
|
||||||
|
description="The number of songs to return",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"genre": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="The genre of the songs to return",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You have access to the following functions:
|
||||||
|
|
||||||
|
{% for t in custom_tools %}
|
||||||
|
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set modified_params = t.parameters.copy() -%}
|
||||||
|
{%- for key, value in modified_params.items() -%}
|
||||||
|
{%- if 'default' in value -%}
|
||||||
|
{%- set _ = value.pop('default', None) -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- set tparams = modified_params | tojson -%}
|
||||||
|
Use the function '{{ tname }}' to '{{ tdesc }}':
|
||||||
|
{"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}}
|
||||||
|
|
||||||
|
{% endfor -%}
|
||||||
|
Think very carefully before calling functions.
|
||||||
|
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||||
|
|
||||||
|
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||||
|
|
||||||
|
Reminder:
|
||||||
|
- If looking for real time information use relevant functions before falling back to brave_search
|
||||||
|
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- Only call one function at a time
|
||||||
|
- Put the entire function call reply on one line
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="trending_songs",
|
||||||
|
description="Returns the trending songs on a Music site",
|
||||||
|
parameters={
|
||||||
|
"n": ToolParamDefinition(
|
||||||
|
param_type="int",
|
||||||
|
description="The number of songs to return",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"genre": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="The genre of the songs to return",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
|
DEFAULT_PROMPT = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
""".strip("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||||
|
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||||
|
return PromptTemplate(
|
||||||
|
system_prompt,
|
||||||
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{% for t in tools -%}
|
||||||
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set tparams = t.parameters -%}
|
||||||
|
{%- set required_params = [] -%}
|
||||||
|
{%- for name, param in tparams.items() if param.required == true -%}
|
||||||
|
{%- set _ = required_params.append(name) -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{
|
||||||
|
"name": "{{tname}}",
|
||||||
|
"description": "{{tdesc}}",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": {{ required_params | tojson }},
|
||||||
|
"properties": {
|
||||||
|
{%- for name, param in tparams.items() %}
|
||||||
|
"{{name}}": {
|
||||||
|
"type": "{{param.param_type}}",
|
||||||
|
"description": "{{param.description}}"{% if param.default %},
|
||||||
|
"default": "{{param.default}}"{% endif %}
|
||||||
|
}{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}{% if not loop.last %},
|
||||||
|
{% endif -%}
|
||||||
|
{%- endfor %}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.strip("\n"),
|
||||||
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
).render()
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get weather info for places",
|
||||||
|
parameters={
|
||||||
|
"city": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The name of the city to get the weather for",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"metric": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResponseGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def gen(
|
||||||
|
self,
|
||||||
|
status: str,
|
||||||
|
stdout: Optional[str] = None,
|
||||||
|
stderr: Optional[str] = None,
|
||||||
|
):
|
||||||
|
assert status in [
|
||||||
|
"success",
|
||||||
|
"failure",
|
||||||
|
], f"status must be 'success' or 'failure'; Got: {status}"
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
{% if status == "success" %}completed{% else %}failed{% endif %}
|
||||||
|
{%- if stdout %}
|
||||||
|
[stdout]{{ stdout }}[/stdout]
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if stderr %}
|
||||||
|
[stderr]{{ stderr }}[/stderr]
|
||||||
|
{%- endif -%}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{
|
||||||
|
"status": status,
|
||||||
|
"stdout": stdout,
|
||||||
|
"stderr": stderr,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self):
|
||||||
|
return [
|
||||||
|
# success
|
||||||
|
{
|
||||||
|
"status": "success",
|
||||||
|
"stdout": '{"results":["something something"]}',
|
||||||
|
},
|
||||||
|
# failure
|
||||||
|
{
|
||||||
|
"status": "failure",
|
||||||
|
"stderr": "brave_search encounter an error: could not communicate with api.brave.com",
|
||||||
|
},
|
||||||
|
]
|
120
llama_stack/models/llama/llama3/template_data.py
Normal file
120
llama_stack/models/llama/llama3/template_data.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .prompt_templates import (
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
ToolResponseGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
INSTRUCTION = "You are a helpful assistant."
|
||||||
|
|
||||||
|
|
||||||
|
def system_message_builtin_tools_only():
|
||||||
|
return {
|
||||||
|
"builtin_tools": BuiltinToolGenerator().data_examples()[0],
|
||||||
|
"custom_tools": [],
|
||||||
|
"instruction": INSTRUCTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def system_message_builtin_code_only():
|
||||||
|
return {
|
||||||
|
"builtin_tools": BuiltinToolGenerator().data_examples()[1],
|
||||||
|
"custom_tools": [],
|
||||||
|
"instruction": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def system_message_custom_tools_only():
|
||||||
|
return {
|
||||||
|
"builtin_tools": [],
|
||||||
|
"custom_tools": JsonCustomToolGenerator().data_examples()[0],
|
||||||
|
"instruction": INSTRUCTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def system_message_builtin_and_custom_tools():
|
||||||
|
return {
|
||||||
|
"builtin_tools": BuiltinToolGenerator().data_examples()[0],
|
||||||
|
"custom_tools": JsonCustomToolGenerator().data_examples()[0],
|
||||||
|
"instruction": INSTRUCTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def system_default():
|
||||||
|
return {
|
||||||
|
"builtin_tools": [],
|
||||||
|
"custom_tools": [],
|
||||||
|
"instruction": INSTRUCTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tool_success():
|
||||||
|
return ToolResponseGenerator().data_examples()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def tool_failure():
|
||||||
|
return ToolResponseGenerator().data_examples()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def assistant_builtin_tool_call():
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"tool_call": ToolCall(
|
||||||
|
call_id="uuid",
|
||||||
|
tool_name=BuiltinTool.brave_search,
|
||||||
|
arguments={
|
||||||
|
"query": "Who won NBA in 2024?",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"stop_reason": StopReason.end_of_message,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def assistant_custom_tool_call():
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"tool_call": ToolCall(
|
||||||
|
call_id="uuid",
|
||||||
|
tool_name="trending_songs",
|
||||||
|
arguments={"country": "US", "n": 10},
|
||||||
|
),
|
||||||
|
"stop_reason": StopReason.end_of_turn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def assistant_default():
|
||||||
|
return {
|
||||||
|
"content": "Hi, I am a helpful assistant. What can I help you with today?",
|
||||||
|
"tool_call": None,
|
||||||
|
"stop_reason": StopReason.end_of_turn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def user_default():
|
||||||
|
return {"content": "Please tell me how to plan a trip to New York"}
|
||||||
|
|
||||||
|
|
||||||
|
def user_images():
|
||||||
|
return {"content": "<|image|><|image|>What do these images depict?"}
|
||||||
|
|
||||||
|
|
||||||
|
def user_interleaved_images():
|
||||||
|
return {"content": "<|image|>Describe the image in one sentence.<|image|>Write a haiku about these images"}
|
199
llama_stack/models/llama/llama3/test_system_prompts.py
Normal file
199
llama_stack/models/llama/llama3/test_system_prompts.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
import unittest
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .prompt_templates import (
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
FunctionTagCustomToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
PythonListCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateTests(unittest.TestCase):
|
||||||
|
def check_generator_output(self, generator, expected_text):
|
||||||
|
example = generator.data_examples()[0]
|
||||||
|
|
||||||
|
pt = generator.gen(example)
|
||||||
|
text = pt.render()
|
||||||
|
# print(text) # debugging
|
||||||
|
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||||
|
|
||||||
|
def test_system_default(self):
|
||||||
|
generator = SystemDefaultGenerator()
|
||||||
|
today = datetime.now().strftime("%d %B %Y")
|
||||||
|
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||||
|
self.check_generator_output(generator, expected_text)
|
||||||
|
|
||||||
|
def test_system_builtin_only(self):
|
||||||
|
generator = BuiltinToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||||
|
|
||||||
|
def test_system_custom_only(self):
|
||||||
|
self.maxDiff = None
|
||||||
|
generator = JsonCustomToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Answer the user's question by making use of the following functions if needed.
|
||||||
|
If none of the function can be used, please say so.
|
||||||
|
Here is a list of functions in JSON format:
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "trending_songs",
|
||||||
|
"description": "Returns the trending songs on a Music site",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"n": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The number of songs to return"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"genre": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The genre of the songs to return"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"required": ["n"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Return function calls in JSON format.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||||
|
|
||||||
|
def test_system_custom_function_tag(self):
|
||||||
|
self.maxDiff = None
|
||||||
|
generator = FunctionTagCustomToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You have access to the following functions:
|
||||||
|
|
||||||
|
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||||
|
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||||
|
|
||||||
|
Think very carefully before calling functions.
|
||||||
|
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||||
|
|
||||||
|
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||||
|
|
||||||
|
Reminder:
|
||||||
|
- If looking for real time information use relevant functions before falling back to brave_search
|
||||||
|
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- Only call one function at a time
|
||||||
|
- Put the entire function call reply on one line
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||||
|
|
||||||
|
def test_llama_3_2_system_zero_shot(self):
|
||||||
|
generator = PythonListCustomToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": ["city"],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||||
|
|
||||||
|
def test_llama_3_2_provided_system_prompt(self):
|
||||||
|
generator = PythonListCustomToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Overriding message.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": ["city"],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]"""
|
||||||
|
)
|
||||||
|
user_system_prompt = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Overriding message.
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
example = generator.data_examples()[0]
|
||||||
|
|
||||||
|
pt = generator.gen(example, user_system_prompt)
|
||||||
|
text = pt.render()
|
||||||
|
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
12
llama_stack/models/llama/llama3_1/__init__.py
Normal file
12
llama_stack/models/llama/llama3_1/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
259
llama_stack/models/llama/llama3_1/prompts.py
Normal file
259
llama_stack/models/llama/llama3_1/prompts.py
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..prompt_format import (
|
||||||
|
# llama3_1_e2e_tool_call_dialog,
|
||||||
|
TextCompletionContent,
|
||||||
|
UseCase,
|
||||||
|
llama3_1_builtin_tool_call_dialog,
|
||||||
|
llama3_1_custom_tool_call_dialog,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wolfram_alpha_response():
|
||||||
|
return textwrap.dedent(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"queryresult": {
|
||||||
|
"success": true,
|
||||||
|
"inputstring": "100th decimal of pi",
|
||||||
|
"pods": [
|
||||||
|
{
|
||||||
|
"title": "Input interpretation",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "100th digit | \u03c0"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Nearby digits",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "...86208998628034825342117067982148086513282306647093..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Result",
|
||||||
|
"primary": true,
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "7"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def usecases() -> List[UseCase | str]:
|
||||||
|
return [
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
# Llama 3.1 - Prompt Formats
|
||||||
|
## Tokens
|
||||||
|
Here is a list of special tokens that are supported by Llama 3.1:
|
||||||
|
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||||
|
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||||
|
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
|
||||||
|
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool]
|
||||||
|
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
|
||||||
|
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||||
|
- at the end of a direct interaction between the model and the user
|
||||||
|
- at the end of multiple interactions between the model and any available tools
|
||||||
|
This token signals to the executor that the model has finished generating a response.
|
||||||
|
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
There are 4 different roles that are supported by Llama 3.1
|
||||||
|
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||||
|
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||||
|
- `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".)
|
||||||
|
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Llama 3.1 Base Model",
|
||||||
|
description="Text completion for Llama 3.1 base model uses this format.",
|
||||||
|
dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")],
|
||||||
|
notes="Note start special tag",
|
||||||
|
),
|
||||||
|
"## Llama 3.1 Instruct Model",
|
||||||
|
UseCase(
|
||||||
|
title="User and assistant conversation",
|
||||||
|
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="You are a helpful assistant"),
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content="Answer who are you in the form of jeopardy?",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
notes="",
|
||||||
|
),
|
||||||
|
"## Tool Calling Formats",
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
|
||||||
|
- Brave Search: Tool call to perform web searches.
|
||||||
|
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
|
||||||
|
- Code Interpreter: Enables the model to output python code.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin Tool Calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of a conversation using brave search
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_builtin_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
|
||||||
|
- The message body of the assistant response starts with a special tag <|python_tag|>
|
||||||
|
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
|
||||||
|
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin Code Interpreter",
|
||||||
|
description="Here is an actual example of model responding with code",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="Environment: ipython"),
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content="Write code to check if number is prime, use that to see if the number 7 is prime",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
|
||||||
|
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Built-in tools full interaction",
|
||||||
|
description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(
|
||||||
|
role="system",
|
||||||
|
content="Environment: ipython\nTools: brave_search, wolfram_alpha\n",
|
||||||
|
),
|
||||||
|
RawMessage(role="user", content="What is the 100th decimal of pi?"),
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="tool_call_id",
|
||||||
|
tool_name=BuiltinTool.wolfram_alpha,
|
||||||
|
arguments={"query": "100th decimal of pi"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
RawMessage(
|
||||||
|
role="tool",
|
||||||
|
content=wolfram_alpha_response(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Note the `<|python_tag|>` in the assistant response.
|
||||||
|
- Role is `tool` for the wolfram alpha response that is passed back to the model.
|
||||||
|
- Final message from assistant has <|eot_id|> tag.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"## Zero shot tool calling",
|
||||||
|
UseCase(
|
||||||
|
title="JSON based tool calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Llama models can now output custom tool calls from a single message to allow easier tool calling.
|
||||||
|
The following prompts provide an example of how custom tools can be called from the output of the model.
|
||||||
|
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_custom_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- JSON format for providing tools needs name, description and parameters
|
||||||
|
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
- Only single tool calls are supported as of now
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# FIXME: This is not working yet as expected
|
||||||
|
# UseCase(
|
||||||
|
# title="E2E tool call example",
|
||||||
|
# description=textwrap.dedent(
|
||||||
|
# """
|
||||||
|
# Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model.
|
||||||
|
# """
|
||||||
|
# ),
|
||||||
|
# dialogs=[
|
||||||
|
# llama3_1_e2e_tool_call_dialog(
|
||||||
|
# tool_prompt_format=ToolPromptFormat.function_tag
|
||||||
|
# )
|
||||||
|
# ],
|
||||||
|
# notes="",
|
||||||
|
# ),
|
||||||
|
"## Example of a user defined tool calling",
|
||||||
|
UseCase(
|
||||||
|
title="`<function>` based tool calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||||
|
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
12
llama_stack/models/llama/llama3_2/__init__.py
Normal file
12
llama_stack/models/llama/llama3_2/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
235
llama_stack/models/llama/llama3_2/prompts_text.py
Normal file
235
llama_stack/models/llama/llama3_2/prompts_text.py
Normal file
|
@ -0,0 +1,235 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
import json
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
RawMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..prompt_format import (
|
||||||
|
TextCompletionContent,
|
||||||
|
UseCase,
|
||||||
|
llama3_1_builtin_code_interpreter_dialog,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def user_tool_call():
|
||||||
|
content = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
|
||||||
|
Here is a list of functions in JSON format that you can invoke:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_user_info",
|
||||||
|
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"user_id"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"user_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
|
||||||
|
},
|
||||||
|
"special": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Any special information or parameters that need to be considered while fetching user details.",
|
||||||
|
"default": "none"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
|
||||||
|
|
||||||
|
NO other text MUST be included.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def system_tool_call():
|
||||||
|
content = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"city"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def usecases():
|
||||||
|
return [
|
||||||
|
UseCase(
|
||||||
|
title="User and assistant conversation",
|
||||||
|
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="You are a helpful assistant"),
|
||||||
|
RawMessage(role="user", content="Who are you?"),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
notes="This format is unchanged from Llama3.1",
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Zero shot function calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
|
||||||
|
This new format is designed to be more flexible and powerful than the previous format.
|
||||||
|
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
|
||||||
|
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
|
||||||
|
Here is an example for the same,
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
# Zero shot tool calls as system message
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content=system_tool_call()),
|
||||||
|
RawMessage(role="user", content="What is the weather in SF and Seattle?"),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- The output supports multiple tool calls natively
|
||||||
|
- JSON format for defining the functions in the system prompt is similar to Llama3.1
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Zero shot function calling with user message",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
# Zero shot tool call as user message
|
||||||
|
[
|
||||||
|
RawMessage(role="user", content=user_tool_call()),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
|
||||||
|
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Code Interpreter",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
|
||||||
|
Here is an example,
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_builtin_code_interpreter_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Note `Environment: ipython` in the system prompt.
|
||||||
|
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Zero shot function calling E2E format",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content=system_tool_call()),
|
||||||
|
RawMessage(role="user", content="What is the weather in SF?"),
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="cc",
|
||||||
|
tool_name="get_weather",
|
||||||
|
arguments={
|
||||||
|
"city": "San Francisco",
|
||||||
|
"metric": "celsius",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
RawMessage(
|
||||||
|
role="tool",
|
||||||
|
content=json.dumps("25 C"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- The output of the function call is provided back to the model as a tool response ( in json format ).
|
||||||
|
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
|
||||||
|
- The model finally summarizes the information from the tool response and returns the result to the user.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
tool_prompt_format=ToolPromptFormat.python_list,
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Prompt format for base models",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"),
|
||||||
|
],
|
||||||
|
notes="Same as Llama3.1",
|
||||||
|
),
|
||||||
|
]
|
133
llama_stack/models/llama/llama3_2/prompts_vision.py
Normal file
133
llama_stack/models/llama/llama3_2/prompts_vision.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
RawMediaItem,
|
||||||
|
RawMessage,
|
||||||
|
RawTextItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..prompt_format import (
|
||||||
|
TextCompletionContent,
|
||||||
|
UseCase,
|
||||||
|
llama3_1_builtin_tool_call_dialog,
|
||||||
|
# llama3_1_builtin_tool_call_with_image_dialog,
|
||||||
|
llama3_2_user_assistant_conversation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def usecases():
|
||||||
|
this_dir = Path(__file__).parent.parent.resolve()
|
||||||
|
with open(this_dir / "scripts/resources/dog.jpg", "rb") as f:
|
||||||
|
img = f.read()
|
||||||
|
|
||||||
|
return [
|
||||||
|
llama3_2_user_assistant_conversation(),
|
||||||
|
UseCase(
|
||||||
|
title="User and assistant conversation with Images",
|
||||||
|
description="This example shows how to pass and image to the model as part of the messages.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
RawMediaItem(data=img),
|
||||||
|
RawTextItem(text="Describe this image in two sentences"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- The `<|image|>` tag is used to indicate presence of the image
|
||||||
|
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
|
||||||
|

|
||||||
|
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
|
||||||
|
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
|
||||||
|
- We recommend using a single image in one prompt
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin and Zero Shot Tool Calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
|
||||||
|
Use `Environment: ipython` to enable tools.
|
||||||
|
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
|
||||||
|
The same builtin tools as Llama3.1 are available,
|
||||||
|
- code_interpreter (for executing python code)
|
||||||
|
- brave_search (to search the web)
|
||||||
|
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_builtin_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Note the `<|python_tag|>` before `brave_search` function call.
|
||||||
|
- The `<|eom_id|>` tag is used to indicate the end of the message.
|
||||||
|
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
|
||||||
|
- Tool Calling does NOT work with images in the prompt as of now.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# UseCase(
|
||||||
|
# title="Tool Calling for vision models",
|
||||||
|
# description=textwrap.dedent(
|
||||||
|
# """
|
||||||
|
# While Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only,
|
||||||
|
# they are not able to do tool calling when prompt contains image inputs (along with text).
|
||||||
|
# The recommended way would be to separate out the image understanding from the tool calling in successive prompts.
|
||||||
|
# Here is an example of how that could be done,
|
||||||
|
# """,
|
||||||
|
# ),
|
||||||
|
# dialogs=[llama3_1_builtin_tool_call_with_image_dialog()],
|
||||||
|
# notes=textwrap.dedent(
|
||||||
|
# """
|
||||||
|
# - Instead of a single prompt (image understanding + tool call), we split into two prompts to achieve the same result.
|
||||||
|
# """
|
||||||
|
# ),
|
||||||
|
# ),
|
||||||
|
UseCase(
|
||||||
|
title="Prompt format for base models",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"),
|
||||||
|
],
|
||||||
|
notes="- Same as Llama3.1",
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Prompt format for base models with Image",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
TextCompletionContent(
|
||||||
|
content=[
|
||||||
|
RawMediaItem(data=img),
|
||||||
|
RawTextItem(text="If I had to write a haiku for this one"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
notes="- Note the placement of the special tags <|begin_of_text|> and <|image|>",
|
||||||
|
),
|
||||||
|
]
|
258
llama_stack/models/llama/llama3_3/prompts.py
Normal file
258
llama_stack/models/llama/llama3_3/prompts.py
Normal file
|
@ -0,0 +1,258 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..prompt_format import (
|
||||||
|
# llama3_1_e2e_tool_call_dialog,
|
||||||
|
TextCompletionContent,
|
||||||
|
UseCase,
|
||||||
|
llama3_1_builtin_tool_call_dialog,
|
||||||
|
llama3_1_custom_tool_call_dialog,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wolfram_alpha_response():
|
||||||
|
return textwrap.dedent(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"queryresult": {
|
||||||
|
"success": true,
|
||||||
|
"inputstring": "100th decimal of pi",
|
||||||
|
"pods": [
|
||||||
|
{
|
||||||
|
"title": "Input interpretation",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "100th digit | \u03c0"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Nearby digits",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "...86208998628034825342117067982148086513282306647093..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Result",
|
||||||
|
"primary": true,
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "7"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def usecases() -> List[UseCase | str]:
|
||||||
|
return [
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
# Llama 3.1 - Prompt Formats
|
||||||
|
## Tokens
|
||||||
|
Here is a list of special tokens that are supported by Llama 3.1:
|
||||||
|
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||||
|
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||||
|
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
|
||||||
|
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool]
|
||||||
|
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
|
||||||
|
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||||
|
- at the end of a direct interaction between the model and the user
|
||||||
|
- at the end of multiple interactions between the model and any available tools
|
||||||
|
This token signals to the executor that the model has finished generating a response.
|
||||||
|
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
There are 4 different roles that are supported by Llama 3.1
|
||||||
|
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||||
|
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||||
|
- `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".)
|
||||||
|
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Llama 3.1 Base Model",
|
||||||
|
description="Text completion for Llama 3.1 base model uses this format.",
|
||||||
|
dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")],
|
||||||
|
notes="Note start special tag",
|
||||||
|
),
|
||||||
|
"## Llama 3.1 Instruct Model",
|
||||||
|
UseCase(
|
||||||
|
title="User and assistant conversation",
|
||||||
|
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="You are a helpful assistant"),
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content="Answer who are you in the form of jeopardy?",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
notes="",
|
||||||
|
),
|
||||||
|
"## Tool Calling Formats",
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
|
||||||
|
- Brave Search: Tool call to perform web searches.
|
||||||
|
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
|
||||||
|
- Code Interpreter: Enables the model to output python code.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin Tool Calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of a conversation using brave search
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_builtin_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
|
||||||
|
- The message body of the assistant response starts with a special tag <|python_tag|>
|
||||||
|
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
|
||||||
|
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin Code Interpreter",
|
||||||
|
description="Here is an actual example of model responding with code",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="Environment: ipython"),
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content="Write code to check if number is prime, use that to see if the number 7 is prime",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
|
||||||
|
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Built-in tools full interaction",
|
||||||
|
description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(
|
||||||
|
role="system",
|
||||||
|
content="Environment: ipython\nTools: brave_search, wolfram_alpha\n",
|
||||||
|
),
|
||||||
|
RawMessage(role="user", content="What is the 100th decimal of pi?"),
|
||||||
|
RawMessage(
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="tool_call_id",
|
||||||
|
tool_name=BuiltinTool.wolfram_alpha,
|
||||||
|
arguments={"query": "100th decimal of pi"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
RawMessage(
|
||||||
|
role="tool",
|
||||||
|
content=wolfram_alpha_response(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Note the `<|python_tag|>` in the assistant response.
|
||||||
|
- Role is `tool` for the wolfram alpha response that is passed back to the model.
|
||||||
|
- Final message from assistant has <|eot_id|> tag.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"## Zero shot tool calling",
|
||||||
|
UseCase(
|
||||||
|
title="JSON based tool calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Llama models can now output custom tool calls from a single message to allow easier tool calling.
|
||||||
|
The following prompts provide an example of how custom tools can be called from the output of the model.
|
||||||
|
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_custom_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- JSON format for providing tools needs name, description and parameters
|
||||||
|
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
- Only single tool calls are supported as of now
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# FIXME: This is not working yet as expected
|
||||||
|
# UseCase(
|
||||||
|
# title="E2E tool call example",
|
||||||
|
# description=textwrap.dedent(
|
||||||
|
# """
|
||||||
|
# Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model.
|
||||||
|
# """
|
||||||
|
# ),
|
||||||
|
# dialogs=[
|
||||||
|
# llama3_1_e2e_tool_call_dialog(
|
||||||
|
# tool_prompt_format=ToolPromptFormat.function_tag
|
||||||
|
# )
|
||||||
|
# ],
|
||||||
|
# notes="",
|
||||||
|
# ),
|
||||||
|
"## Example of a user defined tool calling",
|
||||||
|
UseCase(
|
||||||
|
title="`<function>` based tool calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||||
|
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
204
llama_stack/models/llama/prompt_format.py
Normal file
204
llama_stack/models/llama/prompt_format.py
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
RawContent,
|
||||||
|
RawMediaItem,
|
||||||
|
RawMessage,
|
||||||
|
RawTextItem,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .llama3.interface import LLama31Interface
|
||||||
|
from .llama3.template_data import (
|
||||||
|
system_message_builtin_code_only,
|
||||||
|
system_message_builtin_tools_only,
|
||||||
|
system_message_custom_tools_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextCompletionContent(BaseModel):
|
||||||
|
content: RawContent = ""
|
||||||
|
|
||||||
|
|
||||||
|
class UseCase(BaseModel):
|
||||||
|
title: str = ""
|
||||||
|
description: str = ""
|
||||||
|
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
||||||
|
notes: str = ""
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
||||||
|
|
||||||
|
def md_format(self):
|
||||||
|
section = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
## {title}
|
||||||
|
|
||||||
|
{description}
|
||||||
|
|
||||||
|
{dialogs_text}
|
||||||
|
{notes}
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return section.lstrip()
|
||||||
|
|
||||||
|
def dialogs_to_text(self, generator) -> str:
|
||||||
|
def _code_block(text):
|
||||||
|
return f"```\n{text}\n```"
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for dialog in self.dialogs:
|
||||||
|
if isinstance(dialog, str):
|
||||||
|
text += dialog
|
||||||
|
text += "\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif isinstance(dialog, TextCompletionContent):
|
||||||
|
input_tokens, output_tokens = generator.text_completion_raw(
|
||||||
|
dialog.content,
|
||||||
|
max_gen_len=64,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.95,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||||
|
dialog,
|
||||||
|
max_gen_len=512,
|
||||||
|
temperature=0.0,
|
||||||
|
top_p=0.95,
|
||||||
|
tool_prompt_format=self.tool_prompt_format,
|
||||||
|
)
|
||||||
|
text += "##### Input Prompt Format\n"
|
||||||
|
|
||||||
|
# FIXME: This is added to undo the hack in chat_formatter where
|
||||||
|
# vision tokens are replaced with 128256.
|
||||||
|
input_tokens = [generator.formatter.vision_token if t == 128256 else t for t in input_tokens]
|
||||||
|
|
||||||
|
text += _code_block(generator.tokenizer.decode(input_tokens))
|
||||||
|
# TODO: Figure out if "↵" needs to be added for newlines or end or some indication
|
||||||
|
text += "\n\n"
|
||||||
|
text += "##### Model Response Format\n"
|
||||||
|
text += _code_block(generator.tokenizer.decode(output_tokens))
|
||||||
|
text += "\n\n"
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def to_text(self, generator):
|
||||||
|
section = self.md_format()
|
||||||
|
dialogs_text = self.dialogs_to_text(generator)
|
||||||
|
notes = f"##### Notes\n{self.notes}" if self.notes else ""
|
||||||
|
section = section.format(
|
||||||
|
title=self.title,
|
||||||
|
description=self.description,
|
||||||
|
dialogs_text=dialogs_text,
|
||||||
|
notes=notes,
|
||||||
|
)
|
||||||
|
return section
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_builtin_tools_only())
|
||||||
|
messages += interface.user_message(content="Search the web for the latest price of 1oz gold?")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_builtin_code_only())
|
||||||
|
messages += interface.user_message(
|
||||||
|
content="Write code to check if number is prime. Use it to verify if number 7 is prime"
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_builtin_tool_call_with_image_dialog(
|
||||||
|
tool_prompt_format=ToolPromptFormat.json,
|
||||||
|
):
|
||||||
|
this_dir = Path(__file__).parent
|
||||||
|
with open(this_dir / "llama3/dog.jpg", "rb") as f:
|
||||||
|
img = f.read()
|
||||||
|
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_builtin_tools_only())
|
||||||
|
messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")])
|
||||||
|
messages += interface.assistant_response_messages(
|
||||||
|
"Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix",
|
||||||
|
StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
messages += interface.user_message("Search the web for some food recommendations for the indentified breed")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_custom_tools_only())
|
||||||
|
messages += interface.user_message(content="Use tools to get latest trending songs")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
|
tool_response = json.dumps(["great song1", "awesome song2", "cool song3"])
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_custom_tools_only())
|
||||||
|
messages += interface.user_message(content="Use tools to get latest trending songs")
|
||||||
|
messages.append(
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="call_id",
|
||||||
|
tool_name="trending_songs",
|
||||||
|
arguments={"n": "10", "genre": "latest"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=tool_response,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_2_user_assistant_conversation():
|
||||||
|
return UseCase(
|
||||||
|
title="User and assistant conversation",
|
||||||
|
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="You are a helpful assistant"),
|
||||||
|
RawMessage(role="user", content="Who are you?"),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
notes="This format is unchanged from Llama3.1",
|
||||||
|
)
|
1000
llama_stack/models/llama/sku_list.py
Normal file
1000
llama_stack/models/llama/sku_list.py
Normal file
File diff suppressed because it is too large
Load diff
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue