diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..429abb494 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,5 @@ +# Each line is a file pattern followed by one or more owners. + +# These owners will be the default owners for everything in +# the repo. Unless a later match takes precedence, +* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 9a84eb3b9..d063dc129 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -295,13 +295,18 @@ As you can see above, each “distribution” details the “providers” it is Let's imagine you are working with a 8B-Instruct model. The following command will build a package (in the form of a Conda environment) _and_ configure it. As part of the configuration, you will be asked for some inputs (model_id, max_seq_len, etc.) Since we are working with a 8B model, we will name our build `8b-instruct` to help us remember the config. ``` -llama stack build local --name 8b-instruct +llama stack build ``` -Once it runs successfully , you should see some outputs in the form: +Once it runs, you will be prompted to enter build name and optional arguments, and should see some outputs in the form: ``` -$ llama stack build local --name 8b-instruct +$ llama stack build +Enter value for name (required): 8b-instruct +Enter value for distribution (default: local) (required): local +Enter value for api_providers (optional): +Enter value for image_type (default: conda) (required): + .... .... Successfully installed cfgv-3.4.0 distlib-0.3.8 identify-2.6.0 libcst-1.4.0 llama_toolchain-0.0.2 moreorless-0.4.0 nodeenv-1.9.1 pre-commit-3.8.0 stdlibs-2024.5.15 toml-0.10.2 tomlkit-0.13.0 trailrunner-1.4.0 ufmt-2.7.0 usort-1.0.8 virtualenv-20.26.3 @@ -312,17 +317,41 @@ Successfully setup conda environment. Configuring build... ... YAML configuration has been written to ~/.llama/builds/local/conda/8b-instruct.yaml +Target `8b-test` built with configuration at /home/xiyan/.llama/builds/local/conda/8b-test.yaml +Build spec configuration saved at /home/xiyan/.llama/distributions/local/conda/8b-test-build.yaml ``` + +You can re-build package based on build config +``` +$ cat ~/.llama/distributions/local/conda/8b-instruct-build.yaml +name: 8b-instruct +distribution: local +api_providers: null +image_type: conda + +$ llama stack build --config ~/.llama/distributions/local/conda/8b-instruct-build.yaml + +Successfully setup conda environment. Configuring build... + +... +... + +YAML configuration has been written to ~/.llama/builds/local/conda/8b-instruct.yaml +Target `8b-instruct` built with configuration at ~/.llama/builds/local/conda/8b-instruct.yaml +Build spec configuration saved at ~/.llama/distributions/local/conda/8b-instruct-build.yaml +``` + ### Step 3.3: Configure a distribution You can re-configure this distribution by running: ``` -llama stack configure local --name 8b-instruct +llama stack configure ~/.llama/builds/local/conda/8b-instruct.yaml ``` Here is an example run of how the CLI will guide you to fill the configuration + ``` -$ llama stack configure local --name 8b-instruct +$ llama stack configure ~/.llama/builds/local/conda/8b-instruct.yaml Configuring API: inference (meta-reference) Enter value for model (required): Meta-Llama3.1-8B-Instruct @@ -363,12 +392,12 @@ Now let’s start Llama Stack Distribution Server. You need the YAML configuration file which was written out at the end by the `llama stack build` step. ``` -llama stack run local --name 8b-instruct --port 5000 +llama stack run ~/.llama/builds/local/conda/8b-instruct.yaml --port 5000 ``` You should see the Stack server start and print the APIs that it is supporting, ``` -$ llama stack run local --name 8b-instruct --port 5000 +$ llama stack run ~/.llama/builds/local/conda/8b-instruct.yaml --port 5000 > initializing model parallel with size 1 > initializing ddp with size 1 diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_toolchain/agentic_system/api/api.py index 68ec980e6..b8be54861 100644 --- a/llama_toolchain/agentic_system/api/api.py +++ b/llama_toolchain/agentic_system/api/api.py @@ -41,11 +41,19 @@ class ToolDefinitionCommon(BaseModel): output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) +class SearchEngineType(Enum): + bing = "bing" + brave = "brave" + + @json_schema_type -class BraveSearchToolDefinition(ToolDefinitionCommon): +class SearchToolDefinition(ToolDefinitionCommon): + # NOTE: brave_search is just a placeholder since model always uses + # brave_search as tool call name type: Literal[AgenticSystemTool.brave_search.value] = ( AgenticSystemTool.brave_search.value ) + engine: SearchEngineType = SearchEngineType.brave remote_execution: Optional[RestAPIExecutionConfig] = None @@ -163,7 +171,7 @@ class MemoryToolDefinition(ToolDefinitionCommon): AgenticSystemToolDefinition = Annotated[ Union[ - BraveSearchToolDefinition, + SearchToolDefinition, WolframAlphaToolDefinition, PhotogenToolDefinition, CodeInterpreterToolDefinition, diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index fadb78182..b47e402f0 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -134,7 +134,7 @@ async def run_main(host: str, port: int): api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ - BraveSearchToolDefinition(), + SearchToolDefinition(engine=SearchEngineType.bing), WolframAlphaToolDefinition(), CodeInterpreterToolDefinition(), ] diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 4d38e0032..36c3d19e8 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -710,7 +710,7 @@ class ChatAgent(ShieldRunnerMixin): def _get_tools(self) -> List[ToolDefinition]: ret = [] for t in self.agent_config.tools: - if isinstance(t, BraveSearchToolDefinition): + if isinstance(t, SearchToolDefinition): ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search)) elif isinstance(t, WolframAlphaToolDefinition): ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha)) diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 09fbfdde5..9caa3a75b 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -15,9 +15,9 @@ from llama_toolchain.memory.api import Memory from llama_toolchain.safety.api import Safety from llama_toolchain.agentic_system.api import * # noqa: F403 from llama_toolchain.tools.builtin import ( - BraveSearchTool, CodeInterpreterTool, PhotogenTool, + SearchTool, WolframAlphaTool, ) from llama_toolchain.tools.safety import with_safety @@ -62,17 +62,19 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): if not key: raise ValueError("Wolfram API key not defined in config") tool = WolframAlphaTool(key) - elif isinstance(tool_defn, BraveSearchToolDefinition): - key = self.config.brave_search_api_key + elif isinstance(tool_defn, SearchToolDefinition): + key = None + if tool_defn.engine == SearchEngineType.brave: + key = self.config.brave_search_api_key + elif tool_defn.engine == SearchEngineType.bing: + key = self.config.bing_search_api_key if not key: - raise ValueError("Brave API key not defined in config") - tool = BraveSearchTool(key) + raise ValueError("API key not defined in config") + tool = SearchTool(tool_defn.engine, key) elif isinstance(tool_defn, CodeInterpreterToolDefinition): tool = CodeInterpreterTool() elif isinstance(tool_defn, PhotogenToolDefinition): - tool = PhotogenTool( - dump_dir=tempfile.mkdtemp(), - ) + tool = PhotogenTool(dump_dir=tempfile.mkdtemp()) else: continue diff --git a/llama_toolchain/agentic_system/meta_reference/config.py b/llama_toolchain/agentic_system/meta_reference/config.py index 367ab17a5..f1a92f2e7 100644 --- a/llama_toolchain/agentic_system/meta_reference/config.py +++ b/llama_toolchain/agentic_system/meta_reference/config.py @@ -11,4 +11,5 @@ from pydantic import BaseModel class MetaReferenceImplConfig(BaseModel): brave_search_api_key: Optional[str] = None + bing_search_api_key: Optional[str] = None wolfram_api_key: Optional[str] = None diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py index 22bd4071f..d2d7df6d0 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_toolchain/cli/stack/build.py @@ -8,6 +8,7 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.core.datatypes import * # noqa: F403 +import yaml def parse_api_provider_tuples( @@ -47,55 +48,45 @@ class StackBuild(Subcommand): self.parser.set_defaults(func=self._run_stack_build_command) def _add_arguments(self): - from llama_toolchain.core.distribution_registry import available_distribution_specs - from llama_toolchain.core.package import ( - BuildType, + from llama_toolchain.core.distribution_registry import ( + available_distribution_specs, ) + from llama_toolchain.core.package import ImageType allowed_ids = [d.distribution_type for d in available_distribution_specs()] self.parser.add_argument( - "distribution", + "--config", type=str, - help="Distribution to build (either \"adhoc\" OR one of: {})".format(allowed_ids), - ) - self.parser.add_argument( - "api_providers", - nargs='?', - help="Comma separated list of (api=provider) tuples", + help="Path to a config file to use for the build", ) - self.parser.add_argument( - "--name", - type=str, - help="Name of the build target (image, conda env)", - required=True, - ) - self.parser.add_argument( - "--type", - type=str, - default="conda_env", - choices=[v.value for v in BuildType], - ) + def _run_stack_build_command_from_build_config( + self, build_config: BuildConfig + ) -> None: + import json + import os - def _run_stack_build_command(self, args: argparse.Namespace) -> None: + from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR + from llama_toolchain.common.serialize import EnumEncoder from llama_toolchain.core.distribution_registry import resolve_distribution_spec - from llama_toolchain.core.package import ( - ApiInput, - BuildType, - build_package, - ) + from llama_toolchain.core.package import ApiInput, build_package, ImageType + from termcolor import cprint api_inputs = [] - if args.distribution == "adhoc": - if not args.api_providers: - self.parser.error("You must specify API providers with (api=provider,...) for building an adhoc distribution") + if build_config.distribution == "adhoc": + if not build_config.api_providers: + self.parser.error( + "You must specify API providers with (api=provider,...) for building an adhoc distribution" + ) return - parsed = parse_api_provider_tuples(args.api_providers, self.parser) + parsed = parse_api_provider_tuples(build_config.api_providers, self.parser) for api, provider_spec in parsed.items(): for dep in provider_spec.api_dependencies: if dep not in parsed: - self.parser.error(f"API {api} needs dependency {dep} provided also") + self.parser.error( + f"API {api} needs dependency {dep} provided also" + ) return api_inputs.append( @@ -106,13 +97,17 @@ class StackBuild(Subcommand): ) docker_image = None else: - if args.api_providers: - self.parser.error("You cannot specify API providers for pre-registered distributions") + if build_config.api_providers: + self.parser.error( + "You cannot specify API providers for pre-registered distributions" + ) return - dist = resolve_distribution_spec(args.distribution) + dist = resolve_distribution_spec(build_config.distribution) if dist is None: - self.parser.error(f"Could not find distribution {args.distribution}") + self.parser.error( + f"Could not find distribution {build_config.distribution}" + ) return for api, provider_type in dist.providers.items(): @@ -126,8 +121,41 @@ class StackBuild(Subcommand): build_package( api_inputs, - build_type=BuildType(args.type), - name=args.name, - distribution_type=args.distribution, + image_type=ImageType(build_config.image_type), + name=build_config.name, + distribution_type=build_config.distribution, docker_image=docker_image, ) + + # save build.yaml spec for building same distribution again + build_dir = ( + DISTRIBS_BASE_DIR / build_config.distribution / build_config.image_type + ) + os.makedirs(build_dir, exist_ok=True) + build_file_path = build_dir / f"{build_config.name}-build.yaml" + + with open(build_file_path, "w") as f: + to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder)) + f.write(yaml.dump(to_write, sort_keys=False)) + + cprint( + f"Build spec configuration saved at {str(build_file_path)}", + color="green", + ) + + def _run_stack_build_command(self, args: argparse.Namespace) -> None: + from llama_toolchain.common.prompt_for_config import prompt_for_config + from llama_toolchain.core.dynamic import instantiate_class_type + + if args.config: + with open(args.config, "r") as f: + try: + build_config = BuildConfig(**yaml.safe_load(f)) + except Exception as e: + self.parser.error(f"Could not parse config file {args.config}: {e}") + return + self._run_stack_build_command_from_build_config(build_config) + return + + build_config = prompt_for_config(BuildConfig, None) + self._run_stack_build_command_from_build_config(build_config) diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 658380f4d..2edeae7bc 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -9,10 +9,10 @@ import json from pathlib import Path import yaml -from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR +from termcolor import cprint from llama_toolchain.core.datatypes import * # noqa: F403 @@ -34,38 +34,19 @@ class StackConfigure(Subcommand): from llama_toolchain.core.distribution_registry import ( available_distribution_specs, ) - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType allowed_ids = [d.distribution_type for d in available_distribution_specs()] self.parser.add_argument( - "distribution", + "config", type=str, - help='Distribution ("adhoc" or one of: {})'.format(allowed_ids), - ) - self.parser.add_argument( - "--name", - type=str, - help="Name of the build", - required=True, - ) - self.parser.add_argument( - "--type", - type=str, - default="conda_env", - choices=[v.value for v in BuildType], + help="Path to the package config file (e.g. ~/.llama/builds///.yaml)", ) def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType - build_type = BuildType(args.type) - name = args.name - config_file = ( - BUILDS_BASE_DIR - / args.distribution - / build_type.descriptor() - / f"{name}.yaml" - ) + config_file = Path(args.config) if not config_file.exists(): self.parser.error( f"Could not find {config_file}. Please run `llama stack build` first" diff --git a/llama_toolchain/cli/stack/run.py b/llama_toolchain/cli/stack/run.py index 1568ed820..d040cb1f7 100644 --- a/llama_toolchain/cli/stack/run.py +++ b/llama_toolchain/cli/stack/run.py @@ -29,24 +29,12 @@ class StackRun(Subcommand): self.parser.set_defaults(func=self._run_stack_run_cmd) def _add_arguments(self): - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType self.parser.add_argument( - "distribution", + "config", type=str, - help="Distribution whose build you want to start", - ) - self.parser.add_argument( - "--name", - type=str, - help="Name of the build you want to start", - required=True, - ) - self.parser.add_argument( - "--type", - type=str, - default="conda_env", - choices=[v.value for v in BuildType], + help="Path to config file to use for the run", ) self.parser.add_argument( "--port", @@ -63,12 +51,13 @@ class StackRun(Subcommand): def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_with_pty - from llama_toolchain.core.package import BuildType + from llama_toolchain.core.package import ImageType - build_type = BuildType(args.type) - build_dir = BUILDS_BASE_DIR / args.distribution / build_type.descriptor() - path = build_dir / f"{args.name}.yaml" + if not args.config: + self.parser.error("Must specify a config file to run") + return + path = args.config config_file = Path(path) if not config_file.exists(): diff --git a/llama_toolchain/common/deployment_types.py b/llama_toolchain/common/deployment_types.py index 8b67eff0d..af05aaae4 100644 --- a/llama_toolchain/common/deployment_types.py +++ b/llama_toolchain/common/deployment_types.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, Optional +from typing import Any, Dict, Optional from llama_models.llama3.api.datatypes import URL @@ -26,6 +26,6 @@ class RestAPIMethod(Enum): class RestAPIExecutionConfig(BaseModel): url: URL method: RestAPIMethod - params: Optional[Dict[str, str]] = None - headers: Optional[Dict[str, str]] = None - body: Optional[Dict[str, str]] = None + params: Optional[Dict[str, Any]] = None + headers: Optional[Dict[str, Any]] = None + body: Optional[Dict[str, Any]] = None diff --git a/llama_toolchain/core/build_conda_env.sh b/llama_toolchain/core/build_conda_env.sh index e5b1ca539..664f8ad02 100755 --- a/llama_toolchain/core/build_conda_env.sh +++ b/llama_toolchain/core/build_conda_env.sh @@ -19,7 +19,7 @@ fi set -euo pipefail -if [ "$#" -ne 3 ]; then +if [ "$#" -ne 4 ]; then echo "Usage: $0 " >&2 echo "Example: $0 mybuild 'numpy pandas scipy'" >&2 exit 1 @@ -28,7 +28,8 @@ fi distribution_type="$1" build_name="$2" env_name="llamastack-$build_name" -pip_dependencies="$3" +config_file="$3" +pip_dependencies="$4" # Define color codes RED='\033[0;31m' @@ -117,4 +118,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies" printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n" -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type conda_env +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $config_file diff --git a/llama_toolchain/core/build_container.sh b/llama_toolchain/core/build_container.sh index e5349cd08..3d2fc466f 100755 --- a/llama_toolchain/core/build_container.sh +++ b/llama_toolchain/core/build_container.sh @@ -4,7 +4,7 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} -if [ "$#" -ne 4 ]; then +if [ "$#" -ne 5 ]; then echo "Usage: $0 echo "Example: $0 distribution_type my-fastapi-app python:3.9-slim 'fastapi uvicorn' exit 1 @@ -14,7 +14,8 @@ distribution_type=$1 build_name="$2" image_name="llamastack-$build_name" docker_base=$3 -pip_dependencies=$4 +config_file=$4 +pip_dependencies=$5 # Define color codes RED='\033[0;31m' @@ -110,4 +111,4 @@ set +x printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}" echo "You can run it with: podman run -p 8000:8000 $image_name" -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type container +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $config_file diff --git a/llama_toolchain/core/datatypes.py b/llama_toolchain/core/datatypes.py index 0d946027c..cd9fc9dcf 100644 --- a/llama_toolchain/core/datatypes.py +++ b/llama_toolchain/core/datatypes.py @@ -189,3 +189,19 @@ Provider configurations for each of the APIs provided by this package. This incl the dependencies of these providers as well. """, ) + + +@json_schema_type +class BuildConfig(BaseModel): + name: str + distribution: str = Field( + default="local", description="Type of distribution to build (adhoc | {})" + ) + api_providers: Optional[str] = Field( + default_factory=list, + description="List of API provider names to build", + ) + image_type: str = Field( + default="conda", + description="Type of package to build (conda | container)", + ) diff --git a/llama_toolchain/core/package.py b/llama_toolchain/core/package.py index ab4346a71..6e9eb048d 100644 --- a/llama_toolchain/core/package.py +++ b/llama_toolchain/core/package.py @@ -12,24 +12,21 @@ from typing import List, Optional import pkg_resources import yaml -from pydantic import BaseModel - -from termcolor import cprint from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR from llama_toolchain.common.exec import run_with_pty from llama_toolchain.common.serialize import EnumEncoder +from pydantic import BaseModel + +from termcolor import cprint from llama_toolchain.core.datatypes import * # noqa: F403 from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES -class BuildType(Enum): - container = "container" - conda_env = "conda_env" - - def descriptor(self) -> str: - return "docker" if self == self.container else "conda" +class ImageType(Enum): + docker = "docker" + conda = "conda" class Dependencies(BaseModel): @@ -44,7 +41,7 @@ class ApiInput(BaseModel): def build_package( api_inputs: List[ApiInput], - build_type: BuildType, + image_type: ImageType, name: str, distribution_type: Optional[str] = None, docker_image: Optional[str] = None, @@ -52,7 +49,7 @@ def build_package( if not distribution_type: distribution_type = "adhoc" - build_dir = BUILDS_BASE_DIR / distribution_type / build_type.descriptor() + build_dir = BUILDS_BASE_DIR / distribution_type / image_type.value os.makedirs(build_dir, exist_ok=True) package_name = name.replace("::", "-") @@ -106,14 +103,14 @@ def build_package( ) c.distribution_type = distribution_type - c.docker_image = package_name if build_type == BuildType.container else None - c.conda_env = package_name if build_type == BuildType.conda_env else None + c.docker_image = package_name if image_type == ImageType.docker else None + c.conda_env = package_name if image_type == ImageType.conda else None with open(package_file, "w") as f: to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder)) f.write(yaml.dump(to_write, sort_keys=False)) - if build_type == BuildType.container: + if image_type == ImageType.docker: script = pkg_resources.resource_filename( "llama_toolchain", "core/build_container.sh" ) @@ -122,6 +119,7 @@ def build_package( distribution_type, package_name, package_deps.docker_image, + str(package_file), " ".join(package_deps.pip_packages), ] else: @@ -132,6 +130,7 @@ def build_package( script, distribution_type, package_name, + str(package_file), " ".join(package_deps.pip_packages), ] diff --git a/llama_toolchain/stack.py b/llama_toolchain/stack.py index 875bc5802..1e2976ab3 100644 --- a/llama_toolchain/stack.py +++ b/llama_toolchain/stack.py @@ -15,6 +15,7 @@ from llama_toolchain.telemetry.api import * # noqa: F403 from llama_toolchain.post_training.api import * # noqa: F403 from llama_toolchain.reward_scoring.api import * # noqa: F403 from llama_toolchain.synthetic_data_generation.api import * # noqa: F403 +from llama_toolchain.safety.api import * # noqa: F403 class LlamaStack( @@ -22,6 +23,7 @@ class LlamaStack( BatchInference, AgenticSystem, RewardScoring, + Safety, SyntheticDataGeneration, Datasets, Telemetry, diff --git a/llama_toolchain/tools/builtin.py b/llama_toolchain/tools/builtin.py index 3a53e2e26..56fda3723 100644 --- a/llama_toolchain/tools/builtin.py +++ b/llama_toolchain/tools/builtin.py @@ -83,14 +83,72 @@ class PhotogenTool(SingleMessageBuiltinTool): raise NotImplementedError() -class BraveSearchTool(SingleMessageBuiltinTool): - def __init__(self, api_key: str) -> None: +class SearchTool(SingleMessageBuiltinTool): + def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None: self.api_key = api_key + if engine == SearchEngineType.bing: + self.engine = BingSearch(api_key, **kwargs) + elif engine == SearchEngineType.brave: + self.engine = BraveSearch(api_key, **kwargs) + else: + raise ValueError(f"Unknown search engine: {engine}") def get_name(self) -> str: return BuiltinTool.brave_search.value async def run_impl(self, query: str) -> str: + return await self.engine.search(query) + + +class BingSearch: + def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None: + self.api_key = api_key + self.top_k = top_k + + async def search(self, query: str) -> str: + url = "https://api.bing.microsoft.com/v7.0/search" + headers = { + "Ocp-Apim-Subscription-Key": self.api_key, + } + params = { + "count": self.top_k, + "textDecorations": True, + "textFormat": "HTML", + "q": query, + } + + response = requests.get(url=url, params=params, headers=headers) + response.raise_for_status() + clean = self._clean_response(response.json()) + return json.dumps(clean) + + def _clean_response(self, search_response): + clean_response = [] + query = search_response["queryContext"]["originalQuery"] + if "webPages" in search_response: + pages = search_response["webPages"]["value"] + for p in pages: + selected_keys = {"name", "url", "snippet"} + clean_response.append( + {k: v for k, v in p.items() if k in selected_keys} + ) + if "news" in search_response: + clean_news = [] + news = search_response["news"]["value"] + for n in news: + selected_keys = {"name", "url", "description"} + clean_news.append({k: v for k, v in n.items() if k in selected_keys}) + + clean_response.append(clean_news) + + return {"query": query, "top_k": clean_response} + + +class BraveSearch: + def __init__(self, api_key: str) -> None: + self.api_key = api_key + + async def search(self, query: str) -> str: url = "https://api.search.brave.com/res/v1/web/search" headers = { "X-Subscription-Token": self.api_key, diff --git a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html index 38efd4ca3..211290ce1 100644 --- a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html +++ b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-07 15:23:29.488676" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-10 16:42:15.870336" }, "servers": [ { @@ -51,7 +51,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/BatchChatCompletionRequest" + "$ref": "#/components/schemas/BatchChatCompletionRequestWrapper" } } }, @@ -81,7 +81,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/BatchCompletionRequest" + "$ref": "#/components/schemas/BatchCompletionRequestWrapper" } } }, @@ -141,7 +141,7 @@ "200": { "description": "SSE-stream of these events.", "content": { - "application/json": { + "text/event-stream": { "schema": { "$ref": "#/components/schemas/ChatCompletionResponseStreamChunk" } @@ -157,7 +157,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ChatCompletionRequest" + "$ref": "#/components/schemas/ChatCompletionRequestWrapper" } } }, @@ -187,7 +187,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CompletionRequest" + "$ref": "#/components/schemas/CompletionRequestWrapper" } } }, @@ -277,7 +277,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AgenticSystemTurnCreateRequest" + "$ref": "#/components/schemas/AgenticSystemTurnCreateRequestWrapper" } } }, @@ -300,7 +300,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateDatasetRequest" + "$ref": "#/components/schemas/CreateDatasetRequestWrapper" } } }, @@ -330,7 +330,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateExperimentRequest" + "$ref": "#/components/schemas/CreateExperimentRequestWrapper" } } }, @@ -390,7 +390,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateRunRequest" + "$ref": "#/components/schemas/CreateRunRequestWrapper" } } }, @@ -572,7 +572,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequest" + "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequestWrapper" } } }, @@ -602,7 +602,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateSummarizationRequest" + "$ref": "#/components/schemas/EvaluateSummarizationRequestWrapper" } } }, @@ -632,7 +632,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateTextGenerationRequest" + "$ref": "#/components/schemas/EvaluateTextGenerationRequestWrapper" } } }, @@ -1024,7 +1024,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/LogSearchRequest" + "$ref": "#/components/schemas/LogSearchRequestWrapper" } } }, @@ -1312,7 +1312,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/LogMessagesRequest" + "$ref": "#/components/schemas/LogMessagesRequestWrapper" } } }, @@ -1335,7 +1335,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/LogMetricsRequest" + "$ref": "#/components/schemas/LogMetricsRequestWrapper" } } }, @@ -1365,7 +1365,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/PostTrainingRLHFRequest" + "$ref": "#/components/schemas/PostTrainingRLHFRequestWrapper" } } }, @@ -1425,7 +1425,37 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/RewardScoringRequest" + "$ref": "#/components/schemas/RewardScoringRequestWrapper" + } + } + }, + "required": true + } + } + }, + "/safety/run_shields": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunShieldResponse" + } + } + } + } + }, + "tags": [ + "Safety" + ], + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunShieldRequestWrapper" } } }, @@ -1455,7 +1485,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/PostTrainingSFTRequest" + "$ref": "#/components/schemas/PostTrainingSFTRequestWrapper" } } }, @@ -1485,7 +1515,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/SyntheticDataGenerationRequest" + "$ref": "#/components/schemas/SyntheticDataGenerationRequestWrapper" } } }, @@ -1538,7 +1568,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateExperimentRequest" + "$ref": "#/components/schemas/UpdateExperimentRequestWrapper" } } }, @@ -1568,7 +1598,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateRunRequest" + "$ref": "#/components/schemas/UpdateRunRequestWrapper" } } }, @@ -1598,7 +1628,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UploadArtifactRequest" + "$ref": "#/components/schemas/UploadArtifactRequestWrapper" } } }, @@ -2020,6 +2050,18 @@ "content" ] }, + "BatchChatCompletionRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/BatchChatCompletionRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "BatchChatCompletionResponse": { "type": "object", "properties": { @@ -2076,6 +2118,18 @@ "content_batch" ] }, + "BatchCompletionRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/BatchCompletionRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "BatchCompletionResponse": { "type": "object", "properties": { @@ -2174,6 +2228,18 @@ "messages" ] }, + "ChatCompletionRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/ChatCompletionRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "ChatCompletionResponseEvent": { "type": "object", "properties": { @@ -2316,6 +2382,18 @@ "content" ] }, + "CompletionRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/CompletionRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "CompletionResponseStreamChunk": { "type": "object", "properties": { @@ -2361,7 +2439,7 @@ "items": { "oneOf": [ { - "$ref": "#/components/schemas/BraveSearchToolDefinition" + "$ref": "#/components/schemas/SearchToolDefinition" }, { "$ref": "#/components/schemas/WolframAlphaToolDefinition" @@ -2576,34 +2654,6 @@ "instructions" ] }, - "BraveSearchToolDefinition": { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ShieldDefinition" - } - }, - "output_shields": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ShieldDefinition" - } - }, - "type": { - "type": "string", - "const": "brave_search" - }, - "remote_execution": { - "$ref": "#/components/schemas/RestAPIExecutionConfig" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, "BuiltinShield": { "type": "string", "enum": [ @@ -2737,19 +2787,76 @@ "params": { "type": "object", "additionalProperties": { - "type": "string" + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] } }, "headers": { "type": "object", "additionalProperties": { - "type": "string" + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] } }, "body": { "type": "object", "additionalProperties": { - "type": "string" + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] } } }, @@ -2768,6 +2875,42 @@ "DELETE" ] }, + "SearchToolDefinition": { + "type": "object", + "properties": { + "input_shields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ShieldDefinition" + } + }, + "output_shields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ShieldDefinition" + } + }, + "type": { + "type": "string", + "const": "brave_search" + }, + "engine": { + "type": "string", + "enum": [ + "bing", + "brave" + ] + }, + "remote_execution": { + "$ref": "#/components/schemas/RestAPIExecutionConfig" + } + }, + "additionalProperties": false, + "required": [ + "type", + "engine" + ] + }, "ShieldDefinition": { "type": "object", "properties": { @@ -2911,7 +3054,7 @@ "items": { "oneOf": [ { - "$ref": "#/components/schemas/BraveSearchToolDefinition" + "$ref": "#/components/schemas/SearchToolDefinition" }, { "$ref": "#/components/schemas/WolframAlphaToolDefinition" @@ -3181,6 +3324,18 @@ "mime_type" ] }, + "AgenticSystemTurnCreateRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/AgenticSystemTurnCreateRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "AgenticSystemTurnResponseEvent": { "type": "object", "properties": { @@ -3752,6 +3907,18 @@ "json" ] }, + "CreateDatasetRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/CreateDatasetRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "CreateExperimentRequest": { "type": "object", "properties": { @@ -3789,6 +3956,18 @@ "name" ] }, + "CreateExperimentRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/CreateExperimentRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "Experiment": { "type": "object", "properties": { @@ -4061,6 +4240,18 @@ "experiment_id" ] }, + "CreateRunRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/CreateRunRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "Run": { "type": "object", "properties": { @@ -4273,6 +4464,18 @@ ], "title": "Request to evaluate question answering." }, + "EvaluateQuestionAnsweringRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "EvaluationJob": { "type": "object", "properties": { @@ -4321,6 +4524,18 @@ ], "title": "Request to evaluate summarization." }, + "EvaluateSummarizationRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/EvaluateSummarizationRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "EvaluateTextGenerationRequest": { "type": "object", "properties": { @@ -4358,6 +4573,18 @@ ], "title": "Request to evaluate text generation." }, + "EvaluateTextGenerationRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/EvaluateTextGenerationRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "GetAgenticSystemSessionRequest": { "type": "object", "properties": { @@ -4643,6 +4870,18 @@ "query" ] }, + "LogSearchRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/LogSearchRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "Log": { "type": "object", "properties": { @@ -4902,6 +5141,18 @@ "logs" ] }, + "LogMessagesRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/LogMessagesRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "LogMetricsRequest": { "type": "object", "properties": { @@ -4921,6 +5172,18 @@ "metrics" ] }, + "LogMetricsRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/LogMetricsRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "DPOAlignmentConfig": { "type": "object", "properties": { @@ -5109,6 +5372,18 @@ "fsdp_cpu_offload" ] }, + "PostTrainingRLHFRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/PostTrainingRLHFRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "QueryDocumentsRequest": { "type": "object", "properties": { @@ -5277,6 +5552,18 @@ ], "title": "Request to score a reward function. A list of prompts and a list of responses per prompt." }, + "RewardScoringRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/RewardScoringRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "RewardScoringResponse": { "type": "object", "properties": { @@ -5357,6 +5644,68 @@ "score" ] }, + "RunShieldRequest": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": { + "oneOf": [ + { + "$ref": "#/components/schemas/UserMessage" + }, + { + "$ref": "#/components/schemas/SystemMessage" + }, + { + "$ref": "#/components/schemas/ToolResponseMessage" + }, + { + "$ref": "#/components/schemas/CompletionMessage" + } + ] + } + }, + "shields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ShieldDefinition" + } + } + }, + "additionalProperties": false, + "required": [ + "messages", + "shields" + ] + }, + "RunShieldRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/RunShieldRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, + "RunShieldResponse": { + "type": "object", + "properties": { + "responses": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ShieldResponse" + } + } + }, + "additionalProperties": false, + "required": [ + "responses" + ] + }, "DoraFinetuningConfig": { "type": "object", "properties": { @@ -5562,6 +5911,18 @@ "alpha" ] }, + "PostTrainingSFTRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/PostTrainingSFTRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "SyntheticDataGenerationRequest": { "type": "object", "properties": { @@ -5607,6 +5968,18 @@ ], "title": "Request to generate synthetic data. A small batch of prompts and a filtering function" }, + "SyntheticDataGenerationRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/SyntheticDataGenerationRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "SyntheticDataGenerationResponse": { "type": "object", "properties": { @@ -5707,6 +6080,18 @@ "experiment_id" ] }, + "UpdateExperimentRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/UpdateExperimentRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "UpdateRunRequest": { "type": "object", "properties": { @@ -5751,6 +6136,18 @@ "run_id" ] }, + "UpdateRunRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/UpdateRunRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "UploadArtifactRequest": { "type": "object", "properties": { @@ -5800,6 +6197,18 @@ "artifact_type", "content" ] + }, + "UploadArtifactRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/UploadArtifactRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] } }, "responses": {} @@ -5810,26 +6219,17 @@ } ], "tags": [ - { - "name": "Telemetry" - }, - { - "name": "Evaluations" - }, - { - "name": "AgenticSystem" - }, - { - "name": "Inference" - }, { "name": "BatchInference" }, { - "name": "PostTraining" + "name": "Safety" }, { - "name": "Datasets" + "name": "Telemetry" + }, + { + "name": "RewardScoring" }, { "name": "Memory" @@ -5838,7 +6238,19 @@ "name": "SyntheticDataGeneration" }, { - "name": "RewardScoring" + "name": "Inference" + }, + { + "name": "Evaluations" + }, + { + "name": "PostTraining" + }, + { + "name": "Datasets" + }, + { + "name": "AgenticSystem" }, { "name": "BatchChatCompletionRequest", @@ -5896,6 +6308,10 @@ "name": "UserMessage", "description": "" }, + { + "name": "BatchChatCompletionRequestWrapper", + "description": "" + }, { "name": "BatchChatCompletionResponse", "description": "" @@ -5904,6 +6320,10 @@ "name": "BatchCompletionRequest", "description": "" }, + { + "name": "BatchCompletionRequestWrapper", + "description": "" + }, { "name": "BatchCompletionResponse", "description": "" @@ -5920,6 +6340,10 @@ "name": "ChatCompletionRequest", "description": "" }, + { + "name": "ChatCompletionRequestWrapper", + "description": "" + }, { "name": "ChatCompletionResponseEvent", "description": "Chat completion response event.\n\n" @@ -5948,6 +6372,10 @@ "name": "CompletionRequest", "description": "" }, + { + "name": "CompletionRequestWrapper", + "description": "" + }, { "name": "CompletionResponseStreamChunk", "description": "streamed completion response.\n\n" @@ -5956,10 +6384,6 @@ "name": "AgentConfig", "description": "" }, - { - "name": "BraveSearchToolDefinition", - "description": "" - }, { "name": "BuiltinShield", "description": "" @@ -5988,6 +6412,10 @@ "name": "RestAPIMethod", "description": "" }, + { + "name": "SearchToolDefinition", + "description": "" + }, { "name": "ShieldDefinition", "description": "" @@ -6024,6 +6452,10 @@ "name": "Attachment", "description": "" }, + { + "name": "AgenticSystemTurnCreateRequestWrapper", + "description": "" + }, { "name": "AgenticSystemTurnResponseEvent", "description": "Streamed agent execution response.\n\n" @@ -6092,10 +6524,18 @@ "name": "TrainEvalDatasetColumnType", "description": "" }, + { + "name": "CreateDatasetRequestWrapper", + "description": "" + }, { "name": "CreateExperimentRequest", "description": "" }, + { + "name": "CreateExperimentRequestWrapper", + "description": "" + }, { "name": "Experiment", "description": "" @@ -6116,6 +6556,10 @@ "name": "CreateRunRequest", "description": "" }, + { + "name": "CreateRunRequestWrapper", + "description": "" + }, { "name": "Run", "description": "" @@ -6156,6 +6600,10 @@ "name": "EvaluateQuestionAnsweringRequest", "description": "Request to evaluate question answering.\n\n" }, + { + "name": "EvaluateQuestionAnsweringRequestWrapper", + "description": "" + }, { "name": "EvaluationJob", "description": "" @@ -6164,10 +6612,18 @@ "name": "EvaluateSummarizationRequest", "description": "Request to evaluate summarization.\n\n" }, + { + "name": "EvaluateSummarizationRequestWrapper", + "description": "" + }, { "name": "EvaluateTextGenerationRequest", "description": "Request to evaluate text generation.\n\n" }, + { + "name": "EvaluateTextGenerationRequestWrapper", + "description": "" + }, { "name": "GetAgenticSystemSessionRequest", "description": "" @@ -6212,6 +6668,10 @@ "name": "LogSearchRequest", "description": "" }, + { + "name": "LogSearchRequestWrapper", + "description": "" + }, { "name": "Log", "description": "" @@ -6252,10 +6712,18 @@ "name": "LogMessagesRequest", "description": "" }, + { + "name": "LogMessagesRequestWrapper", + "description": "" + }, { "name": "LogMetricsRequest", "description": "" }, + { + "name": "LogMetricsRequestWrapper", + "description": "" + }, { "name": "DPOAlignmentConfig", "description": "" @@ -6276,6 +6744,10 @@ "name": "TrainingConfig", "description": "" }, + { + "name": "PostTrainingRLHFRequestWrapper", + "description": "" + }, { "name": "QueryDocumentsRequest", "description": "" @@ -6292,6 +6764,10 @@ "name": "RewardScoringRequest", "description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n" }, + { + "name": "RewardScoringRequestWrapper", + "description": "" + }, { "name": "RewardScoringResponse", "description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" @@ -6304,6 +6780,18 @@ "name": "ScoredMessage", "description": "" }, + { + "name": "RunShieldRequest", + "description": "" + }, + { + "name": "RunShieldRequestWrapper", + "description": "" + }, + { + "name": "RunShieldResponse", + "description": "" + }, { "name": "DoraFinetuningConfig", "description": "" @@ -6324,10 +6812,18 @@ "name": "QLoraFinetuningConfig", "description": "" }, + { + "name": "PostTrainingSFTRequestWrapper", + "description": "" + }, { "name": "SyntheticDataGenerationRequest", "description": "Request to generate synthetic data. A small batch of prompts and a filtering function\n\n" }, + { + "name": "SyntheticDataGenerationRequestWrapper", + "description": "" + }, { "name": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" @@ -6340,13 +6836,25 @@ "name": "UpdateExperimentRequest", "description": "" }, + { + "name": "UpdateExperimentRequestWrapper", + "description": "" + }, { "name": "UpdateRunRequest", "description": "" }, + { + "name": "UpdateRunRequestWrapper", + "description": "" + }, { "name": "UploadArtifactRequest", "description": "" + }, + { + "name": "UploadArtifactRequestWrapper", + "description": "" } ], "x-tagGroups": [ @@ -6361,6 +6869,7 @@ "Memory", "PostTraining", "RewardScoring", + "Safety", "SyntheticDataGeneration", "Telemetry" ] @@ -6373,6 +6882,7 @@ "AgenticSystemSessionCreateResponse", "AgenticSystemStepResponse", "AgenticSystemTurnCreateRequest", + "AgenticSystemTurnCreateRequestWrapper", "AgenticSystemTurnResponseEvent", "AgenticSystemTurnResponseStepCompletePayload", "AgenticSystemTurnResponseStepProgressPayload", @@ -6384,15 +6894,17 @@ "ArtifactType", "Attachment", "BatchChatCompletionRequest", + "BatchChatCompletionRequestWrapper", "BatchChatCompletionResponse", "BatchCompletionRequest", + "BatchCompletionRequestWrapper", "BatchCompletionResponse", - "BraveSearchToolDefinition", "BuiltinShield", "BuiltinTool", "CancelEvaluationJobRequest", "CancelTrainingJobRequest", "ChatCompletionRequest", + "ChatCompletionRequestWrapper", "ChatCompletionResponseEvent", "ChatCompletionResponseEventType", "ChatCompletionResponseStreamChunk", @@ -6400,13 +6912,17 @@ "CodeInterpreterToolDefinition", "CompletionMessage", "CompletionRequest", + "CompletionRequestWrapper", "CompletionResponseStreamChunk", "CreateAgenticSystemRequest", "CreateAgenticSystemSessionRequest", "CreateDatasetRequest", + "CreateDatasetRequestWrapper", "CreateExperimentRequest", + "CreateExperimentRequestWrapper", "CreateMemoryBankRequest", "CreateRunRequest", + "CreateRunRequestWrapper", "DPOAlignmentConfig", "DeleteAgenticSystemRequest", "DeleteAgenticSystemSessionRequest", @@ -6418,8 +6934,11 @@ "EmbeddingsRequest", "EmbeddingsResponse", "EvaluateQuestionAnsweringRequest", + "EvaluateQuestionAnsweringRequestWrapper", "EvaluateSummarizationRequest", + "EvaluateSummarizationRequestWrapper", "EvaluateTextGenerationRequest", + "EvaluateTextGenerationRequestWrapper", "EvaluationJob", "EvaluationJobArtifactsResponse", "EvaluationJobLogStream", @@ -6435,8 +6954,11 @@ "ListArtifactsRequest", "Log", "LogMessagesRequest", + "LogMessagesRequestWrapper", "LogMetricsRequest", + "LogMetricsRequestWrapper", "LogSearchRequest", + "LogSearchRequestWrapper", "LoraFinetuningConfig", "MemoryBank", "MemoryBankDocument", @@ -6451,7 +6973,9 @@ "PostTrainingJobStatus", "PostTrainingJobStatusResponse", "PostTrainingRLHFRequest", + "PostTrainingRLHFRequestWrapper", "PostTrainingSFTRequest", + "PostTrainingSFTRequestWrapper", "QLoraFinetuningConfig", "QueryDocumentsRequest", "QueryDocumentsResponse", @@ -6459,18 +6983,24 @@ "RestAPIExecutionConfig", "RestAPIMethod", "RewardScoringRequest", + "RewardScoringRequestWrapper", "RewardScoringResponse", "Run", + "RunShieldRequest", + "RunShieldRequestWrapper", + "RunShieldResponse", "SamplingParams", "SamplingStrategy", "ScoredDialogGenerations", "ScoredMessage", + "SearchToolDefinition", "Session", "ShieldCallStep", "ShieldDefinition", "ShieldResponse", "StopReason", "SyntheticDataGenerationRequest", + "SyntheticDataGenerationRequestWrapper", "SyntheticDataGenerationResponse", "SystemMessage", "TokenLogProbs", @@ -6491,8 +7021,11 @@ "URL", "UpdateDocumentsRequest", "UpdateExperimentRequest", + "UpdateExperimentRequestWrapper", "UpdateRunRequest", + "UpdateRunRequestWrapper", "UploadArtifactRequest", + "UploadArtifactRequestWrapper", "UserMessage", "WolframAlphaToolDefinition" ] diff --git a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml index 11f092549..322645813 100644 --- a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml +++ b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml @@ -25,7 +25,7 @@ components: tools: items: oneOf: - - $ref: '#/components/schemas/BraveSearchToolDefinition' + - $ref: '#/components/schemas/SearchToolDefinition' - $ref: '#/components/schemas/WolframAlphaToolDefinition' - $ref: '#/components/schemas/PhotogenToolDefinition' - $ref: '#/components/schemas/CodeInterpreterToolDefinition' @@ -218,7 +218,7 @@ components: tools: items: oneOf: - - $ref: '#/components/schemas/BraveSearchToolDefinition' + - $ref: '#/components/schemas/SearchToolDefinition' - $ref: '#/components/schemas/WolframAlphaToolDefinition' - $ref: '#/components/schemas/PhotogenToolDefinition' - $ref: '#/components/schemas/CodeInterpreterToolDefinition' @@ -346,6 +346,14 @@ components: - session_id - messages type: object + AgenticSystemTurnCreateRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/AgenticSystemTurnCreateRequest' + required: + - request + type: object AgenticSystemTurnResponseEvent: additionalProperties: false properties: @@ -566,6 +574,14 @@ components: - model - messages_batch type: object + BatchChatCompletionRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/BatchChatCompletionRequest' + required: + - request + type: object BatchChatCompletionResponse: additionalProperties: false properties: @@ -601,6 +617,14 @@ components: - model - content_batch type: object + BatchCompletionRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/BatchCompletionRequest' + required: + - request + type: object BatchCompletionResponse: additionalProperties: false properties: @@ -611,25 +635,6 @@ components: required: - completion_message_batch type: object - BraveSearchToolDefinition: - additionalProperties: false - properties: - input_shields: - items: - $ref: '#/components/schemas/ShieldDefinition' - type: array - output_shields: - items: - $ref: '#/components/schemas/ShieldDefinition' - type: array - remote_execution: - $ref: '#/components/schemas/RestAPIExecutionConfig' - type: - const: brave_search - type: string - required: - - type - type: object BuiltinShield: enum: - llama_guard @@ -696,6 +701,14 @@ components: - model - messages type: object + ChatCompletionRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/ChatCompletionRequest' + required: + - request + type: object ChatCompletionResponseEvent: additionalProperties: false properties: @@ -804,6 +817,14 @@ components: - model - content type: object + CompletionRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/CompletionRequest' + required: + - request + type: object CompletionResponseStreamChunk: additionalProperties: false properties: @@ -850,6 +871,14 @@ components: - dataset title: Request to create a dataset. type: object + CreateDatasetRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/CreateDatasetRequest' + required: + - request + type: object CreateExperimentRequest: additionalProperties: false properties: @@ -868,6 +897,14 @@ components: required: - name type: object + CreateExperimentRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/CreateExperimentRequest' + required: + - request + type: object CreateMemoryBankRequest: additionalProperties: false properties: @@ -939,6 +976,14 @@ components: required: - experiment_id type: object + CreateRunRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/CreateRunRequest' + required: + - request + type: object DPOAlignmentConfig: additionalProperties: false properties: @@ -1104,6 +1149,14 @@ components: - metrics title: Request to evaluate question answering. type: object + EvaluateQuestionAnsweringRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/EvaluateQuestionAnsweringRequest' + required: + - request + type: object EvaluateSummarizationRequest: additionalProperties: false properties: @@ -1130,6 +1183,14 @@ components: - metrics title: Request to evaluate summarization. type: object + EvaluateSummarizationRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/EvaluateSummarizationRequest' + required: + - request + type: object EvaluateTextGenerationRequest: additionalProperties: false properties: @@ -1157,6 +1218,14 @@ components: - metrics title: Request to evaluate text generation. type: object + EvaluateTextGenerationRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/EvaluateTextGenerationRequest' + required: + - request + type: object EvaluationJob: additionalProperties: false properties: @@ -1370,6 +1439,14 @@ components: required: - logs type: object + LogMessagesRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/LogMessagesRequest' + required: + - request + type: object LogMetricsRequest: additionalProperties: false properties: @@ -1383,6 +1460,14 @@ components: - run_id - metrics type: object + LogMetricsRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/LogMetricsRequest' + required: + - request + type: object LogSearchRequest: additionalProperties: false properties: @@ -1401,6 +1486,14 @@ components: required: - query type: object + LogSearchRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/LogSearchRequest' + required: + - request + type: object LoraFinetuningConfig: additionalProperties: false properties: @@ -1741,6 +1834,14 @@ components: - logger_config title: Request to finetune a model. type: object + PostTrainingRLHFRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/PostTrainingRLHFRequest' + required: + - request + type: object PostTrainingSFTRequest: additionalProperties: false properties: @@ -1796,6 +1897,14 @@ components: - logger_config title: Request to finetune a model. type: object + PostTrainingSFTRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/PostTrainingSFTRequest' + required: + - request + type: object QLoraFinetuningConfig: additionalProperties: false properties: @@ -1883,17 +1992,35 @@ components: properties: body: additionalProperties: - type: string + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object type: object headers: additionalProperties: - type: string + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object type: object method: $ref: '#/components/schemas/RestAPIMethod' params: additionalProperties: - type: string + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object type: object url: $ref: '#/components/schemas/URL' @@ -1923,6 +2050,14 @@ components: title: Request to score a reward function. A list of prompts and a list of responses per prompt. type: object + RewardScoringRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/RewardScoringRequest' + required: + - request + type: object RewardScoringResponse: additionalProperties: false properties: @@ -1967,6 +2102,43 @@ components: - started_at - metadata type: object + RunShieldRequest: + additionalProperties: false + properties: + messages: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + type: array + shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + required: + - messages + - shields + type: object + RunShieldRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/RunShieldRequest' + required: + - request + type: object + RunShieldResponse: + additionalProperties: false + properties: + responses: + items: + $ref: '#/components/schemas/ShieldResponse' + type: array + required: + - responses + type: object SamplingParams: additionalProperties: false properties: @@ -2025,6 +2197,31 @@ components: - message - score type: object + SearchToolDefinition: + additionalProperties: false + properties: + engine: + enum: + - bing + - brave + type: string + input_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + output_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + remote_execution: + $ref: '#/components/schemas/RestAPIExecutionConfig' + type: + const: brave_search + type: string + required: + - type + - engine + type: object Session: additionalProperties: false properties: @@ -2145,6 +2342,14 @@ components: title: Request to generate synthetic data. A small batch of prompts and a filtering function type: object + SyntheticDataGenerationRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/SyntheticDataGenerationRequest' + required: + - request + type: object SyntheticDataGenerationResponse: additionalProperties: false properties: @@ -2513,6 +2718,14 @@ components: required: - experiment_id type: object + UpdateExperimentRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/UpdateExperimentRequest' + required: + - request + type: object UpdateRunRequest: additionalProperties: false properties: @@ -2536,6 +2749,14 @@ components: required: - run_id type: object + UpdateRunRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/UpdateRunRequest' + required: + - request + type: object UploadArtifactRequest: additionalProperties: false properties: @@ -2564,6 +2785,14 @@ components: - artifact_type - content type: object + UploadArtifactRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/UploadArtifactRequest' + required: + - request + type: object UserMessage: additionalProperties: false properties: @@ -2609,7 +2838,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-09-07 15:23:29.488676" + \ draft and subject to change.\n Generated at 2024-09-10 16:42:15.870336" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -2741,7 +2970,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/AgenticSystemTurnCreateRequest' + $ref: '#/components/schemas/AgenticSystemTurnCreateRequestWrapper' required: true responses: '200': @@ -2798,7 +3027,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/BatchChatCompletionRequest' + $ref: '#/components/schemas/BatchChatCompletionRequestWrapper' required: true responses: '200': @@ -2816,7 +3045,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/BatchCompletionRequest' + $ref: '#/components/schemas/BatchCompletionRequestWrapper' required: true responses: '200': @@ -2834,7 +3063,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/CreateDatasetRequest' + $ref: '#/components/schemas/CreateDatasetRequestWrapper' required: true responses: '200': @@ -2956,7 +3185,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/EvaluateQuestionAnsweringRequest' + $ref: '#/components/schemas/EvaluateQuestionAnsweringRequestWrapper' required: true responses: '200': @@ -2974,7 +3203,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/EvaluateSummarizationRequest' + $ref: '#/components/schemas/EvaluateSummarizationRequestWrapper' required: true responses: '200': @@ -2992,7 +3221,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/EvaluateTextGenerationRequest' + $ref: '#/components/schemas/EvaluateTextGenerationRequestWrapper' required: true responses: '200': @@ -3028,7 +3257,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/UploadArtifactRequest' + $ref: '#/components/schemas/UploadArtifactRequestWrapper' required: true responses: '200': @@ -3046,7 +3275,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/CreateExperimentRequest' + $ref: '#/components/schemas/CreateExperimentRequestWrapper' required: true responses: '200': @@ -3064,7 +3293,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/CreateRunRequest' + $ref: '#/components/schemas/CreateRunRequestWrapper' required: true responses: '200': @@ -3111,7 +3340,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/UpdateExperimentRequest' + $ref: '#/components/schemas/UpdateExperimentRequestWrapper' required: true responses: '200': @@ -3129,12 +3358,12 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ChatCompletionRequest' + $ref: '#/components/schemas/ChatCompletionRequestWrapper' required: true responses: '200': content: - application/json: + text/event-stream: schema: $ref: '#/components/schemas/ChatCompletionResponseStreamChunk' description: SSE-stream of these events. @@ -3147,7 +3376,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/CompletionRequest' + $ref: '#/components/schemas/CompletionRequestWrapper' required: true responses: '200': @@ -3183,7 +3412,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/LogSearchRequest' + $ref: '#/components/schemas/LogSearchRequestWrapper' required: true responses: '200': @@ -3201,7 +3430,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/LogMessagesRequest' + $ref: '#/components/schemas/LogMessagesRequestWrapper' required: true responses: '200': @@ -3442,7 +3671,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/PostTrainingRLHFRequest' + $ref: '#/components/schemas/PostTrainingRLHFRequestWrapper' required: true responses: '200': @@ -3460,7 +3689,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/PostTrainingSFTRequest' + $ref: '#/components/schemas/PostTrainingSFTRequestWrapper' required: true responses: '200': @@ -3478,7 +3707,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RewardScoringRequest' + $ref: '#/components/schemas/RewardScoringRequestWrapper' required: true responses: '200': @@ -3496,7 +3725,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/LogMetricsRequest' + $ref: '#/components/schemas/LogMetricsRequestWrapper' required: true responses: '200': @@ -3527,7 +3756,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/UpdateRunRequest' + $ref: '#/components/schemas/UpdateRunRequestWrapper' required: true responses: '200': @@ -3538,6 +3767,24 @@ paths: description: OK tags: - Telemetry + /safety/run_shields: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RunShieldRequestWrapper' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RunShieldResponse' + description: OK + tags: + - Safety /synthetic_data_generation/generate: post: parameters: [] @@ -3545,7 +3792,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SyntheticDataGenerationRequest' + $ref: '#/components/schemas/SyntheticDataGenerationRequestWrapper' required: true responses: '200': @@ -3561,16 +3808,17 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Telemetry -- name: Evaluations -- name: AgenticSystem -- name: Inference - name: BatchInference -- name: PostTraining -- name: Datasets +- name: Safety +- name: Telemetry +- name: RewardScoring - name: Memory - name: SyntheticDataGeneration -- name: RewardScoring +- name: Inference +- name: Evaluations +- name: PostTraining +- name: Datasets +- name: AgenticSystem - description: name: BatchChatCompletionRequest @@ -3613,12 +3861,18 @@ tags: name: ToolResponseMessage - description: name: UserMessage +- description: + name: BatchChatCompletionRequestWrapper - description: name: BatchChatCompletionResponse - description: name: BatchCompletionRequest +- description: + name: BatchCompletionRequestWrapper - description: name: BatchCompletionResponse @@ -3631,6 +3885,9 @@ tags: - description: name: ChatCompletionRequest +- description: + name: ChatCompletionRequestWrapper - description: 'Chat completion response event. @@ -3656,6 +3913,9 @@ tags: - description: name: CompletionRequest +- description: + name: CompletionRequestWrapper - description: 'streamed completion response. @@ -3664,9 +3924,6 @@ tags: name: CompletionResponseStreamChunk - description: name: AgentConfig -- description: - name: BraveSearchToolDefinition - description: name: BuiltinShield - description: name: RestAPIMethod +- description: + name: SearchToolDefinition - description: name: ShieldDefinition @@ -3711,6 +3971,9 @@ tags: name: AgenticSystemTurnCreateRequest - description: name: Attachment +- description: + name: AgenticSystemTurnCreateRequestWrapper - description: 'Streamed agent execution response. @@ -3767,9 +4030,15 @@ tags: - description: name: TrainEvalDatasetColumnType +- description: + name: CreateDatasetRequestWrapper - description: name: CreateExperimentRequest +- description: + name: CreateExperimentRequestWrapper - description: name: Experiment - description: name: CreateRunRequest +- description: + name: CreateRunRequestWrapper - description: name: Run - description: ' name: EvaluateQuestionAnsweringRequest +- description: + name: EvaluateQuestionAnsweringRequestWrapper - description: name: EvaluationJob - description: 'Request to evaluate summarization. @@ -3825,12 +4100,18 @@ tags: ' name: EvaluateSummarizationRequest +- description: + name: EvaluateSummarizationRequestWrapper - description: 'Request to evaluate text generation. ' name: EvaluateTextGenerationRequest +- description: + name: EvaluateTextGenerationRequestWrapper - description: name: GetAgenticSystemSessionRequest @@ -3867,6 +4148,9 @@ tags: - description: name: LogSearchRequest +- description: + name: LogSearchRequestWrapper - description: name: Log - description: @@ -3903,9 +4187,15 @@ tags: - description: name: LogMessagesRequest +- description: + name: LogMessagesRequestWrapper - description: name: LogMetricsRequest +- description: + name: LogMetricsRequestWrapper - description: name: DPOAlignmentConfig @@ -3921,6 +4211,9 @@ tags: name: RLHFAlgorithm - description: name: TrainingConfig +- description: + name: PostTrainingRLHFRequestWrapper - description: name: QueryDocumentsRequest @@ -3936,6 +4229,9 @@ tags: ' name: RewardScoringRequest +- description: + name: RewardScoringRequestWrapper - description: 'Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold. @@ -3947,6 +4243,15 @@ tags: name: ScoredDialogGenerations - description: name: ScoredMessage +- description: + name: RunShieldRequest +- description: + name: RunShieldRequestWrapper +- description: + name: RunShieldResponse - description: name: DoraFinetuningConfig @@ -3964,6 +4269,9 @@ tags: - description: name: QLoraFinetuningConfig +- description: + name: PostTrainingSFTRequestWrapper - description: 'Request to generate synthetic data. A small batch of prompts and a filtering function @@ -3971,6 +4279,9 @@ tags: ' name: SyntheticDataGenerationRequest +- description: + name: SyntheticDataGenerationRequestWrapper - description: 'Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold. @@ -3984,12 +4295,21 @@ tags: - description: name: UpdateExperimentRequest +- description: + name: UpdateExperimentRequestWrapper - description: name: UpdateRunRequest +- description: + name: UpdateRunRequestWrapper - description: name: UploadArtifactRequest +- description: + name: UploadArtifactRequestWrapper x-tagGroups: - name: Operations tags: @@ -4001,6 +4321,7 @@ x-tagGroups: - Memory - PostTraining - RewardScoring + - Safety - SyntheticDataGeneration - Telemetry - name: Types @@ -4010,6 +4331,7 @@ x-tagGroups: - AgenticSystemSessionCreateResponse - AgenticSystemStepResponse - AgenticSystemTurnCreateRequest + - AgenticSystemTurnCreateRequestWrapper - AgenticSystemTurnResponseEvent - AgenticSystemTurnResponseStepCompletePayload - AgenticSystemTurnResponseStepProgressPayload @@ -4021,15 +4343,17 @@ x-tagGroups: - ArtifactType - Attachment - BatchChatCompletionRequest + - BatchChatCompletionRequestWrapper - BatchChatCompletionResponse - BatchCompletionRequest + - BatchCompletionRequestWrapper - BatchCompletionResponse - - BraveSearchToolDefinition - BuiltinShield - BuiltinTool - CancelEvaluationJobRequest - CancelTrainingJobRequest - ChatCompletionRequest + - ChatCompletionRequestWrapper - ChatCompletionResponseEvent - ChatCompletionResponseEventType - ChatCompletionResponseStreamChunk @@ -4037,13 +4361,17 @@ x-tagGroups: - CodeInterpreterToolDefinition - CompletionMessage - CompletionRequest + - CompletionRequestWrapper - CompletionResponseStreamChunk - CreateAgenticSystemRequest - CreateAgenticSystemSessionRequest - CreateDatasetRequest + - CreateDatasetRequestWrapper - CreateExperimentRequest + - CreateExperimentRequestWrapper - CreateMemoryBankRequest - CreateRunRequest + - CreateRunRequestWrapper - DPOAlignmentConfig - DeleteAgenticSystemRequest - DeleteAgenticSystemSessionRequest @@ -4055,8 +4383,11 @@ x-tagGroups: - EmbeddingsRequest - EmbeddingsResponse - EvaluateQuestionAnsweringRequest + - EvaluateQuestionAnsweringRequestWrapper - EvaluateSummarizationRequest + - EvaluateSummarizationRequestWrapper - EvaluateTextGenerationRequest + - EvaluateTextGenerationRequestWrapper - EvaluationJob - EvaluationJobArtifactsResponse - EvaluationJobLogStream @@ -4072,8 +4403,11 @@ x-tagGroups: - ListArtifactsRequest - Log - LogMessagesRequest + - LogMessagesRequestWrapper - LogMetricsRequest + - LogMetricsRequestWrapper - LogSearchRequest + - LogSearchRequestWrapper - LoraFinetuningConfig - MemoryBank - MemoryBankDocument @@ -4088,7 +4422,9 @@ x-tagGroups: - PostTrainingJobStatus - PostTrainingJobStatusResponse - PostTrainingRLHFRequest + - PostTrainingRLHFRequestWrapper - PostTrainingSFTRequest + - PostTrainingSFTRequestWrapper - QLoraFinetuningConfig - QueryDocumentsRequest - QueryDocumentsResponse @@ -4096,18 +4432,24 @@ x-tagGroups: - RestAPIExecutionConfig - RestAPIMethod - RewardScoringRequest + - RewardScoringRequestWrapper - RewardScoringResponse - Run + - RunShieldRequest + - RunShieldRequestWrapper + - RunShieldResponse - SamplingParams - SamplingStrategy - ScoredDialogGenerations - ScoredMessage + - SearchToolDefinition - Session - ShieldCallStep - ShieldDefinition - ShieldResponse - StopReason - SyntheticDataGenerationRequest + - SyntheticDataGenerationRequestWrapper - SyntheticDataGenerationResponse - SystemMessage - TokenLogProbs @@ -4128,7 +4470,10 @@ x-tagGroups: - URL - UpdateDocumentsRequest - UpdateExperimentRequest + - UpdateExperimentRequestWrapper - UpdateRunRequest + - UpdateRunRequestWrapper - UploadArtifactRequest + - UploadArtifactRequestWrapper - UserMessage - WolframAlphaToolDefinition diff --git a/rfcs/openapi_generator/generate.py b/rfcs/openapi_generator/generate.py index ab9774e70..279389a47 100644 --- a/rfcs/openapi_generator/generate.py +++ b/rfcs/openapi_generator/generate.py @@ -35,7 +35,10 @@ from llama_toolchain.stack import LlamaStack # TODO: this should be fixed in the generator itself so it reads appropriate annotations -STREAMING_ENDPOINTS = ["/agentic_system/turn/create"] +STREAMING_ENDPOINTS = [ + "/agentic_system/turn/create", + "/inference/chat_completion", +] def patch_sse_stream_responses(spec: Specification): diff --git a/rfcs/openapi_generator/pyopenapi/generator.py b/rfcs/openapi_generator/pyopenapi/generator.py index e1450074b..a711d9f68 100644 --- a/rfcs/openapi_generator/pyopenapi/generator.py +++ b/rfcs/openapi_generator/pyopenapi/generator.py @@ -468,12 +468,14 @@ class Generator: builder = ContentBuilder(self.schema_builder) first = next(iter(op.request_params)) request_name, request_type = first + + from dataclasses import make_dataclass + if len(op.request_params) == 1 and "Request" in first[1].__name__: # TODO(ashwin): Undo the "Request" hack and this entire block eventually - request_name, request_type = first + request_name = first[1].__name__ + "Wrapper" + request_type = make_dataclass(request_name, op.request_params) else: - from dataclasses import make_dataclass - op_name = "".join(word.capitalize() for word in op.name.split("_")) request_name = f"{op_name}Request" request_type = make_dataclass(request_name, op.request_params)