diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 7457aa45e..905539d42 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -248,51 +248,51 @@ llama stack list-distributions ```
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| Distribution ID                | Providers                             | Description                                                                               |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| local                          | {                                     | Use code from `llama_toolchain` itself to serve all llama stack APIs                      |
-|                                |   "inference": "meta-reference",      |                                                                                           |
-|                                |   "memory": "meta-reference-faiss",   |                                                                                           |
-|                                |   "safety": "meta-reference",         |                                                                                           |
-|                                |   "agentic_system": "meta-reference"  |                                                                                           |
-|                                | }                                     |                                                                                           |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| remote                         | {                                     | Point to remote services for all llama stack APIs                                         |
-|                                |   "inference": "remote",              |                                                                                           |
-|                                |   "safety": "remote",                 |                                                                                           |
-|                                |   "agentic_system": "remote",         |                                                                                           |
-|                                |   "memory": "remote"                  |                                                                                           |
-|                                | }                                     |                                                                                           |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| local-ollama                   | {                                     | Like local, but use ollama for running LLM inference                                      |
-|                                |   "inference": "remote::ollama",      |                                                                                           |
-|                                |   "safety": "meta-reference",         |                                                                                           |
-|                                |   "agentic_system": "meta-reference", |                                                                                           |
-|                                |   "memory": "meta-reference-faiss"    |                                                                                           |
-|                                | }                                     |                                                                                           |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| local-plus-fireworks-inference | {                                     | Use Fireworks.ai for running LLM inference                                                |
-|                                |   "inference": "remote::fireworks",   |                                                                                           |
-|                                |   "safety": "meta-reference",         |                                                                                           |
-|                                |   "agentic_system": "meta-reference", |                                                                                           |
-|                                |   "memory": "meta-reference-faiss"    |                                                                                           |
-|                                | }                                     |                                                                                           |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| local-plus-together-inference  | {                                     | Use Together.ai for running LLM inference                                                 |
-|                                |   "inference": "remote::together",    |                                                                                           |
-|                                |   "safety": "meta-reference",         |                                                                                           |
-|                                |   "agentic_system": "meta-reference", |                                                                                           |
-|                                |   "memory": "meta-reference-faiss"    |                                                                                           |
-|                                | }                                     |                                                                                           |
-|--------------------------------|---------------------------------------|-------------------------------------------------------------------------------------------|
-| local-plus-tgi-inference       | {                                     | Use TGI (local or with [Hugging Face Inference Endpoints](https://huggingface.co/         |
-|                                |   "inference": "remote::tgi",         | inference-endpoints/dedicated)) for running LLM inference. When using HF Inference        |
-|                                |   "safety": "meta-reference",         | Endpoints, you must provide the name of the endpoint.                                     |
-|                                |   "agentic_system": "meta-reference", |                                                                                           |
-|                                |   "memory": "meta-reference-faiss"    |                                                                                           |
-|                                | }                                     |                                                                                           |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+i+-------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| Distribution Type              | Providers                             | Description                                                          |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local                          | {                                     | Use code from `llama_toolchain` itself to serve all llama stack APIs |
+|                                |   "inference": "meta-reference",      |                                                                      |
+|                                |   "memory": "meta-reference-faiss",   |                                                                      |
+|                                |   "safety": "meta-reference",         |                                                                      |
+|                                |   "agentic_system": "meta-reference"  |                                                                      |
+|                                | }                                     |                                                                      |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| remote                         | {                                     | Point to remote services for all llama stack APIs                    |
+|                                |   "inference": "remote",              |                                                                      |
+|                                |   "safety": "remote",                 |                                                                      |
+|                                |   "agentic_system": "remote",         |                                                                      |
+|                                |   "memory": "remote"                  |                                                                      |
+|                                | }                                     |                                                                      |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local-ollama                   | {                                     | Like local, but use ollama for running LLM inference                 |
+|                                |   "inference": "remote::ollama",      |                                                                      |
+|                                |   "safety": "meta-reference",         |                                                                      |
+|                                |   "agentic_system": "meta-reference", |                                                                      |
+|                                |   "memory": "meta-reference-faiss"    |                                                                      |
+|                                | }                                     |                                                                      |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local-plus-fireworks-inference | {                                     | Use Fireworks.ai for running LLM inference                           |
+|                                |   "inference": "remote::fireworks",   |                                                                      |
+|                                |   "safety": "meta-reference",         |                                                                      |
+|                                |   "agentic_system": "meta-reference", |                                                                      |
+|                                |   "memory": "meta-reference-faiss"    |                                                                      |
+|                                | }                                     |                                                                      |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local-plus-together-inference  | {                                     | Use Together.ai for running LLM inference                            |
+|                                |   "inference": "remote::together",    |                                                                      |
+|                                |   "safety": "meta-reference",         |                                                                      |
+|                                |   "agentic_system": "meta-reference", |                                                                      |
+|                                |   "memory": "meta-reference-faiss"    |                                                                      |
+|                                | }                                     |                                                                      |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local-plus-tgi-inference       | {                                     | Use TGI (local or with [Hugging Face Inference Endpoints](https://   |
+|                                |   "inference": "remote::tgi",         | huggingface.co/inference-endpoints/dedicated)) for running LLM       |
+|                                |   "safety": "meta-reference",         | inference. When using HF Inference Endpoints, you must provide the   |
+|                                |   "agentic_system": "meta-reference", | name of the endpoint.                                                |
+|                                |   "memory": "meta-reference-faiss"    |                                                                      |
+|                                | }                                     |                                                                      |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
 
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well. diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_toolchain/agentic_system/api/api.py index e3f417918..68ec980e6 100644 --- a/llama_toolchain/agentic_system/api/api.py +++ b/llama_toolchain/agentic_system/api/api.py @@ -116,10 +116,47 @@ MemoryBankConfig = Annotated[ ] -@json_schema_type +class MemoryQueryGenerator(Enum): + default = "default" + llm = "llm" + custom = "custom" + + +class DefaultMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.default.value] = ( + MemoryQueryGenerator.default.value + ) + sep: str = " " + + +class LLMMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value + model: str + template: str + + +class CustomMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value + + +MemoryQueryGeneratorConfig = Annotated[ + Union[ + DefaultMemoryQueryGeneratorConfig, + LLMMemoryQueryGeneratorConfig, + CustomMemoryQueryGeneratorConfig, + ], + Field(discriminator="type"), +] + + class MemoryToolDefinition(ToolDefinitionCommon): type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: MemoryQueryGeneratorConfig = Field( + default=DefaultMemoryQueryGeneratorConfig() + ) max_tokens_in_context: int = 4096 max_chunks: int = 10 diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index ed3145b1e..4d38e0032 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -31,6 +31,7 @@ from llama_toolchain.tools.builtin import ( SingleMessageBuiltinTool, ) +from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin @@ -664,7 +665,9 @@ class ChatAgent(ShieldRunnerMixin): # (i.e., no prior turns uploaded an Attachment) return None, [] - query = " ".join(m.content for m in messages) + query = await generate_rag_query( + memory.query_generator_config, messages, inference_api=self.inference_api + ) tasks = [ self.memory_api.query_documents( bank_id=bank_id, diff --git a/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py new file mode 100644 index 000000000..afcc6afd1 --- /dev/null +++ b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from jinja2 import Template +from llama_models.llama3.api import * # noqa: F403 + + +from llama_toolchain.agentic_system.api import ( + DefaultMemoryQueryGeneratorConfig, + LLMMemoryQueryGeneratorConfig, + MemoryQueryGenerator, + MemoryQueryGeneratorConfig, +) +from termcolor import cprint # noqa: F401 +from llama_toolchain.inference.api import * # noqa: F403 + + +async def generate_rag_query( + config: MemoryQueryGeneratorConfig, + messages: List[Message], + **kwargs, +): + """ + Generates a query that will be used for + retrieving relevant information from the memory bank. + """ + if config.type == MemoryQueryGenerator.default.value: + query = await default_rag_query_generator(config, messages, **kwargs) + elif config.type == MemoryQueryGenerator.llm.value: + query = await llm_rag_query_generator(config, messages, **kwargs) + else: + raise NotImplementedError(f"Unsupported memory query generator {config.type}") + # cprint(f"Generated query >>>: {query}", color="green") + return query + + +async def default_rag_query_generator( + config: DefaultMemoryQueryGeneratorConfig, + messages: List[Message], + **kwargs, +): + return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages) + + +async def llm_rag_query_generator( + config: LLMMemoryQueryGeneratorConfig, + messages: List[Message], + **kwargs, +): + assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" + inference_api = kwargs["inference_api"] + + m_dict = {"messages": [m.model_dump() for m in messages]} + + template = Template(config.template) + content = template.render(m_dict) + + model = config.model + message = UserMessage(content=content) + response = inference_api.chat_completion( + ChatCompletionRequest( + model=model, + messages=[message], + stream=False, + ) + ) + + async for chunk in response: + query = chunk.completion_message.content + + return query diff --git a/llama_toolchain/agentic_system/providers.py b/llama_toolchain/agentic_system/providers.py index a722d9400..164df1a30 100644 --- a/llama_toolchain/agentic_system/providers.py +++ b/llama_toolchain/agentic_system/providers.py @@ -13,7 +13,7 @@ def available_agentic_system_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.agentic_system, - provider_id="meta-reference", + provider_type="meta-reference", pip_packages=[ "codeshield", "matplotlib", diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py index c81a6d350..22bd4071f 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_toolchain/cli/stack/build.py @@ -52,7 +52,7 @@ class StackBuild(Subcommand): BuildType, ) - allowed_ids = [d.distribution_id for d in available_distribution_specs()] + allowed_ids = [d.distribution_type for d in available_distribution_specs()] self.parser.add_argument( "distribution", type=str, @@ -101,7 +101,7 @@ class StackBuild(Subcommand): api_inputs.append( ApiInput( api=api, - provider=provider_spec.provider_id, + provider=provider_spec.provider_type, ) ) docker_image = None @@ -115,11 +115,11 @@ class StackBuild(Subcommand): self.parser.error(f"Could not find distribution {args.distribution}") return - for api, provider_id in dist.providers.items(): + for api, provider_type in dist.providers.items(): api_inputs.append( ApiInput( api=api, - provider=provider_id, + provider=provider_type, ) ) docker_image = dist.docker_image @@ -128,6 +128,6 @@ class StackBuild(Subcommand): api_inputs, build_type=BuildType(args.type), name=args.name, - distribution_id=args.distribution, + distribution_type=args.distribution, docker_image=docker_image, ) diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 70ff4a7f0..658380f4d 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -36,7 +36,7 @@ class StackConfigure(Subcommand): ) from llama_toolchain.core.package import BuildType - allowed_ids = [d.distribution_id for d in available_distribution_specs()] + allowed_ids = [d.distribution_type for d in available_distribution_specs()] self.parser.add_argument( "distribution", type=str, @@ -84,7 +84,7 @@ def configure_llama_distribution(config_file: Path) -> None: if config.providers: cprint( - f"Configuration already exists for {config.distribution_id}. Will overwrite...", + f"Configuration already exists for {config.distribution_type}. Will overwrite...", "yellow", attrs=["bold"], ) diff --git a/llama_toolchain/cli/stack/list_distributions.py b/llama_toolchain/cli/stack/list_distributions.py index c4d529157..557b8c33c 100644 --- a/llama_toolchain/cli/stack/list_distributions.py +++ b/llama_toolchain/cli/stack/list_distributions.py @@ -33,7 +33,7 @@ class StackListDistributions(Subcommand): # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ - "Distribution ID", + "Distribution Type", "Providers", "Description", ] @@ -43,7 +43,7 @@ class StackListDistributions(Subcommand): providers = {k.value: v for k, v in spec.providers.items()} rows.append( [ - spec.distribution_id, + spec.distribution_type, json.dumps(providers, indent=2), spec.description, ] diff --git a/llama_toolchain/cli/stack/list_providers.py b/llama_toolchain/cli/stack/list_providers.py index 29602d889..fdf4ab054 100644 --- a/llama_toolchain/cli/stack/list_providers.py +++ b/llama_toolchain/cli/stack/list_providers.py @@ -41,7 +41,7 @@ class StackListProviders(Subcommand): # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ - "Provider ID", + "Provider Type", "PIP Package Dependencies", ] @@ -49,7 +49,7 @@ class StackListProviders(Subcommand): for spec in providers_for_api.values(): rows.append( [ - spec.provider_id, + spec.provider_type, ",".join(spec.pip_packages), ] ) diff --git a/llama_toolchain/cli/stack/run.py b/llama_toolchain/cli/stack/run.py index 68853db35..1568ed820 100644 --- a/llama_toolchain/cli/stack/run.py +++ b/llama_toolchain/cli/stack/run.py @@ -80,7 +80,7 @@ class StackRun(Subcommand): with open(config_file, "r") as f: config = PackageConfig(**yaml.safe_load(f)) - if not config.distribution_id: + if not config.distribution_type: raise ValueError("Build config appears to be corrupt.") if config.docker_image: diff --git a/llama_toolchain/core/build_conda_env.sh b/llama_toolchain/core/build_conda_env.sh index 1e8c002f2..e5b1ca539 100755 --- a/llama_toolchain/core/build_conda_env.sh +++ b/llama_toolchain/core/build_conda_env.sh @@ -20,12 +20,12 @@ fi set -euo pipefail if [ "$#" -ne 3 ]; then - echo "Usage: $0 " >&2 - echo "Example: $0 mybuild 'numpy pandas scipy'" >&2 + echo "Usage: $0 " >&2 + echo "Example: $0 mybuild 'numpy pandas scipy'" >&2 exit 1 fi -distribution_id="$1" +distribution_type="$1" build_name="$2" env_name="llamastack-$build_name" pip_dependencies="$3" @@ -117,4 +117,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies" printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n" -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type conda_env +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type conda_env diff --git a/llama_toolchain/core/build_container.sh b/llama_toolchain/core/build_container.sh index ec2ca8a0c..e5349cd08 100755 --- a/llama_toolchain/core/build_container.sh +++ b/llama_toolchain/core/build_container.sh @@ -5,12 +5,12 @@ LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} if [ "$#" -ne 4 ]; then - echo "Usage: $0 - echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn' + echo "Usage: $0 + echo "Example: $0 distribution_type my-fastapi-app python:3.9-slim 'fastapi uvicorn' exit 1 fi -distribution_id=$1 +distribution_type=$1 build_name="$2" image_name="llamastack-$build_name" docker_base=$3 @@ -110,4 +110,4 @@ set +x printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}" echo "You can run it with: podman run -p 8000:8000 $image_name" -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type container +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type container diff --git a/llama_toolchain/core/configure.py b/llama_toolchain/core/configure.py index 7f9aa0140..252358a52 100644 --- a/llama_toolchain/core/configure.py +++ b/llama_toolchain/core/configure.py @@ -21,14 +21,14 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None: for api_str, stub_config in existing_configs.items(): api = Api(api_str) providers = all_providers[api] - provider_id = stub_config["provider_id"] - if provider_id not in providers: + provider_type = stub_config["provider_type"] + if provider_type not in providers: raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api_str}`" + f"Unknown provider `{provider_type}` is not available for API `{api_str}`" ) - provider_spec = providers[provider_id] - cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"]) + provider_spec = providers[provider_type] + cprint(f"Configuring API: {api_str} ({provider_type})", "white", attrs=["bold"]) config_type = instantiate_class_type(provider_spec.config_class) try: @@ -43,7 +43,7 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None: print("") provider_configs[api_str] = { - "provider_id": provider_id, + "provider_type": provider_type, **provider_config.dict(), } diff --git a/llama_toolchain/core/datatypes.py b/llama_toolchain/core/datatypes.py index cbdda51d4..138d20941 100644 --- a/llama_toolchain/core/datatypes.py +++ b/llama_toolchain/core/datatypes.py @@ -31,7 +31,7 @@ class ApiEndpoint(BaseModel): @json_schema_type class ProviderSpec(BaseModel): api: Api - provider_id: str + provider_type: str config_class: str = Field( ..., description="Fully-qualified classname of the config for this provider", @@ -100,7 +100,7 @@ class RemoteProviderConfig(BaseModel): return url.rstrip("/") -def remote_provider_id(adapter_id: str) -> str: +def remote_provider_type(adapter_id: str) -> str: return f"remote::{adapter_id}" @@ -141,22 +141,22 @@ def remote_provider_spec( if adapter and adapter.config_class else "llama_toolchain.core.datatypes.RemoteProviderConfig" ) - provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote" + provider_type = remote_provider_type(adapter.adapter_id) if adapter else "remote" return RemoteProviderSpec( - api=api, provider_id=provider_id, config_class=config_class, adapter=adapter + api=api, provider_type=provider_type, config_class=config_class, adapter=adapter ) @json_schema_type class DistributionSpec(BaseModel): - distribution_id: str + distribution_type: str description: str docker_image: Optional[str] = None providers: Dict[Api, str] = Field( default_factory=dict, - description="Provider IDs for each of the APIs provided by this distribution", + description="Provider Types for each of the APIs provided by this distribution", ) @@ -171,7 +171,7 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p this could be just a hash """, ) - distribution_id: Optional[str] = None + distribution_type: Optional[str] = None docker_image: Optional[str] = Field( default=None, diff --git a/llama_toolchain/core/distribution.py b/llama_toolchain/core/distribution.py index 4c50189c0..89e1d7793 100644 --- a/llama_toolchain/core/distribution.py +++ b/llama_toolchain/core/distribution.py @@ -83,18 +83,18 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: inference_providers_by_id = { - a.provider_id: a for a in available_inference_providers() + a.provider_type: a for a in available_inference_providers() } - safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()} + safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()} agentic_system_providers_by_id = { - a.provider_id: a for a in available_agentic_system_providers() + a.provider_type: a for a in available_agentic_system_providers() } ret = { Api.inference: inference_providers_by_id, Api.safety: safety_providers_by_id, Api.agentic_system: agentic_system_providers_by_id, - Api.memory: {a.provider_id: a for a in available_memory_providers()}, + Api.memory: {a.provider_type: a for a in available_memory_providers()}, } for k, v in ret.items(): v["remote"] = remote_provider_spec(k) diff --git a/llama_toolchain/core/distribution_registry.py b/llama_toolchain/core/distribution_registry.py index 9413e1374..855fc6300 100644 --- a/llama_toolchain/core/distribution_registry.py +++ b/llama_toolchain/core/distribution_registry.py @@ -14,7 +14,7 @@ from .datatypes import * # noqa: F403 def available_distribution_specs() -> List[DistributionSpec]: return [ DistributionSpec( - distribution_id="local", + distribution_type="local", description="Use code from `llama_toolchain` itself to serve all llama stack APIs", providers={ Api.inference: "meta-reference", @@ -24,35 +24,35 @@ def available_distribution_specs() -> List[DistributionSpec]: }, ), DistributionSpec( - distribution_id="remote", + distribution_type="remote", description="Point to remote services for all llama stack APIs", providers={x: "remote" for x in Api}, ), DistributionSpec( - distribution_id="local-ollama", + distribution_type="local-ollama", description="Like local, but use ollama for running LLM inference", providers={ - Api.inference: remote_provider_id("ollama"), + Api.inference: remote_provider_type("ollama"), Api.safety: "meta-reference", Api.agentic_system: "meta-reference", Api.memory: "meta-reference-faiss", }, ), DistributionSpec( - distribution_id="local-plus-fireworks-inference", + distribution_type="local-plus-fireworks-inference", description="Use Fireworks.ai for running LLM inference", providers={ - Api.inference: remote_provider_id("fireworks"), + Api.inference: remote_provider_type("fireworks"), Api.safety: "meta-reference", Api.agentic_system: "meta-reference", Api.memory: "meta-reference-faiss", }, ), DistributionSpec( - distribution_id="local-plus-together-inference", + distribution_type="local-plus-together-inference", description="Use Together.ai for running LLM inference", providers={ - Api.inference: remote_provider_id("together"), + Api.inference: remote_provider_type("together"), Api.safety: "meta-reference", Api.agentic_system: "meta-reference", Api.memory: "meta-reference-faiss", @@ -72,8 +72,8 @@ def available_distribution_specs() -> List[DistributionSpec]: @lru_cache() -def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]: +def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]: for spec in available_distribution_specs(): - if spec.distribution_id == distribution_id: + if spec.distribution_type == distribution_type: return spec return None diff --git a/llama_toolchain/core/package.py b/llama_toolchain/core/package.py index 72bd93152..ab4346a71 100644 --- a/llama_toolchain/core/package.py +++ b/llama_toolchain/core/package.py @@ -46,13 +46,13 @@ def build_package( api_inputs: List[ApiInput], build_type: BuildType, name: str, - distribution_id: Optional[str] = None, + distribution_type: Optional[str] = None, docker_image: Optional[str] = None, ): - if not distribution_id: - distribution_id = "adhoc" + if not distribution_type: + distribution_type = "adhoc" - build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor() + build_dir = BUILDS_BASE_DIR / distribution_type / build_type.descriptor() os.makedirs(build_dir, exist_ok=True) package_name = name.replace("::", "-") @@ -79,7 +79,7 @@ def build_package( if provider.docker_image: raise ValueError("A stack's dependencies cannot have a docker image") - stub_config[api.value] = {"provider_id": api_input.provider} + stub_config[api.value] = {"provider_type": api_input.provider} if package_file.exists(): cprint( @@ -92,7 +92,7 @@ def build_package( c.providers[api_str] = new_config else: existing_config = c.providers[api_str] - if existing_config["provider_id"] != new_config["provider_id"]: + if existing_config["provider_type"] != new_config["provider_type"]: cprint( f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`", color="yellow", @@ -105,7 +105,7 @@ def build_package( providers=stub_config, ) - c.distribution_id = distribution_id + c.distribution_type = distribution_type c.docker_image = package_name if build_type == BuildType.container else None c.conda_env = package_name if build_type == BuildType.conda_env else None @@ -119,7 +119,7 @@ def build_package( ) args = [ script, - distribution_id, + distribution_type, package_name, package_deps.docker_image, " ".join(package_deps.pip_packages), @@ -130,7 +130,7 @@ def build_package( ) args = [ script, - distribution_id, + distribution_type, package_name, " ".join(package_deps.pip_packages), ] diff --git a/llama_toolchain/core/server.py b/llama_toolchain/core/server.py index 4de84b726..8c7ab10a7 100644 --- a/llama_toolchain/core/server.py +++ b/llama_toolchain/core/server.py @@ -284,13 +284,13 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): for api_str, provider_config in config["providers"].items(): api = Api(api_str) providers = all_providers[api] - provider_id = provider_config["provider_id"] - if provider_id not in providers: + provider_type = provider_config["provider_type"] + if provider_type not in providers: raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api}`" + f"Unknown provider `{provider_type}` is not available for API `{api}`" ) - provider_specs[api] = providers[provider_id] + provider_specs[api] = providers[provider_type] impls = resolve_impls(provider_specs, config) diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index e9f6b4072..f313de3fd 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -13,7 +13,7 @@ def available_inference_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.inference, - provider_id="meta-reference", + provider_type="meta-reference", pip_packages=[ "accelerate", "blobfile", diff --git a/llama_toolchain/memory/adapters/chroma/__init__.py b/llama_toolchain/memory/adapters/chroma/__init__.py new file mode 100644 index 000000000..c90a8e8ac --- /dev/null +++ b/llama_toolchain/memory/adapters/chroma/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_toolchain.core.datatypes import RemoteProviderConfig + + +async def get_adapter_impl(config: RemoteProviderConfig, _deps): + from .chroma import ChromaMemoryAdapter + + impl = ChromaMemoryAdapter(config.url) + await impl.initialize() + return impl diff --git a/llama_toolchain/memory/adapters/chroma/chroma.py b/llama_toolchain/memory/adapters/chroma/chroma.py new file mode 100644 index 000000000..f4952cd0e --- /dev/null +++ b/llama_toolchain/memory/adapters/chroma/chroma.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import uuid +from typing import List +from urllib.parse import urlparse + +import chromadb +from numpy.typing import NDArray + +from llama_toolchain.memory.api import * # noqa: F403 + + +from llama_toolchain.memory.common.vector_store import BankWithIndex, EmbeddingIndex + + +class ChromaIndex(EmbeddingIndex): + def __init__(self, client: chromadb.AsyncHttpClient, collection): + self.client = client + self.collection = collection + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len( + embeddings + ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + + for i, chunk in enumerate(chunks): + print(f"Adding chunk #{i} tokens={chunk.token_count}") + + await self.collection.add( + documents=[chunk.json() for chunk in chunks], + embeddings=embeddings, + ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], + ) + + async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + results = await self.collection.query( + query_embeddings=[embedding.tolist()], + n_results=k, + include=["documents", "distances"], + ) + distances = results["distances"][0] + documents = results["documents"][0] + + chunks = [] + scores = [] + for dist, doc in zip(distances, documents): + try: + doc = json.loads(doc) + chunk = Chunk(**doc) + except Exception: + import traceback + + traceback.print_exc() + print(f"Failed to parse document: {doc}") + continue + + chunks.append(chunk) + scores.append(1.0 / float(dist)) + + return QueryDocumentsResponse(chunks=chunks, scores=scores) + + +class ChromaMemoryAdapter(Memory): + def __init__(self, url: str) -> None: + print(f"Initializing ChromaMemoryAdapter with url: {url}") + url = url.rstrip("/") + parsed = urlparse(url) + + if parsed.path and parsed.path != "/": + raise ValueError("URL should not contain a path") + + self.host = parsed.hostname + self.port = parsed.port + + self.client = None + self.cache = {} + + async def initialize(self) -> None: + try: + print(f"Connecting to Chroma server at: {self.host}:{self.port}") + self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port) + except Exception as e: + import traceback + + traceback.print_exc() + raise RuntimeError("Could not connect to Chroma server") from e + + async def shutdown(self) -> None: + pass + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + bank_id = str(uuid.uuid4()) + bank = MemoryBank( + bank_id=bank_id, + name=name, + config=config, + url=url, + ) + collection = await self.client.create_collection( + name=bank_id, + metadata={"bank": bank.json()}, + ) + bank_index = BankWithIndex( + bank=bank, index=ChromaIndex(self.client, collection) + ) + self.cache[bank_id] = bank_index + return bank + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + bank_index = await self._get_and_cache_bank_index(bank_id) + if bank_index is None: + return None + return bank_index.bank + + async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: + if bank_id in self.cache: + return self.cache[bank_id] + + collections = await self.client.list_collections() + for collection in collections: + if collection.name == bank_id: + print(collection.metadata) + bank = MemoryBank(**json.loads(collection.metadata["bank"])) + index = BankWithIndex( + bank=bank, + index=ChromaIndex(self.client, collection), + ) + self.cache[bank_id] = index + return index + + return None + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + await index.insert_documents(documents) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + return await index.query_documents(query, params) diff --git a/llama_toolchain/memory/adapters/pgvector/__init__.py b/llama_toolchain/memory/adapters/pgvector/__init__.py new file mode 100644 index 000000000..4ac30452f --- /dev/null +++ b/llama_toolchain/memory/adapters/pgvector/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import PGVectorConfig + + +async def get_adapter_impl(config: PGVectorConfig, _deps): + from .pgvector import PGVectorMemoryAdapter + + impl = PGVectorMemoryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_toolchain/memory/adapters/pgvector/config.py b/llama_toolchain/memory/adapters/pgvector/config.py new file mode 100644 index 000000000..87b2f4a3b --- /dev/null +++ b/llama_toolchain/memory/adapters/pgvector/config.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class PGVectorConfig(BaseModel): + host: str = Field(default="localhost") + port: int = Field(default=5432) + db: str + user: str + password: str diff --git a/llama_toolchain/memory/adapters/pgvector/pgvector.py b/llama_toolchain/memory/adapters/pgvector/pgvector.py new file mode 100644 index 000000000..930d7720f --- /dev/null +++ b/llama_toolchain/memory/adapters/pgvector/pgvector.py @@ -0,0 +1,234 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid + +from typing import List, Tuple + +import psycopg2 +from numpy.typing import NDArray +from psycopg2 import sql +from psycopg2.extras import execute_values, Json +from pydantic import BaseModel +from llama_toolchain.memory.api import * # noqa: F403 + + +from llama_toolchain.memory.common.vector_store import ( + ALL_MINILM_L6_V2_DIMENSION, + BankWithIndex, + EmbeddingIndex, +) + +from .config import PGVectorConfig + + +def check_extension_version(cur): + cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'") + result = cur.fetchone() + return result[0] if result else None + + +def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]): + query = sql.SQL( + """ + INSERT INTO metadata_store (key, data) + VALUES %s + ON CONFLICT (key) DO UPDATE + SET data = EXCLUDED.data + """ + ) + + values = [(key, Json(model.dict())) for key, model in keys_models] + execute_values(cur, query, values, template="(%s, %s)") + + +def load_models(cur, keys: List[str], cls): + query = "SELECT key, data FROM metadata_store" + if keys: + placeholders = ",".join(["%s"] * len(keys)) + query += f" WHERE key IN ({placeholders})" + cur.execute(query, keys) + else: + cur.execute(query) + + rows = cur.fetchall() + return [cls(**row["data"]) for row in rows] + + +class PGVectorIndex(EmbeddingIndex): + def __init__(self, bank: MemoryBank, dimension: int, cursor): + self.cursor = cursor + self.table_name = f"vector_store_{bank.name}" + + self.cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id TEXT PRIMARY KEY, + document JSONB, + embedding vector({dimension}) + ) + """ + ) + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len( + embeddings + ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + + values = [] + for i, chunk in enumerate(chunks): + print(f"Adding chunk #{i} tokens={chunk.token_count}") + values.append( + ( + f"{chunk.document_id}:chunk-{i}", + Json(chunk.dict()), + embeddings[i].tolist(), + ) + ) + + query = sql.SQL( + f""" + INSERT INTO {self.table_name} (id, document, embedding) + VALUES %s + ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document + """ + ) + execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)") + + async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + self.cursor.execute( + f""" + SELECT document, embedding <-> %s::vector AS distance + FROM {self.table_name} + ORDER BY distance + LIMIT %s + """, + (embedding.tolist(), k), + ) + results = self.cursor.fetchall() + + chunks = [] + scores = [] + for doc, dist in results: + chunks.append(Chunk(**doc)) + scores.append(1.0 / float(dist)) + + return QueryDocumentsResponse(chunks=chunks, scores=scores) + + +class PGVectorMemoryAdapter(Memory): + def __init__(self, config: PGVectorConfig) -> None: + print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") + self.config = config + self.cursor = None + self.conn = None + self.cache = {} + + async def initialize(self) -> None: + try: + self.conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.db, + user=self.config.user, + password=self.config.password, + ) + self.cursor = self.conn.cursor() + + version = check_extension_version(self.cursor) + if version: + print(f"Vector extension version: {version}") + else: + raise RuntimeError("Vector extension is not installed.") + + self.cursor.execute( + """ + CREATE TABLE IF NOT EXISTS metadata_store ( + key TEXT PRIMARY KEY, + data JSONB + ) + """ + ) + except Exception as e: + import traceback + + traceback.print_exc() + raise RuntimeError("Could not connect to PGVector database server") from e + + async def shutdown(self) -> None: + pass + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + bank_id = str(uuid.uuid4()) + bank = MemoryBank( + bank_id=bank_id, + name=name, + config=config, + url=url, + ) + upsert_models( + self.cursor, + [ + (bank.bank_id, bank), + ], + ) + index = BankWithIndex( + bank=bank, + index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), + ) + self.cache[bank_id] = index + return bank + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + bank_index = await self._get_and_cache_bank_index(bank_id) + if bank_index is None: + return None + return bank_index.bank + + async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: + if bank_id in self.cache: + return self.cache[bank_id] + + banks = load_models(self.cursor, [bank_id], MemoryBank) + if not banks: + return None + + bank = banks[0] + index = BankWithIndex( + bank=bank, + index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), + ) + self.cache[bank_id] = index + return index + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + await index.insert_documents(documents) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + return await index.query_documents(query, params) diff --git a/llama_toolchain/memory/common/vector_store.py b/llama_toolchain/memory/common/vector_store.py new file mode 100644 index 000000000..154deea18 --- /dev/null +++ b/llama_toolchain/memory/common/vector_store.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import httpx +import numpy as np +from numpy.typing import NDArray + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_toolchain.memory.api import * # noqa: F403 + + +ALL_MINILM_L6_V2_DIMENSION = 384 + +EMBEDDING_MODEL = None + + +def get_embedding_model() -> "SentenceTransformer": + global EMBEDDING_MODEL + + if EMBEDDING_MODEL is None: + print("Loading sentence transformer") + + from sentence_transformers import SentenceTransformer + + EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2") + + return EMBEDDING_MODEL + + +async def content_from_doc(doc: MemoryBankDocument) -> str: + if isinstance(doc.content, URL): + async with httpx.AsyncClient() as client: + r = await client.get(doc.content.uri) + return r.text + + return interleaved_text_media_as_str(doc.content) + + +def make_overlapped_chunks( + document_id: str, text: str, window_len: int, overlap_len: int +) -> List[Chunk]: + tokenizer = Tokenizer.get_instance() + tokens = tokenizer.encode(text, bos=False, eos=False) + + chunks = [] + for i in range(0, len(tokens), window_len - overlap_len): + toks = tokens[i : i + window_len] + chunk = tokenizer.decode(toks) + chunks.append( + Chunk(content=chunk, token_count=len(toks), document_id=document_id) + ) + + return chunks + + +class EmbeddingIndex(ABC): + @abstractmethod + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + raise NotImplementedError() + + @abstractmethod + async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + raise NotImplementedError() + + +@dataclass +class BankWithIndex: + bank: MemoryBank + index: EmbeddingIndex + + async def insert_documents( + self, + documents: List[MemoryBankDocument], + ) -> None: + model = get_embedding_model() + for doc in documents: + content = await content_from_doc(doc) + chunks = make_overlapped_chunks( + doc.document_id, + content, + self.bank.config.chunk_size_in_tokens, + self.bank.config.overlap_size_in_tokens + or (self.bank.config.chunk_size_in_tokens // 4), + ) + embeddings = model.encode([x.content for x in chunks]).astype(np.float32) + + await self.index.add_chunks(chunks, embeddings) + + async def query_documents( + self, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + if params is None: + params = {} + k = params.get("max_chunks", 3) + + def _process(c) -> str: + if isinstance(c, str): + return c + else: + return "" + + if isinstance(query, list): + query_str = " ".join([_process(c) for c in query]) + else: + query_str = _process(query) + + model = get_embedding_model() + query_vector = model.encode([query_str])[0].astype(np.float32) + return await self.index.query(query_vector, k) diff --git a/llama_toolchain/memory/meta_reference/faiss/faiss.py b/llama_toolchain/memory/meta_reference/faiss/faiss.py index 422674939..807aa208f 100644 --- a/llama_toolchain/memory/meta_reference/faiss/faiss.py +++ b/llama_toolchain/memory/meta_reference/faiss/faiss.py @@ -5,108 +5,45 @@ # the root directory of this source tree. import uuid -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple + +from typing import Any, Dict, List, Optional import faiss -import httpx import numpy as np +from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_models.llama3.api.tokenizer import Tokenizer from llama_toolchain.memory.api import * # noqa: F403 +from llama_toolchain.memory.common.vector_store import ( + ALL_MINILM_L6_V2_DIMENSION, + BankWithIndex, + EmbeddingIndex, +) from .config import FaissImplConfig -async def content_from_doc(doc: MemoryBankDocument) -> str: - if isinstance(doc.content, URL): - async with httpx.AsyncClient() as client: - r = await client.get(doc.content.uri) - return r.text +class FaissIndex(EmbeddingIndex): + id_by_index: Dict[int, str] + chunk_by_index: Dict[int, str] - return interleaved_text_media_as_str(doc.content) + def __init__(self, dimension: int): + self.index = faiss.IndexFlatL2(dimension) + self.id_by_index = {} + self.chunk_by_index = {} + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + indexlen = len(self.id_by_index) + for i, chunk in enumerate(chunks): + self.chunk_by_index[indexlen + i] = chunk + print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}") + self.id_by_index[indexlen + i] = chunk.document_id -def make_overlapped_chunks( - text: str, window_len: int, overlap_len: int -) -> List[Tuple[str, int]]: - tokenizer = Tokenizer.get_instance() - tokens = tokenizer.encode(text, bos=False, eos=False) + self.index.add(np.array(embeddings).astype(np.float32)) - chunks = [] - for i in range(0, len(tokens), window_len - overlap_len): - toks = tokens[i : i + window_len] - chunk = tokenizer.decode(toks) - chunks.append((chunk, len(toks))) - - return chunks - - -@dataclass -class BankState: - bank: MemoryBank - index: Optional[faiss.IndexFlatL2] = None - doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict) - id_by_index: Dict[int, str] = field(default_factory=dict) - chunk_by_index: Dict[int, str] = field(default_factory=dict) - - async def insert_documents( - self, - model: "SentenceTransformer", - documents: List[MemoryBankDocument], - ) -> None: - tokenizer = Tokenizer.get_instance() - chunk_size = self.bank.config.chunk_size_in_tokens - - for doc in documents: - indexlen = len(self.id_by_index) - self.doc_by_id[doc.document_id] = doc - - content = await content_from_doc(doc) - chunks = make_overlapped_chunks( - content, - self.bank.config.chunk_size_in_tokens, - self.bank.config.overlap_size_in_tokens - or (self.bank.config.chunk_size_in_tokens // 4), - ) - embeddings = model.encode([x[0] for x in chunks]).astype(np.float32) - await self._ensure_index(embeddings.shape[1]) - - self.index.add(embeddings) - for i, chunk in enumerate(chunks): - self.chunk_by_index[indexlen + i] = Chunk( - content=chunk[0], - token_count=chunk[1], - document_id=doc.document_id, - ) - print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}") - self.id_by_index[indexlen + i] = doc.document_id - - async def query_documents( - self, - model: "SentenceTransformer", - query: InterleavedTextMedia, - params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - if params is None: - params = {} - k = params.get("max_chunks", 3) - - def _process(c) -> str: - if isinstance(c, str): - return c - else: - return "" - - if isinstance(query, list): - query_str = " ".join([_process(c) for c in query]) - else: - query_str = _process(query) - - query_vector = model.encode([query_str])[0] + async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: distances, indices = self.index.search( - query_vector.reshape(1, -1).astype(np.float32), k + embedding.reshape(1, -1).astype(np.float32), k ) chunks = [] @@ -119,17 +56,11 @@ class BankState: return QueryDocumentsResponse(chunks=chunks, scores=scores) - async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2: - if self.index is None: - self.index = faiss.IndexFlatL2(dimension) - return self.index - class FaissMemoryImpl(Memory): def __init__(self, config: FaissImplConfig) -> None: self.config = config - self.model = None - self.states = {} + self.cache = {} async def initialize(self) -> None: ... @@ -153,14 +84,15 @@ class FaissMemoryImpl(Memory): config=config, url=url, ) - state = BankState(bank=bank) - self.states[bank_id] = state + index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)) + self.cache[bank_id] = index return bank async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - if bank_id not in self.states: + index = self.cache.get(bank_id) + if index is None: return None - return self.states[bank_id].bank + return index.bank async def insert_documents( self, @@ -168,10 +100,11 @@ class FaissMemoryImpl(Memory): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - assert bank_id in self.states, f"Bank {bank_id} not found" - state = self.states[bank_id] + index = self.cache.get(bank_id) + if index is None: + raise ValueError(f"Bank {bank_id} not found") - await state.insert_documents(self.get_model(), documents) + await index.insert_documents(documents) async def query_documents( self, @@ -179,16 +112,8 @@ class FaissMemoryImpl(Memory): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - assert bank_id in self.states, f"Bank {bank_id} not found" - state = self.states[bank_id] + index = self.cache.get(bank_id) + if index is None: + raise ValueError(f"Bank {bank_id} not found") - return await state.query_documents(self.get_model(), query, params) - - def get_model(self) -> "SentenceTransformer": - from sentence_transformers import SentenceTransformer - - if self.model is None: - print("Loading sentence transformer") - self.model = SentenceTransformer("all-MiniLM-L6-v2") - - return self.model + return await index.query_documents(query, params) diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py index f8675c344..cc113d132 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_toolchain/memory/providers.py @@ -6,20 +6,38 @@ from typing import List -from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_toolchain.core.datatypes import * # noqa: F403 + +EMBEDDING_DEPS = [ + "blobfile", + "sentence-transformers", +] def available_memory_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.memory, - provider_id="meta-reference-faiss", - pip_packages=[ - "blobfile", - "faiss-cpu", - "sentence-transformers", - ], + provider_type="meta-reference-faiss", + pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_toolchain.memory.meta_reference.faiss", config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig", ), + remote_provider_spec( + api=Api.memory, + adapter=AdapterSpec( + adapter_id="chromadb", + pip_packages=EMBEDDING_DEPS + ["chromadb-client"], + module="llama_toolchain.memory.adapters.chroma", + ), + ), + remote_provider_spec( + api=Api.memory, + adapter=AdapterSpec( + adapter_id="pgvector", + pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"], + module="llama_toolchain.memory.adapters.pgvector", + config_class="llama_toolchain.memory.adapters.pgvector.PGVectorConfig", + ), + ), ] diff --git a/llama_toolchain/safety/providers.py b/llama_toolchain/safety/providers.py index dfacf3f67..8471ab139 100644 --- a/llama_toolchain/safety/providers.py +++ b/llama_toolchain/safety/providers.py @@ -13,7 +13,7 @@ def available_safety_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.safety, - provider_id="meta-reference", + provider_type="meta-reference", pip_packages=[ "accelerate", "codeshield", diff --git a/llama_toolchain/stack.py b/llama_toolchain/stack.py index 6ec05896d..875bc5802 100644 --- a/llama_toolchain/stack.py +++ b/llama_toolchain/stack.py @@ -11,7 +11,7 @@ from llama_toolchain.evaluations.api import * # noqa: F403 from llama_toolchain.inference.api import * # noqa: F403 from llama_toolchain.batch_inference.api import * # noqa: F403 from llama_toolchain.memory.api import * # noqa: F403 -from llama_toolchain.observability.api import * # noqa: F403 +from llama_toolchain.telemetry.api import * # noqa: F403 from llama_toolchain.post_training.api import * # noqa: F403 from llama_toolchain.reward_scoring.api import * # noqa: F403 from llama_toolchain.synthetic_data_generation.api import * # noqa: F403 @@ -24,7 +24,7 @@ class LlamaStack( RewardScoring, SyntheticDataGeneration, Datasets, - Observability, + Telemetry, PostTraining, Memory, Evaluations, diff --git a/llama_toolchain/observability/__init__.py b/llama_toolchain/telemetry/__init__.py similarity index 100% rename from llama_toolchain/observability/__init__.py rename to llama_toolchain/telemetry/__init__.py diff --git a/llama_toolchain/observability/api/__init__.py b/llama_toolchain/telemetry/api/__init__.py similarity index 100% rename from llama_toolchain/observability/api/__init__.py rename to llama_toolchain/telemetry/api/__init__.py diff --git a/llama_toolchain/observability/api/api.py b/llama_toolchain/telemetry/api/api.py similarity index 99% rename from llama_toolchain/observability/api/api.py rename to llama_toolchain/telemetry/api/api.py index 86a5cc703..ae784428b 100644 --- a/llama_toolchain/observability/api/api.py +++ b/llama_toolchain/telemetry/api/api.py @@ -134,7 +134,7 @@ class LogSearchRequest(BaseModel): filters: Optional[Dict[str, Any]] = None -class Observability(Protocol): +class Telemetry(Protocol): @webmethod(route="/experiments/create") def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ... diff --git a/requirements.txt b/requirements.txt index bf61af71b..720f84b79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ blobfile fire httpx huggingface-hub -llama-models +llama-models>=0.0.13 pydantic requests termcolor diff --git a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html index d417f02f3..f5f5fa154 100644 --- a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html +++ b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-04 10:28:38.779789" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-10 01:13:08.531639" }, "servers": [ { @@ -51,7 +51,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/BatchChatCompletionRequest" + "$ref": "#/components/schemas/BatchChatCompletionRequestWrapper" } } }, @@ -81,7 +81,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/BatchCompletionRequest" + "$ref": "#/components/schemas/BatchCompletionRequestWrapper" } } }, @@ -141,7 +141,7 @@ "200": { "description": "SSE-stream of these events.", "content": { - "application/json": { + "text/event-stream": { "schema": { "$ref": "#/components/schemas/ChatCompletionResponseStreamChunk" } @@ -157,7 +157,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ChatCompletionRequest" + "$ref": "#/components/schemas/ChatCompletionRequestWrapper" } } }, @@ -187,7 +187,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CompletionRequest" + "$ref": "#/components/schemas/CompletionRequestWrapper" } } }, @@ -277,7 +277,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AgenticSystemTurnCreateRequest" + "$ref": "#/components/schemas/AgenticSystemTurnCreateRequestWrapper" } } }, @@ -300,7 +300,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateDatasetRequest" + "$ref": "#/components/schemas/CreateDatasetRequestWrapper" } } }, @@ -323,14 +323,14 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateExperimentRequest" + "$ref": "#/components/schemas/CreateExperimentRequestWrapper" } } }, @@ -383,14 +383,14 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateRunRequest" + "$ref": "#/components/schemas/CreateRunRequestWrapper" } } }, @@ -572,7 +572,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequest" + "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequestWrapper" } } }, @@ -602,7 +602,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateSummarizationRequest" + "$ref": "#/components/schemas/EvaluateSummarizationRequestWrapper" } } }, @@ -632,7 +632,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/EvaluateTextGenerationRequest" + "$ref": "#/components/schemas/EvaluateTextGenerationRequestWrapper" } } }, @@ -784,7 +784,7 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [ { @@ -988,7 +988,7 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [ { @@ -1017,14 +1017,14 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/LogSearchRequest" + "$ref": "#/components/schemas/LogSearchRequestWrapper" } } }, @@ -1083,7 +1083,7 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [ { @@ -1242,7 +1242,7 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { @@ -1272,7 +1272,7 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [] } @@ -1305,14 +1305,14 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/LogMessagesRequest" + "$ref": "#/components/schemas/LogMessagesRequestWrapper" } } }, @@ -1328,14 +1328,14 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/LogMetricsRequest" + "$ref": "#/components/schemas/LogMetricsRequestWrapper" } } }, @@ -1365,7 +1365,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/PostTrainingRLHFRequest" + "$ref": "#/components/schemas/PostTrainingRLHFRequestWrapper" } } }, @@ -1425,7 +1425,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/RewardScoringRequest" + "$ref": "#/components/schemas/RewardScoringRequestWrapper" } } }, @@ -1455,7 +1455,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/PostTrainingSFTRequest" + "$ref": "#/components/schemas/PostTrainingSFTRequestWrapper" } } }, @@ -1485,7 +1485,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/SyntheticDataGenerationRequest" + "$ref": "#/components/schemas/SyntheticDataGenerationRequestWrapper" } } }, @@ -1531,14 +1531,14 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateExperimentRequest" + "$ref": "#/components/schemas/UpdateExperimentRequestWrapper" } } }, @@ -1561,14 +1561,14 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateRunRequest" + "$ref": "#/components/schemas/UpdateRunRequestWrapper" } } }, @@ -1591,14 +1591,14 @@ } }, "tags": [ - "Observability" + "Telemetry" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UploadArtifactRequest" + "$ref": "#/components/schemas/UploadArtifactRequestWrapper" } } }, @@ -2020,6 +2020,18 @@ "content" ] }, + "BatchChatCompletionRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/BatchChatCompletionRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "BatchChatCompletionResponse": { "type": "object", "properties": { @@ -2076,6 +2088,18 @@ "content_batch" ] }, + "BatchCompletionRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/BatchCompletionRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "BatchCompletionResponse": { "type": "object", "properties": { @@ -2174,6 +2198,18 @@ "messages" ] }, + "ChatCompletionRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/ChatCompletionRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "ChatCompletionResponseEvent": { "type": "object", "properties": { @@ -2316,6 +2352,18 @@ "content" ] }, + "CompletionRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/CompletionRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "CompletionResponseStreamChunk": { "type": "object", "properties": { @@ -2376,7 +2424,183 @@ "$ref": "#/components/schemas/FunctionCallToolDefinition" }, { - "$ref": "#/components/schemas/MemoryToolDefinition" + "type": "object", + "properties": { + "input_shields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ShieldDefinition" + } + }, + "output_shields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ShieldDefinition" + } + }, + "type": { + "type": "string", + "const": "memory" + }, + "memory_bank_configs": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "vector" + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "keyvalue" + }, + "keys": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type", + "keys" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "keyword" + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "graph" + }, + "entities": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type", + "entities" + ] + } + ] + } + }, + "query_generator_config": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "default" + }, + "sep": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "sep" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm" + }, + "model": { + "type": "string" + }, + "template": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "model", + "template" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "custom" + } + }, + "additionalProperties": false, + "required": [ + "type" + ] + } + ] + }, + "max_tokens_in_context": { + "type": "integer" + }, + "max_chunks": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "type", + "memory_bank_configs", + "query_generator_config", + "max_tokens_in_context", + "max_chunks" + ] } ] } @@ -2513,129 +2737,6 @@ "parameters" ] }, - "MemoryToolDefinition": { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ShieldDefinition" - } - }, - "output_shields": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ShieldDefinition" - } - }, - "type": { - "type": "string", - "const": "memory" - }, - "memory_bank_configs": { - "type": "array", - "items": { - "oneOf": [ - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "vector" - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "keyvalue" - }, - "keys": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type", - "keys" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "keyword" - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "graph" - }, - "entities": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type", - "entities" - ] - } - ] - } - }, - "max_tokens_in_context": { - "type": "integer" - }, - "max_chunks": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "type", - "memory_bank_configs", - "max_tokens_in_context", - "max_chunks" - ] - }, "OnViolationAction": { "type": "integer", "enum": [ @@ -2873,7 +2974,183 @@ "$ref": "#/components/schemas/FunctionCallToolDefinition" }, { - "$ref": "#/components/schemas/MemoryToolDefinition" + "type": "object", + "properties": { + "input_shields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ShieldDefinition" + } + }, + "output_shields": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ShieldDefinition" + } + }, + "type": { + "type": "string", + "const": "memory" + }, + "memory_bank_configs": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "vector" + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "keyvalue" + }, + "keys": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type", + "keys" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "keyword" + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "graph" + }, + "entities": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type", + "entities" + ] + } + ] + } + }, + "query_generator_config": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "default" + }, + "sep": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "sep" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm" + }, + "model": { + "type": "string" + }, + "template": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "model", + "template" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "custom" + } + }, + "additionalProperties": false, + "required": [ + "type" + ] + } + ] + }, + "max_tokens_in_context": { + "type": "integer" + }, + "max_chunks": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "type", + "memory_bank_configs", + "query_generator_config", + "max_tokens_in_context", + "max_chunks" + ] } ] } @@ -2952,6 +3229,18 @@ "mime_type" ] }, + "AgenticSystemTurnCreateRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/AgenticSystemTurnCreateRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "AgenticSystemTurnResponseEvent": { "type": "object", "properties": { @@ -3523,6 +3812,18 @@ "json" ] }, + "CreateDatasetRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/CreateDatasetRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "CreateExperimentRequest": { "type": "object", "properties": { @@ -3560,6 +3861,18 @@ "name" ] }, + "CreateExperimentRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/CreateExperimentRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "Experiment": { "type": "object", "properties": { @@ -3832,6 +4145,18 @@ "experiment_id" ] }, + "CreateRunRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/CreateRunRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "Run": { "type": "object", "properties": { @@ -4044,6 +4369,18 @@ ], "title": "Request to evaluate question answering." }, + "EvaluateQuestionAnsweringRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "EvaluationJob": { "type": "object", "properties": { @@ -4092,6 +4429,18 @@ ], "title": "Request to evaluate summarization." }, + "EvaluateSummarizationRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/EvaluateSummarizationRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "EvaluateTextGenerationRequest": { "type": "object", "properties": { @@ -4129,6 +4478,18 @@ ], "title": "Request to evaluate text generation." }, + "EvaluateTextGenerationRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/EvaluateTextGenerationRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "GetAgenticSystemSessionRequest": { "type": "object", "properties": { @@ -4414,6 +4775,18 @@ "query" ] }, + "LogSearchRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/LogSearchRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "Log": { "type": "object", "properties": { @@ -4673,6 +5046,18 @@ "logs" ] }, + "LogMessagesRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/LogMessagesRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "LogMetricsRequest": { "type": "object", "properties": { @@ -4692,6 +5077,18 @@ "metrics" ] }, + "LogMetricsRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/LogMetricsRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "DPOAlignmentConfig": { "type": "object", "properties": { @@ -4880,6 +5277,18 @@ "fsdp_cpu_offload" ] }, + "PostTrainingRLHFRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/PostTrainingRLHFRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "QueryDocumentsRequest": { "type": "object", "properties": { @@ -5048,6 +5457,18 @@ ], "title": "Request to score a reward function. A list of prompts and a list of responses per prompt." }, + "RewardScoringRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/RewardScoringRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "RewardScoringResponse": { "type": "object", "properties": { @@ -5333,6 +5754,18 @@ "alpha" ] }, + "PostTrainingSFTRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/PostTrainingSFTRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "SyntheticDataGenerationRequest": { "type": "object", "properties": { @@ -5378,6 +5811,18 @@ ], "title": "Request to generate synthetic data. A small batch of prompts and a filtering function" }, + "SyntheticDataGenerationRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/SyntheticDataGenerationRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "SyntheticDataGenerationResponse": { "type": "object", "properties": { @@ -5478,6 +5923,18 @@ "experiment_id" ] }, + "UpdateExperimentRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/UpdateExperimentRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "UpdateRunRequest": { "type": "object", "properties": { @@ -5522,6 +5979,18 @@ "run_id" ] }, + "UpdateRunRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/UpdateRunRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] + }, "UploadArtifactRequest": { "type": "object", "properties": { @@ -5571,6 +6040,18 @@ "artifact_type", "content" ] + }, + "UploadArtifactRequestWrapper": { + "type": "object", + "properties": { + "request": { + "$ref": "#/components/schemas/UploadArtifactRequest" + } + }, + "additionalProperties": false, + "required": [ + "request" + ] } }, "responses": {} @@ -5582,10 +6063,10 @@ ], "tags": [ { - "name": "Observability" + "name": "SyntheticDataGeneration" }, { - "name": "Inference" + "name": "RewardScoring" }, { "name": "Datasets" @@ -5597,19 +6078,19 @@ "name": "AgenticSystem" }, { - "name": "Evaluations" - }, - { - "name": "SyntheticDataGeneration" - }, - { - "name": "RewardScoring" + "name": "BatchInference" }, { "name": "PostTraining" }, { - "name": "BatchInference" + "name": "Evaluations" + }, + { + "name": "Telemetry" + }, + { + "name": "Inference" }, { "name": "BatchChatCompletionRequest", @@ -5667,6 +6148,10 @@ "name": "UserMessage", "description": "" }, + { + "name": "BatchChatCompletionRequestWrapper", + "description": "" + }, { "name": "BatchChatCompletionResponse", "description": "" @@ -5675,6 +6160,10 @@ "name": "BatchCompletionRequest", "description": "" }, + { + "name": "BatchCompletionRequestWrapper", + "description": "" + }, { "name": "BatchCompletionResponse", "description": "" @@ -5691,6 +6180,10 @@ "name": "ChatCompletionRequest", "description": "" }, + { + "name": "ChatCompletionRequestWrapper", + "description": "" + }, { "name": "ChatCompletionResponseEvent", "description": "Chat completion response event.\n\n" @@ -5719,6 +6212,10 @@ "name": "CompletionRequest", "description": "" }, + { + "name": "CompletionRequestWrapper", + "description": "" + }, { "name": "CompletionResponseStreamChunk", "description": "streamed completion response.\n\n" @@ -5743,10 +6240,6 @@ "name": "FunctionCallToolDefinition", "description": "" }, - { - "name": "MemoryToolDefinition", - "description": "" - }, { "name": "OnViolationAction", "description": "" @@ -5799,6 +6292,10 @@ "name": "Attachment", "description": "" }, + { + "name": "AgenticSystemTurnCreateRequestWrapper", + "description": "" + }, { "name": "AgenticSystemTurnResponseEvent", "description": "Streamed agent execution response.\n\n" @@ -5867,10 +6364,18 @@ "name": "TrainEvalDatasetColumnType", "description": "" }, + { + "name": "CreateDatasetRequestWrapper", + "description": "" + }, { "name": "CreateExperimentRequest", "description": "" }, + { + "name": "CreateExperimentRequestWrapper", + "description": "" + }, { "name": "Experiment", "description": "" @@ -5891,6 +6396,10 @@ "name": "CreateRunRequest", "description": "" }, + { + "name": "CreateRunRequestWrapper", + "description": "" + }, { "name": "Run", "description": "" @@ -5931,6 +6440,10 @@ "name": "EvaluateQuestionAnsweringRequest", "description": "Request to evaluate question answering.\n\n" }, + { + "name": "EvaluateQuestionAnsweringRequestWrapper", + "description": "" + }, { "name": "EvaluationJob", "description": "" @@ -5939,10 +6452,18 @@ "name": "EvaluateSummarizationRequest", "description": "Request to evaluate summarization.\n\n" }, + { + "name": "EvaluateSummarizationRequestWrapper", + "description": "" + }, { "name": "EvaluateTextGenerationRequest", "description": "Request to evaluate text generation.\n\n" }, + { + "name": "EvaluateTextGenerationRequestWrapper", + "description": "" + }, { "name": "GetAgenticSystemSessionRequest", "description": "" @@ -5987,6 +6508,10 @@ "name": "LogSearchRequest", "description": "" }, + { + "name": "LogSearchRequestWrapper", + "description": "" + }, { "name": "Log", "description": "" @@ -6027,10 +6552,18 @@ "name": "LogMessagesRequest", "description": "" }, + { + "name": "LogMessagesRequestWrapper", + "description": "" + }, { "name": "LogMetricsRequest", "description": "" }, + { + "name": "LogMetricsRequestWrapper", + "description": "" + }, { "name": "DPOAlignmentConfig", "description": "" @@ -6051,6 +6584,10 @@ "name": "TrainingConfig", "description": "" }, + { + "name": "PostTrainingRLHFRequestWrapper", + "description": "" + }, { "name": "QueryDocumentsRequest", "description": "" @@ -6067,6 +6604,10 @@ "name": "RewardScoringRequest", "description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n" }, + { + "name": "RewardScoringRequestWrapper", + "description": "" + }, { "name": "RewardScoringResponse", "description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" @@ -6099,10 +6640,18 @@ "name": "QLoraFinetuningConfig", "description": "" }, + { + "name": "PostTrainingSFTRequestWrapper", + "description": "" + }, { "name": "SyntheticDataGenerationRequest", "description": "Request to generate synthetic data. A small batch of prompts and a filtering function\n\n" }, + { + "name": "SyntheticDataGenerationRequestWrapper", + "description": "" + }, { "name": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" @@ -6115,13 +6664,25 @@ "name": "UpdateExperimentRequest", "description": "" }, + { + "name": "UpdateExperimentRequestWrapper", + "description": "" + }, { "name": "UpdateRunRequest", "description": "" }, + { + "name": "UpdateRunRequestWrapper", + "description": "" + }, { "name": "UploadArtifactRequest", "description": "" + }, + { + "name": "UploadArtifactRequestWrapper", + "description": "" } ], "x-tagGroups": [ @@ -6134,10 +6695,10 @@ "Evaluations", "Inference", "Memory", - "Observability", "PostTraining", "RewardScoring", - "SyntheticDataGeneration" + "SyntheticDataGeneration", + "Telemetry" ] }, { @@ -6148,6 +6709,7 @@ "AgenticSystemSessionCreateResponse", "AgenticSystemStepResponse", "AgenticSystemTurnCreateRequest", + "AgenticSystemTurnCreateRequestWrapper", "AgenticSystemTurnResponseEvent", "AgenticSystemTurnResponseStepCompletePayload", "AgenticSystemTurnResponseStepProgressPayload", @@ -6159,8 +6721,10 @@ "ArtifactType", "Attachment", "BatchChatCompletionRequest", + "BatchChatCompletionRequestWrapper", "BatchChatCompletionResponse", "BatchCompletionRequest", + "BatchCompletionRequestWrapper", "BatchCompletionResponse", "BraveSearchToolDefinition", "BuiltinShield", @@ -6168,6 +6732,7 @@ "CancelEvaluationJobRequest", "CancelTrainingJobRequest", "ChatCompletionRequest", + "ChatCompletionRequestWrapper", "ChatCompletionResponseEvent", "ChatCompletionResponseEventType", "ChatCompletionResponseStreamChunk", @@ -6175,13 +6740,17 @@ "CodeInterpreterToolDefinition", "CompletionMessage", "CompletionRequest", + "CompletionRequestWrapper", "CompletionResponseStreamChunk", "CreateAgenticSystemRequest", "CreateAgenticSystemSessionRequest", "CreateDatasetRequest", + "CreateDatasetRequestWrapper", "CreateExperimentRequest", + "CreateExperimentRequestWrapper", "CreateMemoryBankRequest", "CreateRunRequest", + "CreateRunRequestWrapper", "DPOAlignmentConfig", "DeleteAgenticSystemRequest", "DeleteAgenticSystemSessionRequest", @@ -6193,8 +6762,11 @@ "EmbeddingsRequest", "EmbeddingsResponse", "EvaluateQuestionAnsweringRequest", + "EvaluateQuestionAnsweringRequestWrapper", "EvaluateSummarizationRequest", + "EvaluateSummarizationRequestWrapper", "EvaluateTextGenerationRequest", + "EvaluateTextGenerationRequestWrapper", "EvaluationJob", "EvaluationJobArtifactsResponse", "EvaluationJobLogStream", @@ -6210,13 +6782,15 @@ "ListArtifactsRequest", "Log", "LogMessagesRequest", + "LogMessagesRequestWrapper", "LogMetricsRequest", + "LogMetricsRequestWrapper", "LogSearchRequest", + "LogSearchRequestWrapper", "LoraFinetuningConfig", "MemoryBank", "MemoryBankDocument", "MemoryRetrievalStep", - "MemoryToolDefinition", "Metric", "OnViolationAction", "OptimizerConfig", @@ -6227,7 +6801,9 @@ "PostTrainingJobStatus", "PostTrainingJobStatusResponse", "PostTrainingRLHFRequest", + "PostTrainingRLHFRequestWrapper", "PostTrainingSFTRequest", + "PostTrainingSFTRequestWrapper", "QLoraFinetuningConfig", "QueryDocumentsRequest", "QueryDocumentsResponse", @@ -6235,6 +6811,7 @@ "RestAPIExecutionConfig", "RestAPIMethod", "RewardScoringRequest", + "RewardScoringRequestWrapper", "RewardScoringResponse", "Run", "SamplingParams", @@ -6247,6 +6824,7 @@ "ShieldResponse", "StopReason", "SyntheticDataGenerationRequest", + "SyntheticDataGenerationRequestWrapper", "SyntheticDataGenerationResponse", "SystemMessage", "TokenLogProbs", @@ -6267,8 +6845,11 @@ "URL", "UpdateDocumentsRequest", "UpdateExperimentRequest", + "UpdateExperimentRequestWrapper", "UpdateRunRequest", + "UpdateRunRequestWrapper", "UploadArtifactRequest", + "UploadArtifactRequestWrapper", "UserMessage", "WolframAlphaToolDefinition" ] diff --git a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml index 7560f3c88..3c3475fff 100644 --- a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml +++ b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml @@ -30,7 +30,123 @@ components: - $ref: '#/components/schemas/PhotogenToolDefinition' - $ref: '#/components/schemas/CodeInterpreterToolDefinition' - $ref: '#/components/schemas/FunctionCallToolDefinition' - - $ref: '#/components/schemas/MemoryToolDefinition' + - additionalProperties: false + properties: + input_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + max_chunks: + type: integer + max_tokens_in_context: + type: integer + memory_bank_configs: + items: + oneOf: + - additionalProperties: false + properties: + bank_id: + type: string + type: + const: vector + type: string + required: + - bank_id + - type + type: object + - additionalProperties: false + properties: + bank_id: + type: string + keys: + items: + type: string + type: array + type: + const: keyvalue + type: string + required: + - bank_id + - type + - keys + type: object + - additionalProperties: false + properties: + bank_id: + type: string + type: + const: keyword + type: string + required: + - bank_id + - type + type: object + - additionalProperties: false + properties: + bank_id: + type: string + entities: + items: + type: string + type: array + type: + const: graph + type: string + required: + - bank_id + - type + - entities + type: object + type: array + output_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + query_generator_config: + oneOf: + - additionalProperties: false + properties: + sep: + type: string + type: + const: default + type: string + required: + - type + - sep + type: object + - additionalProperties: false + properties: + model: + type: string + template: + type: string + type: + const: llm + type: string + required: + - type + - model + - template + type: object + - additionalProperties: false + properties: + type: + const: custom + type: string + required: + - type + type: object + type: + const: memory + type: string + required: + - type + - memory_bank_configs + - query_generator_config + - max_tokens_in_context + - max_chunks + type: object type: array required: - model @@ -107,13 +223,137 @@ components: - $ref: '#/components/schemas/PhotogenToolDefinition' - $ref: '#/components/schemas/CodeInterpreterToolDefinition' - $ref: '#/components/schemas/FunctionCallToolDefinition' - - $ref: '#/components/schemas/MemoryToolDefinition' + - additionalProperties: false + properties: + input_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + max_chunks: + type: integer + max_tokens_in_context: + type: integer + memory_bank_configs: + items: + oneOf: + - additionalProperties: false + properties: + bank_id: + type: string + type: + const: vector + type: string + required: + - bank_id + - type + type: object + - additionalProperties: false + properties: + bank_id: + type: string + keys: + items: + type: string + type: array + type: + const: keyvalue + type: string + required: + - bank_id + - type + - keys + type: object + - additionalProperties: false + properties: + bank_id: + type: string + type: + const: keyword + type: string + required: + - bank_id + - type + type: object + - additionalProperties: false + properties: + bank_id: + type: string + entities: + items: + type: string + type: array + type: + const: graph + type: string + required: + - bank_id + - type + - entities + type: object + type: array + output_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + query_generator_config: + oneOf: + - additionalProperties: false + properties: + sep: + type: string + type: + const: default + type: string + required: + - type + - sep + type: object + - additionalProperties: false + properties: + model: + type: string + template: + type: string + type: + const: llm + type: string + required: + - type + - model + - template + type: object + - additionalProperties: false + properties: + type: + const: custom + type: string + required: + - type + type: object + type: + const: memory + type: string + required: + - type + - memory_bank_configs + - query_generator_config + - max_tokens_in_context + - max_chunks + type: object type: array required: - agent_id - session_id - messages type: object + AgenticSystemTurnCreateRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/AgenticSystemTurnCreateRequest' + required: + - request + type: object AgenticSystemTurnResponseEvent: additionalProperties: false properties: @@ -334,6 +574,14 @@ components: - model - messages_batch type: object + BatchChatCompletionRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/BatchChatCompletionRequest' + required: + - request + type: object BatchChatCompletionResponse: additionalProperties: false properties: @@ -369,6 +617,14 @@ components: - model - content_batch type: object + BatchCompletionRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/BatchCompletionRequest' + required: + - request + type: object BatchCompletionResponse: additionalProperties: false properties: @@ -464,6 +720,14 @@ components: - model - messages type: object + ChatCompletionRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/ChatCompletionRequest' + required: + - request + type: object ChatCompletionResponseEvent: additionalProperties: false properties: @@ -572,6 +836,14 @@ components: - model - content type: object + CompletionRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/CompletionRequest' + required: + - request + type: object CompletionResponseStreamChunk: additionalProperties: false properties: @@ -618,6 +890,14 @@ components: - dataset title: Request to create a dataset. type: object + CreateDatasetRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/CreateDatasetRequest' + required: + - request + type: object CreateExperimentRequest: additionalProperties: false properties: @@ -636,6 +916,14 @@ components: required: - name type: object + CreateExperimentRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/CreateExperimentRequest' + required: + - request + type: object CreateMemoryBankRequest: additionalProperties: false properties: @@ -707,6 +995,14 @@ components: required: - experiment_id type: object + CreateRunRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/CreateRunRequest' + required: + - request + type: object DPOAlignmentConfig: additionalProperties: false properties: @@ -872,6 +1168,14 @@ components: - metrics title: Request to evaluate question answering. type: object + EvaluateQuestionAnsweringRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/EvaluateQuestionAnsweringRequest' + required: + - request + type: object EvaluateSummarizationRequest: additionalProperties: false properties: @@ -898,6 +1202,14 @@ components: - metrics title: Request to evaluate summarization. type: object + EvaluateSummarizationRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/EvaluateSummarizationRequest' + required: + - request + type: object EvaluateTextGenerationRequest: additionalProperties: false properties: @@ -925,6 +1237,14 @@ components: - metrics title: Request to evaluate text generation. type: object + EvaluateTextGenerationRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/EvaluateTextGenerationRequest' + required: + - request + type: object EvaluationJob: additionalProperties: false properties: @@ -1138,6 +1458,14 @@ components: required: - logs type: object + LogMessagesRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/LogMessagesRequest' + required: + - request + type: object LogMetricsRequest: additionalProperties: false properties: @@ -1151,6 +1479,14 @@ components: - run_id - metrics type: object + LogMetricsRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/LogMetricsRequest' + required: + - request + type: object LogSearchRequest: additionalProperties: false properties: @@ -1169,6 +1505,14 @@ components: required: - query type: object + LogSearchRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/LogSearchRequest' + required: + - request + type: object LoraFinetuningConfig: additionalProperties: false properties: @@ -1310,88 +1654,6 @@ components: - memory_bank_ids - inserted_context type: object - MemoryToolDefinition: - additionalProperties: false - properties: - input_shields: - items: - $ref: '#/components/schemas/ShieldDefinition' - type: array - max_chunks: - type: integer - max_tokens_in_context: - type: integer - memory_bank_configs: - items: - oneOf: - - additionalProperties: false - properties: - bank_id: - type: string - type: - const: vector - type: string - required: - - bank_id - - type - type: object - - additionalProperties: false - properties: - bank_id: - type: string - keys: - items: - type: string - type: array - type: - const: keyvalue - type: string - required: - - bank_id - - type - - keys - type: object - - additionalProperties: false - properties: - bank_id: - type: string - type: - const: keyword - type: string - required: - - bank_id - - type - type: object - - additionalProperties: false - properties: - bank_id: - type: string - entities: - items: - type: string - type: array - type: - const: graph - type: string - required: - - bank_id - - type - - entities - type: object - type: array - output_shields: - items: - $ref: '#/components/schemas/ShieldDefinition' - type: array - type: - const: memory - type: string - required: - - type - - memory_bank_configs - - max_tokens_in_context - - max_chunks - type: object Metric: additionalProperties: false properties: @@ -1591,6 +1853,14 @@ components: - logger_config title: Request to finetune a model. type: object + PostTrainingRLHFRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/PostTrainingRLHFRequest' + required: + - request + type: object PostTrainingSFTRequest: additionalProperties: false properties: @@ -1646,6 +1916,14 @@ components: - logger_config title: Request to finetune a model. type: object + PostTrainingSFTRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/PostTrainingSFTRequest' + required: + - request + type: object QLoraFinetuningConfig: additionalProperties: false properties: @@ -1773,6 +2051,14 @@ components: title: Request to score a reward function. A list of prompts and a list of responses per prompt. type: object + RewardScoringRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/RewardScoringRequest' + required: + - request + type: object RewardScoringResponse: additionalProperties: false properties: @@ -1995,6 +2281,14 @@ components: title: Request to generate synthetic data. A small batch of prompts and a filtering function type: object + SyntheticDataGenerationRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/SyntheticDataGenerationRequest' + required: + - request + type: object SyntheticDataGenerationResponse: additionalProperties: false properties: @@ -2363,6 +2657,14 @@ components: required: - experiment_id type: object + UpdateExperimentRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/UpdateExperimentRequest' + required: + - request + type: object UpdateRunRequest: additionalProperties: false properties: @@ -2386,6 +2688,14 @@ components: required: - run_id type: object + UpdateRunRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/UpdateRunRequest' + required: + - request + type: object UploadArtifactRequest: additionalProperties: false properties: @@ -2414,6 +2724,14 @@ components: - artifact_type - content type: object + UploadArtifactRequestWrapper: + additionalProperties: false + properties: + request: + $ref: '#/components/schemas/UploadArtifactRequest' + required: + - request + type: object UserMessage: additionalProperties: false properties: @@ -2459,7 +2777,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-09-04 10:28:38.779789" + \ draft and subject to change.\n Generated at 2024-09-10 01:13:08.531639" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -2591,7 +2909,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/AgenticSystemTurnCreateRequest' + $ref: '#/components/schemas/AgenticSystemTurnCreateRequestWrapper' required: true responses: '200': @@ -2640,7 +2958,7 @@ paths: $ref: '#/components/schemas/Artifact' description: OK tags: - - Observability + - Telemetry /batch_inference/chat_completion: post: parameters: [] @@ -2648,7 +2966,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/BatchChatCompletionRequest' + $ref: '#/components/schemas/BatchChatCompletionRequestWrapper' required: true responses: '200': @@ -2666,7 +2984,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/BatchCompletionRequest' + $ref: '#/components/schemas/BatchCompletionRequestWrapper' required: true responses: '200': @@ -2684,7 +3002,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/CreateDatasetRequest' + $ref: '#/components/schemas/CreateDatasetRequestWrapper' required: true responses: '200': @@ -2806,7 +3124,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/EvaluateQuestionAnsweringRequest' + $ref: '#/components/schemas/EvaluateQuestionAnsweringRequestWrapper' required: true responses: '200': @@ -2824,7 +3142,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/EvaluateSummarizationRequest' + $ref: '#/components/schemas/EvaluateSummarizationRequestWrapper' required: true responses: '200': @@ -2842,7 +3160,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/EvaluateTextGenerationRequest' + $ref: '#/components/schemas/EvaluateTextGenerationRequestWrapper' required: true responses: '200': @@ -2870,7 +3188,7 @@ paths: $ref: '#/components/schemas/Artifact' description: OK tags: - - Observability + - Telemetry /experiments/artifacts/upload: post: parameters: [] @@ -2878,7 +3196,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/UploadArtifactRequest' + $ref: '#/components/schemas/UploadArtifactRequestWrapper' required: true responses: '200': @@ -2888,7 +3206,7 @@ paths: $ref: '#/components/schemas/Artifact' description: OK tags: - - Observability + - Telemetry /experiments/create: post: parameters: [] @@ -2896,7 +3214,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/CreateExperimentRequest' + $ref: '#/components/schemas/CreateExperimentRequestWrapper' required: true responses: '200': @@ -2906,7 +3224,7 @@ paths: $ref: '#/components/schemas/Experiment' description: OK tags: - - Observability + - Telemetry /experiments/create_run: post: parameters: [] @@ -2914,7 +3232,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/CreateRunRequest' + $ref: '#/components/schemas/CreateRunRequestWrapper' required: true responses: '200': @@ -2924,7 +3242,7 @@ paths: $ref: '#/components/schemas/Run' description: OK tags: - - Observability + - Telemetry /experiments/get: get: parameters: @@ -2941,7 +3259,7 @@ paths: $ref: '#/components/schemas/Experiment' description: OK tags: - - Observability + - Telemetry /experiments/list: get: parameters: [] @@ -2953,7 +3271,7 @@ paths: $ref: '#/components/schemas/Experiment' description: OK tags: - - Observability + - Telemetry /experiments/update: post: parameters: [] @@ -2961,7 +3279,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/UpdateExperimentRequest' + $ref: '#/components/schemas/UpdateExperimentRequestWrapper' required: true responses: '200': @@ -2971,7 +3289,7 @@ paths: $ref: '#/components/schemas/Experiment' description: OK tags: - - Observability + - Telemetry /inference/chat_completion: post: parameters: [] @@ -2979,12 +3297,12 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ChatCompletionRequest' + $ref: '#/components/schemas/ChatCompletionRequestWrapper' required: true responses: '200': content: - application/json: + text/event-stream: schema: $ref: '#/components/schemas/ChatCompletionResponseStreamChunk' description: SSE-stream of these events. @@ -2997,7 +3315,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/CompletionRequest' + $ref: '#/components/schemas/CompletionRequestWrapper' required: true responses: '200': @@ -3033,7 +3351,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/LogSearchRequest' + $ref: '#/components/schemas/LogSearchRequestWrapper' required: true responses: '200': @@ -3043,7 +3361,7 @@ paths: $ref: '#/components/schemas/Log' description: OK tags: - - Observability + - Telemetry /logging/log_messages: post: parameters: [] @@ -3051,13 +3369,13 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/LogMessagesRequest' + $ref: '#/components/schemas/LogMessagesRequestWrapper' required: true responses: '200': description: OK tags: - - Observability + - Telemetry /memory_bank/documents/delete: post: parameters: [] @@ -3292,7 +3610,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/PostTrainingRLHFRequest' + $ref: '#/components/schemas/PostTrainingRLHFRequestWrapper' required: true responses: '200': @@ -3310,7 +3628,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/PostTrainingSFTRequest' + $ref: '#/components/schemas/PostTrainingSFTRequestWrapper' required: true responses: '200': @@ -3328,7 +3646,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RewardScoringRequest' + $ref: '#/components/schemas/RewardScoringRequestWrapper' required: true responses: '200': @@ -3346,13 +3664,13 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/LogMetricsRequest' + $ref: '#/components/schemas/LogMetricsRequestWrapper' required: true responses: '200': description: OK tags: - - Observability + - Telemetry /runs/metrics: get: parameters: @@ -3369,7 +3687,7 @@ paths: $ref: '#/components/schemas/Metric' description: OK tags: - - Observability + - Telemetry /runs/update: post: parameters: [] @@ -3377,7 +3695,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/UpdateRunRequest' + $ref: '#/components/schemas/UpdateRunRequestWrapper' required: true responses: '200': @@ -3387,7 +3705,7 @@ paths: $ref: '#/components/schemas/Run' description: OK tags: - - Observability + - Telemetry /synthetic_data_generation/generate: post: parameters: [] @@ -3395,7 +3713,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SyntheticDataGenerationRequest' + $ref: '#/components/schemas/SyntheticDataGenerationRequestWrapper' required: true responses: '200': @@ -3411,16 +3729,16 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Observability -- name: Inference +- name: SyntheticDataGeneration +- name: RewardScoring - name: Datasets - name: Memory - name: AgenticSystem -- name: Evaluations -- name: SyntheticDataGeneration -- name: RewardScoring -- name: PostTraining - name: BatchInference +- name: PostTraining +- name: Evaluations +- name: Telemetry +- name: Inference - description: name: BatchChatCompletionRequest @@ -3463,12 +3781,18 @@ tags: name: ToolResponseMessage - description: name: UserMessage +- description: + name: BatchChatCompletionRequestWrapper - description: name: BatchChatCompletionResponse - description: name: BatchCompletionRequest +- description: + name: BatchCompletionRequestWrapper - description: name: BatchCompletionResponse @@ -3481,6 +3805,9 @@ tags: - description: name: ChatCompletionRequest +- description: + name: ChatCompletionRequestWrapper - description: 'Chat completion response event. @@ -3506,6 +3833,9 @@ tags: - description: name: CompletionRequest +- description: + name: CompletionRequestWrapper - description: 'streamed completion response. @@ -3525,9 +3855,6 @@ tags: - description: name: FunctionCallToolDefinition -- description: - name: MemoryToolDefinition - description: name: OnViolationAction @@ -3564,6 +3891,9 @@ tags: name: AgenticSystemTurnCreateRequest - description: name: Attachment +- description: + name: AgenticSystemTurnCreateRequestWrapper - description: 'Streamed agent execution response. @@ -3620,9 +3950,15 @@ tags: - description: name: TrainEvalDatasetColumnType +- description: + name: CreateDatasetRequestWrapper - description: name: CreateExperimentRequest +- description: + name: CreateExperimentRequestWrapper - description: name: Experiment - description: name: CreateRunRequest +- description: + name: CreateRunRequestWrapper - description: name: Run - description: ' name: EvaluateQuestionAnsweringRequest +- description: + name: EvaluateQuestionAnsweringRequestWrapper - description: name: EvaluationJob - description: 'Request to evaluate summarization. @@ -3678,12 +4020,18 @@ tags: ' name: EvaluateSummarizationRequest +- description: + name: EvaluateSummarizationRequestWrapper - description: 'Request to evaluate text generation. ' name: EvaluateTextGenerationRequest +- description: + name: EvaluateTextGenerationRequestWrapper - description: name: GetAgenticSystemSessionRequest @@ -3720,6 +4068,9 @@ tags: - description: name: LogSearchRequest +- description: + name: LogSearchRequestWrapper - description: name: Log - description: @@ -3756,9 +4107,15 @@ tags: - description: name: LogMessagesRequest +- description: + name: LogMessagesRequestWrapper - description: name: LogMetricsRequest +- description: + name: LogMetricsRequestWrapper - description: name: DPOAlignmentConfig @@ -3774,6 +4131,9 @@ tags: name: RLHFAlgorithm - description: name: TrainingConfig +- description: + name: PostTrainingRLHFRequestWrapper - description: name: QueryDocumentsRequest @@ -3789,6 +4149,9 @@ tags: ' name: RewardScoringRequest +- description: + name: RewardScoringRequestWrapper - description: 'Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold. @@ -3817,6 +4180,9 @@ tags: - description: name: QLoraFinetuningConfig +- description: + name: PostTrainingSFTRequestWrapper - description: 'Request to generate synthetic data. A small batch of prompts and a filtering function @@ -3824,6 +4190,9 @@ tags: ' name: SyntheticDataGenerationRequest +- description: + name: SyntheticDataGenerationRequestWrapper - description: 'Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold. @@ -3837,12 +4206,21 @@ tags: - description: name: UpdateExperimentRequest +- description: + name: UpdateExperimentRequestWrapper - description: name: UpdateRunRequest +- description: + name: UpdateRunRequestWrapper - description: name: UploadArtifactRequest +- description: + name: UploadArtifactRequestWrapper x-tagGroups: - name: Operations tags: @@ -3852,10 +4230,10 @@ x-tagGroups: - Evaluations - Inference - Memory - - Observability - PostTraining - RewardScoring - SyntheticDataGeneration + - Telemetry - name: Types tags: - AgentConfig @@ -3863,6 +4241,7 @@ x-tagGroups: - AgenticSystemSessionCreateResponse - AgenticSystemStepResponse - AgenticSystemTurnCreateRequest + - AgenticSystemTurnCreateRequestWrapper - AgenticSystemTurnResponseEvent - AgenticSystemTurnResponseStepCompletePayload - AgenticSystemTurnResponseStepProgressPayload @@ -3874,8 +4253,10 @@ x-tagGroups: - ArtifactType - Attachment - BatchChatCompletionRequest + - BatchChatCompletionRequestWrapper - BatchChatCompletionResponse - BatchCompletionRequest + - BatchCompletionRequestWrapper - BatchCompletionResponse - BraveSearchToolDefinition - BuiltinShield @@ -3883,6 +4264,7 @@ x-tagGroups: - CancelEvaluationJobRequest - CancelTrainingJobRequest - ChatCompletionRequest + - ChatCompletionRequestWrapper - ChatCompletionResponseEvent - ChatCompletionResponseEventType - ChatCompletionResponseStreamChunk @@ -3890,13 +4272,17 @@ x-tagGroups: - CodeInterpreterToolDefinition - CompletionMessage - CompletionRequest + - CompletionRequestWrapper - CompletionResponseStreamChunk - CreateAgenticSystemRequest - CreateAgenticSystemSessionRequest - CreateDatasetRequest + - CreateDatasetRequestWrapper - CreateExperimentRequest + - CreateExperimentRequestWrapper - CreateMemoryBankRequest - CreateRunRequest + - CreateRunRequestWrapper - DPOAlignmentConfig - DeleteAgenticSystemRequest - DeleteAgenticSystemSessionRequest @@ -3908,8 +4294,11 @@ x-tagGroups: - EmbeddingsRequest - EmbeddingsResponse - EvaluateQuestionAnsweringRequest + - EvaluateQuestionAnsweringRequestWrapper - EvaluateSummarizationRequest + - EvaluateSummarizationRequestWrapper - EvaluateTextGenerationRequest + - EvaluateTextGenerationRequestWrapper - EvaluationJob - EvaluationJobArtifactsResponse - EvaluationJobLogStream @@ -3925,13 +4314,15 @@ x-tagGroups: - ListArtifactsRequest - Log - LogMessagesRequest + - LogMessagesRequestWrapper - LogMetricsRequest + - LogMetricsRequestWrapper - LogSearchRequest + - LogSearchRequestWrapper - LoraFinetuningConfig - MemoryBank - MemoryBankDocument - MemoryRetrievalStep - - MemoryToolDefinition - Metric - OnViolationAction - OptimizerConfig @@ -3942,7 +4333,9 @@ x-tagGroups: - PostTrainingJobStatus - PostTrainingJobStatusResponse - PostTrainingRLHFRequest + - PostTrainingRLHFRequestWrapper - PostTrainingSFTRequest + - PostTrainingSFTRequestWrapper - QLoraFinetuningConfig - QueryDocumentsRequest - QueryDocumentsResponse @@ -3950,6 +4343,7 @@ x-tagGroups: - RestAPIExecutionConfig - RestAPIMethod - RewardScoringRequest + - RewardScoringRequestWrapper - RewardScoringResponse - Run - SamplingParams @@ -3962,6 +4356,7 @@ x-tagGroups: - ShieldResponse - StopReason - SyntheticDataGenerationRequest + - SyntheticDataGenerationRequestWrapper - SyntheticDataGenerationResponse - SystemMessage - TokenLogProbs @@ -3982,7 +4377,10 @@ x-tagGroups: - URL - UpdateDocumentsRequest - UpdateExperimentRequest + - UpdateExperimentRequestWrapper - UpdateRunRequest + - UpdateRunRequestWrapper - UploadArtifactRequest + - UploadArtifactRequestWrapper - UserMessage - WolframAlphaToolDefinition diff --git a/rfcs/openapi_generator/generate.py b/rfcs/openapi_generator/generate.py index ab9774e70..279389a47 100644 --- a/rfcs/openapi_generator/generate.py +++ b/rfcs/openapi_generator/generate.py @@ -35,7 +35,10 @@ from llama_toolchain.stack import LlamaStack # TODO: this should be fixed in the generator itself so it reads appropriate annotations -STREAMING_ENDPOINTS = ["/agentic_system/turn/create"] +STREAMING_ENDPOINTS = [ + "/agentic_system/turn/create", + "/inference/chat_completion", +] def patch_sse_stream_responses(spec: Specification): diff --git a/rfcs/openapi_generator/pyopenapi/generator.py b/rfcs/openapi_generator/pyopenapi/generator.py index e1450074b..a711d9f68 100644 --- a/rfcs/openapi_generator/pyopenapi/generator.py +++ b/rfcs/openapi_generator/pyopenapi/generator.py @@ -468,12 +468,14 @@ class Generator: builder = ContentBuilder(self.schema_builder) first = next(iter(op.request_params)) request_name, request_type = first + + from dataclasses import make_dataclass + if len(op.request_params) == 1 and "Request" in first[1].__name__: # TODO(ashwin): Undo the "Request" hack and this entire block eventually - request_name, request_type = first + request_name = first[1].__name__ + "Wrapper" + request_type = make_dataclass(request_name, op.request_params) else: - from dataclasses import make_dataclass - op_name = "".join(word.capitalize() for word in op.name.split("_")) request_name = f"{op_name}Request" request_type = make_dataclass(request_name, op.request_params) diff --git a/rfcs/openapi_generator/run_openapi_generator.sh b/rfcs/openapi_generator/run_openapi_generator.sh index cf4265ae5..1b2f979cc 100755 --- a/rfcs/openapi_generator/run_openapi_generator.sh +++ b/rfcs/openapi_generator/run_openapi_generator.sh @@ -28,4 +28,4 @@ if [ ${#missing_packages[@]} -ne 0 ]; then exit 1 fi -PYTHONPATH=$PYTHONPATH:../.. python3 -m rfcs.openapi_generator.generate $* +PYTHONPATH=$PYTHONPATH:../.. python -m rfcs.openapi_generator.generate $*