diff --git a/docs/cli_reference.md b/docs/cli_reference.md
index 7457aa45e..905539d42 100644
--- a/docs/cli_reference.md
+++ b/docs/cli_reference.md
@@ -248,51 +248,51 @@ llama stack list-distributions
```
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| Distribution ID | Providers | Description |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
-| | "inference": "meta-reference", | |
-| | "memory": "meta-reference-faiss", | |
-| | "safety": "meta-reference", | |
-| | "agentic_system": "meta-reference" | |
-| | } | |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| remote | { | Point to remote services for all llama stack APIs |
-| | "inference": "remote", | |
-| | "safety": "remote", | |
-| | "agentic_system": "remote", | |
-| | "memory": "remote" | |
-| | } | |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| local-ollama | { | Like local, but use ollama for running LLM inference |
-| | "inference": "remote::ollama", | |
-| | "safety": "meta-reference", | |
-| | "agentic_system": "meta-reference", | |
-| | "memory": "meta-reference-faiss" | |
-| | } | |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
-| | "inference": "remote::fireworks", | |
-| | "safety": "meta-reference", | |
-| | "agentic_system": "meta-reference", | |
-| | "memory": "meta-reference-faiss" | |
-| | } | |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
-| local-plus-together-inference | { | Use Together.ai for running LLM inference |
-| | "inference": "remote::together", | |
-| | "safety": "meta-reference", | |
-| | "agentic_system": "meta-reference", | |
-| | "memory": "meta-reference-faiss" | |
-| | } | |
-|--------------------------------|---------------------------------------|-------------------------------------------------------------------------------------------|
-| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https://huggingface.co/ |
-| | "inference": "remote::tgi", | inference-endpoints/dedicated)) for running LLM inference. When using HF Inference |
-| | "safety": "meta-reference", | Endpoints, you must provide the name of the endpoint. |
-| | "agentic_system": "meta-reference", | |
-| | "memory": "meta-reference-faiss" | |
-| | } | |
-+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+i+-------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| Distribution Type | Providers | Description |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
+| | "inference": "meta-reference", | |
+| | "memory": "meta-reference-faiss", | |
+| | "safety": "meta-reference", | |
+| | "agentic_system": "meta-reference" | |
+| | } | |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| remote | { | Point to remote services for all llama stack APIs |
+| | "inference": "remote", | |
+| | "safety": "remote", | |
+| | "agentic_system": "remote", | |
+| | "memory": "remote" | |
+| | } | |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local-ollama | { | Like local, but use ollama for running LLM inference |
+| | "inference": "remote::ollama", | |
+| | "safety": "meta-reference", | |
+| | "agentic_system": "meta-reference", | |
+| | "memory": "meta-reference-faiss" | |
+| | } | |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
+| | "inference": "remote::fireworks", | |
+| | "safety": "meta-reference", | |
+| | "agentic_system": "meta-reference", | |
+| | "memory": "meta-reference-faiss" | |
+| | } | |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local-plus-together-inference | { | Use Together.ai for running LLM inference |
+| | "inference": "remote::together", | |
+| | "safety": "meta-reference", | |
+| | "agentic_system": "meta-reference", | |
+| | "memory": "meta-reference-faiss" | |
+| | } | |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
+| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https:// |
+| | "inference": "remote::tgi", | huggingface.co/inference-endpoints/dedicated)) for running LLM |
+| | "safety": "meta-reference", | inference. When using HF Inference Endpoints, you must provide the |
+| | "agentic_system": "meta-reference", | name of the endpoint. |
+| | "memory": "meta-reference-faiss" | |
+| | } | |
++--------------------------------+---------------------------------------+----------------------------------------------------------------------+
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.
diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_toolchain/agentic_system/api/api.py
index e3f417918..68ec980e6 100644
--- a/llama_toolchain/agentic_system/api/api.py
+++ b/llama_toolchain/agentic_system/api/api.py
@@ -116,10 +116,47 @@ MemoryBankConfig = Annotated[
]
-@json_schema_type
+class MemoryQueryGenerator(Enum):
+ default = "default"
+ llm = "llm"
+ custom = "custom"
+
+
+class DefaultMemoryQueryGeneratorConfig(BaseModel):
+ type: Literal[MemoryQueryGenerator.default.value] = (
+ MemoryQueryGenerator.default.value
+ )
+ sep: str = " "
+
+
+class LLMMemoryQueryGeneratorConfig(BaseModel):
+ type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
+ model: str
+ template: str
+
+
+class CustomMemoryQueryGeneratorConfig(BaseModel):
+ type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
+
+
+MemoryQueryGeneratorConfig = Annotated[
+ Union[
+ DefaultMemoryQueryGeneratorConfig,
+ LLMMemoryQueryGeneratorConfig,
+ CustomMemoryQueryGeneratorConfig,
+ ],
+ Field(discriminator="type"),
+]
+
+
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
+ # This config defines how a query is generated using the messages
+ # for memory bank retrieval.
+ query_generator_config: MemoryQueryGeneratorConfig = Field(
+ default=DefaultMemoryQueryGeneratorConfig()
+ )
max_tokens_in_context: int = 4096
max_chunks: int = 10
diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py
index ed3145b1e..4d38e0032 100644
--- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py
+++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py
@@ -31,6 +31,7 @@ from llama_toolchain.tools.builtin import (
SingleMessageBuiltinTool,
)
+from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
@@ -664,7 +665,9 @@ class ChatAgent(ShieldRunnerMixin):
# (i.e., no prior turns uploaded an Attachment)
return None, []
- query = " ".join(m.content for m in messages)
+ query = await generate_rag_query(
+ memory.query_generator_config, messages, inference_api=self.inference_api
+ )
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
diff --git a/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py
new file mode 100644
index 000000000..afcc6afd1
--- /dev/null
+++ b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py
@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import List
+
+from jinja2 import Template
+from llama_models.llama3.api import * # noqa: F403
+
+
+from llama_toolchain.agentic_system.api import (
+ DefaultMemoryQueryGeneratorConfig,
+ LLMMemoryQueryGeneratorConfig,
+ MemoryQueryGenerator,
+ MemoryQueryGeneratorConfig,
+)
+from termcolor import cprint # noqa: F401
+from llama_toolchain.inference.api import * # noqa: F403
+
+
+async def generate_rag_query(
+ config: MemoryQueryGeneratorConfig,
+ messages: List[Message],
+ **kwargs,
+):
+ """
+ Generates a query that will be used for
+ retrieving relevant information from the memory bank.
+ """
+ if config.type == MemoryQueryGenerator.default.value:
+ query = await default_rag_query_generator(config, messages, **kwargs)
+ elif config.type == MemoryQueryGenerator.llm.value:
+ query = await llm_rag_query_generator(config, messages, **kwargs)
+ else:
+ raise NotImplementedError(f"Unsupported memory query generator {config.type}")
+ # cprint(f"Generated query >>>: {query}", color="green")
+ return query
+
+
+async def default_rag_query_generator(
+ config: DefaultMemoryQueryGeneratorConfig,
+ messages: List[Message],
+ **kwargs,
+):
+ return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
+
+
+async def llm_rag_query_generator(
+ config: LLMMemoryQueryGeneratorConfig,
+ messages: List[Message],
+ **kwargs,
+):
+ assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
+ inference_api = kwargs["inference_api"]
+
+ m_dict = {"messages": [m.model_dump() for m in messages]}
+
+ template = Template(config.template)
+ content = template.render(m_dict)
+
+ model = config.model
+ message = UserMessage(content=content)
+ response = inference_api.chat_completion(
+ ChatCompletionRequest(
+ model=model,
+ messages=[message],
+ stream=False,
+ )
+ )
+
+ async for chunk in response:
+ query = chunk.completion_message.content
+
+ return query
diff --git a/llama_toolchain/agentic_system/providers.py b/llama_toolchain/agentic_system/providers.py
index a722d9400..164df1a30 100644
--- a/llama_toolchain/agentic_system/providers.py
+++ b/llama_toolchain/agentic_system/providers.py
@@ -13,7 +13,7 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.agentic_system,
- provider_id="meta-reference",
+ provider_type="meta-reference",
pip_packages=[
"codeshield",
"matplotlib",
diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py
index c81a6d350..22bd4071f 100644
--- a/llama_toolchain/cli/stack/build.py
+++ b/llama_toolchain/cli/stack/build.py
@@ -52,7 +52,7 @@ class StackBuild(Subcommand):
BuildType,
)
- allowed_ids = [d.distribution_id for d in available_distribution_specs()]
+ allowed_ids = [d.distribution_type for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
@@ -101,7 +101,7 @@ class StackBuild(Subcommand):
api_inputs.append(
ApiInput(
api=api,
- provider=provider_spec.provider_id,
+ provider=provider_spec.provider_type,
)
)
docker_image = None
@@ -115,11 +115,11 @@ class StackBuild(Subcommand):
self.parser.error(f"Could not find distribution {args.distribution}")
return
- for api, provider_id in dist.providers.items():
+ for api, provider_type in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
- provider=provider_id,
+ provider=provider_type,
)
)
docker_image = dist.docker_image
@@ -128,6 +128,6 @@ class StackBuild(Subcommand):
api_inputs,
build_type=BuildType(args.type),
name=args.name,
- distribution_id=args.distribution,
+ distribution_type=args.distribution,
docker_image=docker_image,
)
diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py
index 70ff4a7f0..658380f4d 100644
--- a/llama_toolchain/cli/stack/configure.py
+++ b/llama_toolchain/cli/stack/configure.py
@@ -36,7 +36,7 @@ class StackConfigure(Subcommand):
)
from llama_toolchain.core.package import BuildType
- allowed_ids = [d.distribution_id for d in available_distribution_specs()]
+ allowed_ids = [d.distribution_type for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
@@ -84,7 +84,7 @@ def configure_llama_distribution(config_file: Path) -> None:
if config.providers:
cprint(
- f"Configuration already exists for {config.distribution_id}. Will overwrite...",
+ f"Configuration already exists for {config.distribution_type}. Will overwrite...",
"yellow",
attrs=["bold"],
)
diff --git a/llama_toolchain/cli/stack/list_distributions.py b/llama_toolchain/cli/stack/list_distributions.py
index c4d529157..557b8c33c 100644
--- a/llama_toolchain/cli/stack/list_distributions.py
+++ b/llama_toolchain/cli/stack/list_distributions.py
@@ -33,7 +33,7 @@ class StackListDistributions(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
- "Distribution ID",
+ "Distribution Type",
"Providers",
"Description",
]
@@ -43,7 +43,7 @@ class StackListDistributions(Subcommand):
providers = {k.value: v for k, v in spec.providers.items()}
rows.append(
[
- spec.distribution_id,
+ spec.distribution_type,
json.dumps(providers, indent=2),
spec.description,
]
diff --git a/llama_toolchain/cli/stack/list_providers.py b/llama_toolchain/cli/stack/list_providers.py
index 29602d889..fdf4ab054 100644
--- a/llama_toolchain/cli/stack/list_providers.py
+++ b/llama_toolchain/cli/stack/list_providers.py
@@ -41,7 +41,7 @@ class StackListProviders(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
- "Provider ID",
+ "Provider Type",
"PIP Package Dependencies",
]
@@ -49,7 +49,7 @@ class StackListProviders(Subcommand):
for spec in providers_for_api.values():
rows.append(
[
- spec.provider_id,
+ spec.provider_type,
",".join(spec.pip_packages),
]
)
diff --git a/llama_toolchain/cli/stack/run.py b/llama_toolchain/cli/stack/run.py
index 68853db35..1568ed820 100644
--- a/llama_toolchain/cli/stack/run.py
+++ b/llama_toolchain/cli/stack/run.py
@@ -80,7 +80,7 @@ class StackRun(Subcommand):
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
- if not config.distribution_id:
+ if not config.distribution_type:
raise ValueError("Build config appears to be corrupt.")
if config.docker_image:
diff --git a/llama_toolchain/core/build_conda_env.sh b/llama_toolchain/core/build_conda_env.sh
index 1e8c002f2..e5b1ca539 100755
--- a/llama_toolchain/core/build_conda_env.sh
+++ b/llama_toolchain/core/build_conda_env.sh
@@ -20,12 +20,12 @@ fi
set -euo pipefail
if [ "$#" -ne 3 ]; then
- echo "Usage: $0 " >&2
- echo "Example: $0 mybuild 'numpy pandas scipy'" >&2
+ echo "Usage: $0 " >&2
+ echo "Example: $0 mybuild 'numpy pandas scipy'" >&2
exit 1
fi
-distribution_id="$1"
+distribution_type="$1"
build_name="$2"
env_name="llamastack-$build_name"
pip_dependencies="$3"
@@ -117,4 +117,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies"
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
-$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type conda_env
+$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type conda_env
diff --git a/llama_toolchain/core/build_container.sh b/llama_toolchain/core/build_container.sh
index ec2ca8a0c..e5349cd08 100755
--- a/llama_toolchain/core/build_container.sh
+++ b/llama_toolchain/core/build_container.sh
@@ -5,12 +5,12 @@ LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then
- echo "Usage: $0
- echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn'
+ echo "Usage: $0
+ echo "Example: $0 distribution_type my-fastapi-app python:3.9-slim 'fastapi uvicorn'
exit 1
fi
-distribution_id=$1
+distribution_type=$1
build_name="$2"
image_name="llamastack-$build_name"
docker_base=$3
@@ -110,4 +110,4 @@ set +x
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
echo "You can run it with: podman run -p 8000:8000 $image_name"
-$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type container
+$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type container
diff --git a/llama_toolchain/core/configure.py b/llama_toolchain/core/configure.py
index 7f9aa0140..252358a52 100644
--- a/llama_toolchain/core/configure.py
+++ b/llama_toolchain/core/configure.py
@@ -21,14 +21,14 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
for api_str, stub_config in existing_configs.items():
api = Api(api_str)
providers = all_providers[api]
- provider_id = stub_config["provider_id"]
- if provider_id not in providers:
+ provider_type = stub_config["provider_type"]
+ if provider_type not in providers:
raise ValueError(
- f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
+ f"Unknown provider `{provider_type}` is not available for API `{api_str}`"
)
- provider_spec = providers[provider_id]
- cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
+ provider_spec = providers[provider_type]
+ cprint(f"Configuring API: {api_str} ({provider_type})", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
try:
@@ -43,7 +43,7 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
print("")
provider_configs[api_str] = {
- "provider_id": provider_id,
+ "provider_type": provider_type,
**provider_config.dict(),
}
diff --git a/llama_toolchain/core/datatypes.py b/llama_toolchain/core/datatypes.py
index cbdda51d4..138d20941 100644
--- a/llama_toolchain/core/datatypes.py
+++ b/llama_toolchain/core/datatypes.py
@@ -31,7 +31,7 @@ class ApiEndpoint(BaseModel):
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
- provider_id: str
+ provider_type: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
@@ -100,7 +100,7 @@ class RemoteProviderConfig(BaseModel):
return url.rstrip("/")
-def remote_provider_id(adapter_id: str) -> str:
+def remote_provider_type(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@@ -141,22 +141,22 @@ def remote_provider_spec(
if adapter and adapter.config_class
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
)
- provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
+ provider_type = remote_provider_type(adapter.adapter_id) if adapter else "remote"
return RemoteProviderSpec(
- api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
+ api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
)
@json_schema_type
class DistributionSpec(BaseModel):
- distribution_id: str
+ distribution_type: str
description: str
docker_image: Optional[str] = None
providers: Dict[Api, str] = Field(
default_factory=dict,
- description="Provider IDs for each of the APIs provided by this distribution",
+ description="Provider Types for each of the APIs provided by this distribution",
)
@@ -171,7 +171,7 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
this could be just a hash
""",
)
- distribution_id: Optional[str] = None
+ distribution_type: Optional[str] = None
docker_image: Optional[str] = Field(
default=None,
diff --git a/llama_toolchain/core/distribution.py b/llama_toolchain/core/distribution.py
index 4c50189c0..89e1d7793 100644
--- a/llama_toolchain/core/distribution.py
+++ b/llama_toolchain/core/distribution.py
@@ -83,18 +83,18 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
inference_providers_by_id = {
- a.provider_id: a for a in available_inference_providers()
+ a.provider_type: a for a in available_inference_providers()
}
- safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
+ safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
- a.provider_id: a for a in available_agentic_system_providers()
+ a.provider_type: a for a in available_agentic_system_providers()
}
ret = {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
- Api.memory: {a.provider_id: a for a in available_memory_providers()},
+ Api.memory: {a.provider_type: a for a in available_memory_providers()},
}
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)
diff --git a/llama_toolchain/core/distribution_registry.py b/llama_toolchain/core/distribution_registry.py
index 9413e1374..855fc6300 100644
--- a/llama_toolchain/core/distribution_registry.py
+++ b/llama_toolchain/core/distribution_registry.py
@@ -14,7 +14,7 @@ from .datatypes import * # noqa: F403
def available_distribution_specs() -> List[DistributionSpec]:
return [
DistributionSpec(
- distribution_id="local",
+ distribution_type="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
providers={
Api.inference: "meta-reference",
@@ -24,35 +24,35 @@ def available_distribution_specs() -> List[DistributionSpec]:
},
),
DistributionSpec(
- distribution_id="remote",
+ distribution_type="remote",
description="Point to remote services for all llama stack APIs",
providers={x: "remote" for x in Api},
),
DistributionSpec(
- distribution_id="local-ollama",
+ distribution_type="local-ollama",
description="Like local, but use ollama for running LLM inference",
providers={
- Api.inference: remote_provider_id("ollama"),
+ Api.inference: remote_provider_type("ollama"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
- distribution_id="local-plus-fireworks-inference",
+ distribution_type="local-plus-fireworks-inference",
description="Use Fireworks.ai for running LLM inference",
providers={
- Api.inference: remote_provider_id("fireworks"),
+ Api.inference: remote_provider_type("fireworks"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
- distribution_id="local-plus-together-inference",
+ distribution_type="local-plus-together-inference",
description="Use Together.ai for running LLM inference",
providers={
- Api.inference: remote_provider_id("together"),
+ Api.inference: remote_provider_type("together"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
@@ -72,8 +72,8 @@ def available_distribution_specs() -> List[DistributionSpec]:
@lru_cache()
-def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
+def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
- if spec.distribution_id == distribution_id:
+ if spec.distribution_type == distribution_type:
return spec
return None
diff --git a/llama_toolchain/core/package.py b/llama_toolchain/core/package.py
index 72bd93152..ab4346a71 100644
--- a/llama_toolchain/core/package.py
+++ b/llama_toolchain/core/package.py
@@ -46,13 +46,13 @@ def build_package(
api_inputs: List[ApiInput],
build_type: BuildType,
name: str,
- distribution_id: Optional[str] = None,
+ distribution_type: Optional[str] = None,
docker_image: Optional[str] = None,
):
- if not distribution_id:
- distribution_id = "adhoc"
+ if not distribution_type:
+ distribution_type = "adhoc"
- build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor()
+ build_dir = BUILDS_BASE_DIR / distribution_type / build_type.descriptor()
os.makedirs(build_dir, exist_ok=True)
package_name = name.replace("::", "-")
@@ -79,7 +79,7 @@ def build_package(
if provider.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
- stub_config[api.value] = {"provider_id": api_input.provider}
+ stub_config[api.value] = {"provider_type": api_input.provider}
if package_file.exists():
cprint(
@@ -92,7 +92,7 @@ def build_package(
c.providers[api_str] = new_config
else:
existing_config = c.providers[api_str]
- if existing_config["provider_id"] != new_config["provider_id"]:
+ if existing_config["provider_type"] != new_config["provider_type"]:
cprint(
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
color="yellow",
@@ -105,7 +105,7 @@ def build_package(
providers=stub_config,
)
- c.distribution_id = distribution_id
+ c.distribution_type = distribution_type
c.docker_image = package_name if build_type == BuildType.container else None
c.conda_env = package_name if build_type == BuildType.conda_env else None
@@ -119,7 +119,7 @@ def build_package(
)
args = [
script,
- distribution_id,
+ distribution_type,
package_name,
package_deps.docker_image,
" ".join(package_deps.pip_packages),
@@ -130,7 +130,7 @@ def build_package(
)
args = [
script,
- distribution_id,
+ distribution_type,
package_name,
" ".join(package_deps.pip_packages),
]
diff --git a/llama_toolchain/core/server.py b/llama_toolchain/core/server.py
index 4de84b726..8c7ab10a7 100644
--- a/llama_toolchain/core/server.py
+++ b/llama_toolchain/core/server.py
@@ -284,13 +284,13 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
- provider_id = provider_config["provider_id"]
- if provider_id not in providers:
+ provider_type = provider_config["provider_type"]
+ if provider_type not in providers:
raise ValueError(
- f"Unknown provider `{provider_id}` is not available for API `{api}`"
+ f"Unknown provider `{provider_type}` is not available for API `{api}`"
)
- provider_specs[api] = providers[provider_id]
+ provider_specs[api] = providers[provider_type]
impls = resolve_impls(provider_specs, config)
diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py
index e9f6b4072..f313de3fd 100644
--- a/llama_toolchain/inference/providers.py
+++ b/llama_toolchain/inference/providers.py
@@ -13,7 +13,7 @@ def available_inference_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,
- provider_id="meta-reference",
+ provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
diff --git a/llama_toolchain/memory/adapters/chroma/__init__.py b/llama_toolchain/memory/adapters/chroma/__init__.py
new file mode 100644
index 000000000..c90a8e8ac
--- /dev/null
+++ b/llama_toolchain/memory/adapters/chroma/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from llama_toolchain.core.datatypes import RemoteProviderConfig
+
+
+async def get_adapter_impl(config: RemoteProviderConfig, _deps):
+ from .chroma import ChromaMemoryAdapter
+
+ impl = ChromaMemoryAdapter(config.url)
+ await impl.initialize()
+ return impl
diff --git a/llama_toolchain/memory/adapters/chroma/chroma.py b/llama_toolchain/memory/adapters/chroma/chroma.py
new file mode 100644
index 000000000..f4952cd0e
--- /dev/null
+++ b/llama_toolchain/memory/adapters/chroma/chroma.py
@@ -0,0 +1,165 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import json
+import uuid
+from typing import List
+from urllib.parse import urlparse
+
+import chromadb
+from numpy.typing import NDArray
+
+from llama_toolchain.memory.api import * # noqa: F403
+
+
+from llama_toolchain.memory.common.vector_store import BankWithIndex, EmbeddingIndex
+
+
+class ChromaIndex(EmbeddingIndex):
+ def __init__(self, client: chromadb.AsyncHttpClient, collection):
+ self.client = client
+ self.collection = collection
+
+ async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
+ assert len(chunks) == len(
+ embeddings
+ ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
+
+ for i, chunk in enumerate(chunks):
+ print(f"Adding chunk #{i} tokens={chunk.token_count}")
+
+ await self.collection.add(
+ documents=[chunk.json() for chunk in chunks],
+ embeddings=embeddings,
+ ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
+ )
+
+ async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
+ results = await self.collection.query(
+ query_embeddings=[embedding.tolist()],
+ n_results=k,
+ include=["documents", "distances"],
+ )
+ distances = results["distances"][0]
+ documents = results["documents"][0]
+
+ chunks = []
+ scores = []
+ for dist, doc in zip(distances, documents):
+ try:
+ doc = json.loads(doc)
+ chunk = Chunk(**doc)
+ except Exception:
+ import traceback
+
+ traceback.print_exc()
+ print(f"Failed to parse document: {doc}")
+ continue
+
+ chunks.append(chunk)
+ scores.append(1.0 / float(dist))
+
+ return QueryDocumentsResponse(chunks=chunks, scores=scores)
+
+
+class ChromaMemoryAdapter(Memory):
+ def __init__(self, url: str) -> None:
+ print(f"Initializing ChromaMemoryAdapter with url: {url}")
+ url = url.rstrip("/")
+ parsed = urlparse(url)
+
+ if parsed.path and parsed.path != "/":
+ raise ValueError("URL should not contain a path")
+
+ self.host = parsed.hostname
+ self.port = parsed.port
+
+ self.client = None
+ self.cache = {}
+
+ async def initialize(self) -> None:
+ try:
+ print(f"Connecting to Chroma server at: {self.host}:{self.port}")
+ self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ raise RuntimeError("Could not connect to Chroma server") from e
+
+ async def shutdown(self) -> None:
+ pass
+
+ async def create_memory_bank(
+ self,
+ name: str,
+ config: MemoryBankConfig,
+ url: Optional[URL] = None,
+ ) -> MemoryBank:
+ bank_id = str(uuid.uuid4())
+ bank = MemoryBank(
+ bank_id=bank_id,
+ name=name,
+ config=config,
+ url=url,
+ )
+ collection = await self.client.create_collection(
+ name=bank_id,
+ metadata={"bank": bank.json()},
+ )
+ bank_index = BankWithIndex(
+ bank=bank, index=ChromaIndex(self.client, collection)
+ )
+ self.cache[bank_id] = bank_index
+ return bank
+
+ async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
+ bank_index = await self._get_and_cache_bank_index(bank_id)
+ if bank_index is None:
+ return None
+ return bank_index.bank
+
+ async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
+ if bank_id in self.cache:
+ return self.cache[bank_id]
+
+ collections = await self.client.list_collections()
+ for collection in collections:
+ if collection.name == bank_id:
+ print(collection.metadata)
+ bank = MemoryBank(**json.loads(collection.metadata["bank"]))
+ index = BankWithIndex(
+ bank=bank,
+ index=ChromaIndex(self.client, collection),
+ )
+ self.cache[bank_id] = index
+ return index
+
+ return None
+
+ async def insert_documents(
+ self,
+ bank_id: str,
+ documents: List[MemoryBankDocument],
+ ttl_seconds: Optional[int] = None,
+ ) -> None:
+ index = await self._get_and_cache_bank_index(bank_id)
+ if not index:
+ raise ValueError(f"Bank {bank_id} not found")
+
+ await index.insert_documents(documents)
+
+ async def query_documents(
+ self,
+ bank_id: str,
+ query: InterleavedTextMedia,
+ params: Optional[Dict[str, Any]] = None,
+ ) -> QueryDocumentsResponse:
+ index = await self._get_and_cache_bank_index(bank_id)
+ if not index:
+ raise ValueError(f"Bank {bank_id} not found")
+
+ return await index.query_documents(query, params)
diff --git a/llama_toolchain/memory/adapters/pgvector/__init__.py b/llama_toolchain/memory/adapters/pgvector/__init__.py
new file mode 100644
index 000000000..4ac30452f
--- /dev/null
+++ b/llama_toolchain/memory/adapters/pgvector/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .config import PGVectorConfig
+
+
+async def get_adapter_impl(config: PGVectorConfig, _deps):
+ from .pgvector import PGVectorMemoryAdapter
+
+ impl = PGVectorMemoryAdapter(config)
+ await impl.initialize()
+ return impl
diff --git a/llama_toolchain/memory/adapters/pgvector/config.py b/llama_toolchain/memory/adapters/pgvector/config.py
new file mode 100644
index 000000000..87b2f4a3b
--- /dev/null
+++ b/llama_toolchain/memory/adapters/pgvector/config.py
@@ -0,0 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from llama_models.schema_utils import json_schema_type
+from pydantic import BaseModel, Field
+
+
+@json_schema_type
+class PGVectorConfig(BaseModel):
+ host: str = Field(default="localhost")
+ port: int = Field(default=5432)
+ db: str
+ user: str
+ password: str
diff --git a/llama_toolchain/memory/adapters/pgvector/pgvector.py b/llama_toolchain/memory/adapters/pgvector/pgvector.py
new file mode 100644
index 000000000..930d7720f
--- /dev/null
+++ b/llama_toolchain/memory/adapters/pgvector/pgvector.py
@@ -0,0 +1,234 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import uuid
+
+from typing import List, Tuple
+
+import psycopg2
+from numpy.typing import NDArray
+from psycopg2 import sql
+from psycopg2.extras import execute_values, Json
+from pydantic import BaseModel
+from llama_toolchain.memory.api import * # noqa: F403
+
+
+from llama_toolchain.memory.common.vector_store import (
+ ALL_MINILM_L6_V2_DIMENSION,
+ BankWithIndex,
+ EmbeddingIndex,
+)
+
+from .config import PGVectorConfig
+
+
+def check_extension_version(cur):
+ cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
+ result = cur.fetchone()
+ return result[0] if result else None
+
+
+def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
+ query = sql.SQL(
+ """
+ INSERT INTO metadata_store (key, data)
+ VALUES %s
+ ON CONFLICT (key) DO UPDATE
+ SET data = EXCLUDED.data
+ """
+ )
+
+ values = [(key, Json(model.dict())) for key, model in keys_models]
+ execute_values(cur, query, values, template="(%s, %s)")
+
+
+def load_models(cur, keys: List[str], cls):
+ query = "SELECT key, data FROM metadata_store"
+ if keys:
+ placeholders = ",".join(["%s"] * len(keys))
+ query += f" WHERE key IN ({placeholders})"
+ cur.execute(query, keys)
+ else:
+ cur.execute(query)
+
+ rows = cur.fetchall()
+ return [cls(**row["data"]) for row in rows]
+
+
+class PGVectorIndex(EmbeddingIndex):
+ def __init__(self, bank: MemoryBank, dimension: int, cursor):
+ self.cursor = cursor
+ self.table_name = f"vector_store_{bank.name}"
+
+ self.cursor.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {self.table_name} (
+ id TEXT PRIMARY KEY,
+ document JSONB,
+ embedding vector({dimension})
+ )
+ """
+ )
+
+ async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
+ assert len(chunks) == len(
+ embeddings
+ ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
+
+ values = []
+ for i, chunk in enumerate(chunks):
+ print(f"Adding chunk #{i} tokens={chunk.token_count}")
+ values.append(
+ (
+ f"{chunk.document_id}:chunk-{i}",
+ Json(chunk.dict()),
+ embeddings[i].tolist(),
+ )
+ )
+
+ query = sql.SQL(
+ f"""
+ INSERT INTO {self.table_name} (id, document, embedding)
+ VALUES %s
+ ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
+ """
+ )
+ execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
+
+ async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
+ self.cursor.execute(
+ f"""
+ SELECT document, embedding <-> %s::vector AS distance
+ FROM {self.table_name}
+ ORDER BY distance
+ LIMIT %s
+ """,
+ (embedding.tolist(), k),
+ )
+ results = self.cursor.fetchall()
+
+ chunks = []
+ scores = []
+ for doc, dist in results:
+ chunks.append(Chunk(**doc))
+ scores.append(1.0 / float(dist))
+
+ return QueryDocumentsResponse(chunks=chunks, scores=scores)
+
+
+class PGVectorMemoryAdapter(Memory):
+ def __init__(self, config: PGVectorConfig) -> None:
+ print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
+ self.config = config
+ self.cursor = None
+ self.conn = None
+ self.cache = {}
+
+ async def initialize(self) -> None:
+ try:
+ self.conn = psycopg2.connect(
+ host=self.config.host,
+ port=self.config.port,
+ database=self.config.db,
+ user=self.config.user,
+ password=self.config.password,
+ )
+ self.cursor = self.conn.cursor()
+
+ version = check_extension_version(self.cursor)
+ if version:
+ print(f"Vector extension version: {version}")
+ else:
+ raise RuntimeError("Vector extension is not installed.")
+
+ self.cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS metadata_store (
+ key TEXT PRIMARY KEY,
+ data JSONB
+ )
+ """
+ )
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ raise RuntimeError("Could not connect to PGVector database server") from e
+
+ async def shutdown(self) -> None:
+ pass
+
+ async def create_memory_bank(
+ self,
+ name: str,
+ config: MemoryBankConfig,
+ url: Optional[URL] = None,
+ ) -> MemoryBank:
+ bank_id = str(uuid.uuid4())
+ bank = MemoryBank(
+ bank_id=bank_id,
+ name=name,
+ config=config,
+ url=url,
+ )
+ upsert_models(
+ self.cursor,
+ [
+ (bank.bank_id, bank),
+ ],
+ )
+ index = BankWithIndex(
+ bank=bank,
+ index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
+ )
+ self.cache[bank_id] = index
+ return bank
+
+ async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
+ bank_index = await self._get_and_cache_bank_index(bank_id)
+ if bank_index is None:
+ return None
+ return bank_index.bank
+
+ async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
+ if bank_id in self.cache:
+ return self.cache[bank_id]
+
+ banks = load_models(self.cursor, [bank_id], MemoryBank)
+ if not banks:
+ return None
+
+ bank = banks[0]
+ index = BankWithIndex(
+ bank=bank,
+ index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
+ )
+ self.cache[bank_id] = index
+ return index
+
+ async def insert_documents(
+ self,
+ bank_id: str,
+ documents: List[MemoryBankDocument],
+ ttl_seconds: Optional[int] = None,
+ ) -> None:
+ index = await self._get_and_cache_bank_index(bank_id)
+ if not index:
+ raise ValueError(f"Bank {bank_id} not found")
+
+ await index.insert_documents(documents)
+
+ async def query_documents(
+ self,
+ bank_id: str,
+ query: InterleavedTextMedia,
+ params: Optional[Dict[str, Any]] = None,
+ ) -> QueryDocumentsResponse:
+ index = await self._get_and_cache_bank_index(bank_id)
+ if not index:
+ raise ValueError(f"Bank {bank_id} not found")
+
+ return await index.query_documents(query, params)
diff --git a/llama_toolchain/memory/common/vector_store.py b/llama_toolchain/memory/common/vector_store.py
new file mode 100644
index 000000000..154deea18
--- /dev/null
+++ b/llama_toolchain/memory/common/vector_store.py
@@ -0,0 +1,120 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+import httpx
+import numpy as np
+from numpy.typing import NDArray
+
+from llama_models.llama3.api.datatypes import * # noqa: F403
+from llama_models.llama3.api.tokenizer import Tokenizer
+
+from llama_toolchain.memory.api import * # noqa: F403
+
+
+ALL_MINILM_L6_V2_DIMENSION = 384
+
+EMBEDDING_MODEL = None
+
+
+def get_embedding_model() -> "SentenceTransformer":
+ global EMBEDDING_MODEL
+
+ if EMBEDDING_MODEL is None:
+ print("Loading sentence transformer")
+
+ from sentence_transformers import SentenceTransformer
+
+ EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
+
+ return EMBEDDING_MODEL
+
+
+async def content_from_doc(doc: MemoryBankDocument) -> str:
+ if isinstance(doc.content, URL):
+ async with httpx.AsyncClient() as client:
+ r = await client.get(doc.content.uri)
+ return r.text
+
+ return interleaved_text_media_as_str(doc.content)
+
+
+def make_overlapped_chunks(
+ document_id: str, text: str, window_len: int, overlap_len: int
+) -> List[Chunk]:
+ tokenizer = Tokenizer.get_instance()
+ tokens = tokenizer.encode(text, bos=False, eos=False)
+
+ chunks = []
+ for i in range(0, len(tokens), window_len - overlap_len):
+ toks = tokens[i : i + window_len]
+ chunk = tokenizer.decode(toks)
+ chunks.append(
+ Chunk(content=chunk, token_count=len(toks), document_id=document_id)
+ )
+
+ return chunks
+
+
+class EmbeddingIndex(ABC):
+ @abstractmethod
+ async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
+ raise NotImplementedError()
+
+ @abstractmethod
+ async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
+ raise NotImplementedError()
+
+
+@dataclass
+class BankWithIndex:
+ bank: MemoryBank
+ index: EmbeddingIndex
+
+ async def insert_documents(
+ self,
+ documents: List[MemoryBankDocument],
+ ) -> None:
+ model = get_embedding_model()
+ for doc in documents:
+ content = await content_from_doc(doc)
+ chunks = make_overlapped_chunks(
+ doc.document_id,
+ content,
+ self.bank.config.chunk_size_in_tokens,
+ self.bank.config.overlap_size_in_tokens
+ or (self.bank.config.chunk_size_in_tokens // 4),
+ )
+ embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
+
+ await self.index.add_chunks(chunks, embeddings)
+
+ async def query_documents(
+ self,
+ query: InterleavedTextMedia,
+ params: Optional[Dict[str, Any]] = None,
+ ) -> QueryDocumentsResponse:
+ if params is None:
+ params = {}
+ k = params.get("max_chunks", 3)
+
+ def _process(c) -> str:
+ if isinstance(c, str):
+ return c
+ else:
+ return ""
+
+ if isinstance(query, list):
+ query_str = " ".join([_process(c) for c in query])
+ else:
+ query_str = _process(query)
+
+ model = get_embedding_model()
+ query_vector = model.encode([query_str])[0].astype(np.float32)
+ return await self.index.query(query_vector, k)
diff --git a/llama_toolchain/memory/meta_reference/faiss/faiss.py b/llama_toolchain/memory/meta_reference/faiss/faiss.py
index 422674939..807aa208f 100644
--- a/llama_toolchain/memory/meta_reference/faiss/faiss.py
+++ b/llama_toolchain/memory/meta_reference/faiss/faiss.py
@@ -5,108 +5,45 @@
# the root directory of this source tree.
import uuid
-from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional, Tuple
+
+from typing import Any, Dict, List, Optional
import faiss
-import httpx
import numpy as np
+from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
-from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.memory.api import * # noqa: F403
+from llama_toolchain.memory.common.vector_store import (
+ ALL_MINILM_L6_V2_DIMENSION,
+ BankWithIndex,
+ EmbeddingIndex,
+)
from .config import FaissImplConfig
-async def content_from_doc(doc: MemoryBankDocument) -> str:
- if isinstance(doc.content, URL):
- async with httpx.AsyncClient() as client:
- r = await client.get(doc.content.uri)
- return r.text
+class FaissIndex(EmbeddingIndex):
+ id_by_index: Dict[int, str]
+ chunk_by_index: Dict[int, str]
- return interleaved_text_media_as_str(doc.content)
+ def __init__(self, dimension: int):
+ self.index = faiss.IndexFlatL2(dimension)
+ self.id_by_index = {}
+ self.chunk_by_index = {}
+ async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
+ indexlen = len(self.id_by_index)
+ for i, chunk in enumerate(chunks):
+ self.chunk_by_index[indexlen + i] = chunk
+ print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
+ self.id_by_index[indexlen + i] = chunk.document_id
-def make_overlapped_chunks(
- text: str, window_len: int, overlap_len: int
-) -> List[Tuple[str, int]]:
- tokenizer = Tokenizer.get_instance()
- tokens = tokenizer.encode(text, bos=False, eos=False)
+ self.index.add(np.array(embeddings).astype(np.float32))
- chunks = []
- for i in range(0, len(tokens), window_len - overlap_len):
- toks = tokens[i : i + window_len]
- chunk = tokenizer.decode(toks)
- chunks.append((chunk, len(toks)))
-
- return chunks
-
-
-@dataclass
-class BankState:
- bank: MemoryBank
- index: Optional[faiss.IndexFlatL2] = None
- doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
- id_by_index: Dict[int, str] = field(default_factory=dict)
- chunk_by_index: Dict[int, str] = field(default_factory=dict)
-
- async def insert_documents(
- self,
- model: "SentenceTransformer",
- documents: List[MemoryBankDocument],
- ) -> None:
- tokenizer = Tokenizer.get_instance()
- chunk_size = self.bank.config.chunk_size_in_tokens
-
- for doc in documents:
- indexlen = len(self.id_by_index)
- self.doc_by_id[doc.document_id] = doc
-
- content = await content_from_doc(doc)
- chunks = make_overlapped_chunks(
- content,
- self.bank.config.chunk_size_in_tokens,
- self.bank.config.overlap_size_in_tokens
- or (self.bank.config.chunk_size_in_tokens // 4),
- )
- embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
- await self._ensure_index(embeddings.shape[1])
-
- self.index.add(embeddings)
- for i, chunk in enumerate(chunks):
- self.chunk_by_index[indexlen + i] = Chunk(
- content=chunk[0],
- token_count=chunk[1],
- document_id=doc.document_id,
- )
- print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
- self.id_by_index[indexlen + i] = doc.document_id
-
- async def query_documents(
- self,
- model: "SentenceTransformer",
- query: InterleavedTextMedia,
- params: Optional[Dict[str, Any]] = None,
- ) -> QueryDocumentsResponse:
- if params is None:
- params = {}
- k = params.get("max_chunks", 3)
-
- def _process(c) -> str:
- if isinstance(c, str):
- return c
- else:
- return ""
-
- if isinstance(query, list):
- query_str = " ".join([_process(c) for c in query])
- else:
- query_str = _process(query)
-
- query_vector = model.encode([query_str])[0]
+ async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
distances, indices = self.index.search(
- query_vector.reshape(1, -1).astype(np.float32), k
+ embedding.reshape(1, -1).astype(np.float32), k
)
chunks = []
@@ -119,17 +56,11 @@ class BankState:
return QueryDocumentsResponse(chunks=chunks, scores=scores)
- async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
- if self.index is None:
- self.index = faiss.IndexFlatL2(dimension)
- return self.index
-
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
- self.model = None
- self.states = {}
+ self.cache = {}
async def initialize(self) -> None: ...
@@ -153,14 +84,15 @@ class FaissMemoryImpl(Memory):
config=config,
url=url,
)
- state = BankState(bank=bank)
- self.states[bank_id] = state
+ index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
+ self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
- if bank_id not in self.states:
+ index = self.cache.get(bank_id)
+ if index is None:
return None
- return self.states[bank_id].bank
+ return index.bank
async def insert_documents(
self,
@@ -168,10 +100,11 @@ class FaissMemoryImpl(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
- assert bank_id in self.states, f"Bank {bank_id} not found"
- state = self.states[bank_id]
+ index = self.cache.get(bank_id)
+ if index is None:
+ raise ValueError(f"Bank {bank_id} not found")
- await state.insert_documents(self.get_model(), documents)
+ await index.insert_documents(documents)
async def query_documents(
self,
@@ -179,16 +112,8 @@ class FaissMemoryImpl(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
- assert bank_id in self.states, f"Bank {bank_id} not found"
- state = self.states[bank_id]
+ index = self.cache.get(bank_id)
+ if index is None:
+ raise ValueError(f"Bank {bank_id} not found")
- return await state.query_documents(self.get_model(), query, params)
-
- def get_model(self) -> "SentenceTransformer":
- from sentence_transformers import SentenceTransformer
-
- if self.model is None:
- print("Loading sentence transformer")
- self.model = SentenceTransformer("all-MiniLM-L6-v2")
-
- return self.model
+ return await index.query_documents(query, params)
diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py
index f8675c344..cc113d132 100644
--- a/llama_toolchain/memory/providers.py
+++ b/llama_toolchain/memory/providers.py
@@ -6,20 +6,38 @@
from typing import List
-from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
+from llama_toolchain.core.datatypes import * # noqa: F403
+
+EMBEDDING_DEPS = [
+ "blobfile",
+ "sentence-transformers",
+]
def available_memory_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
- provider_id="meta-reference-faiss",
- pip_packages=[
- "blobfile",
- "faiss-cpu",
- "sentence-transformers",
- ],
+ provider_type="meta-reference-faiss",
+ pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_toolchain.memory.meta_reference.faiss",
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
),
+ remote_provider_spec(
+ api=Api.memory,
+ adapter=AdapterSpec(
+ adapter_id="chromadb",
+ pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
+ module="llama_toolchain.memory.adapters.chroma",
+ ),
+ ),
+ remote_provider_spec(
+ api=Api.memory,
+ adapter=AdapterSpec(
+ adapter_id="pgvector",
+ pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
+ module="llama_toolchain.memory.adapters.pgvector",
+ config_class="llama_toolchain.memory.adapters.pgvector.PGVectorConfig",
+ ),
+ ),
]
diff --git a/llama_toolchain/safety/providers.py b/llama_toolchain/safety/providers.py
index dfacf3f67..8471ab139 100644
--- a/llama_toolchain/safety/providers.py
+++ b/llama_toolchain/safety/providers.py
@@ -13,7 +13,7 @@ def available_safety_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.safety,
- provider_id="meta-reference",
+ provider_type="meta-reference",
pip_packages=[
"accelerate",
"codeshield",
diff --git a/llama_toolchain/stack.py b/llama_toolchain/stack.py
index 6ec05896d..875bc5802 100644
--- a/llama_toolchain/stack.py
+++ b/llama_toolchain/stack.py
@@ -11,7 +11,7 @@ from llama_toolchain.evaluations.api import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.batch_inference.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
-from llama_toolchain.observability.api import * # noqa: F403
+from llama_toolchain.telemetry.api import * # noqa: F403
from llama_toolchain.post_training.api import * # noqa: F403
from llama_toolchain.reward_scoring.api import * # noqa: F403
from llama_toolchain.synthetic_data_generation.api import * # noqa: F403
@@ -24,7 +24,7 @@ class LlamaStack(
RewardScoring,
SyntheticDataGeneration,
Datasets,
- Observability,
+ Telemetry,
PostTraining,
Memory,
Evaluations,
diff --git a/llama_toolchain/observability/__init__.py b/llama_toolchain/telemetry/__init__.py
similarity index 100%
rename from llama_toolchain/observability/__init__.py
rename to llama_toolchain/telemetry/__init__.py
diff --git a/llama_toolchain/observability/api/__init__.py b/llama_toolchain/telemetry/api/__init__.py
similarity index 100%
rename from llama_toolchain/observability/api/__init__.py
rename to llama_toolchain/telemetry/api/__init__.py
diff --git a/llama_toolchain/observability/api/api.py b/llama_toolchain/telemetry/api/api.py
similarity index 99%
rename from llama_toolchain/observability/api/api.py
rename to llama_toolchain/telemetry/api/api.py
index 86a5cc703..ae784428b 100644
--- a/llama_toolchain/observability/api/api.py
+++ b/llama_toolchain/telemetry/api/api.py
@@ -134,7 +134,7 @@ class LogSearchRequest(BaseModel):
filters: Optional[Dict[str, Any]] = None
-class Observability(Protocol):
+class Telemetry(Protocol):
@webmethod(route="/experiments/create")
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ...
diff --git a/requirements.txt b/requirements.txt
index bf61af71b..720f84b79 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
-llama-models
+llama-models>=0.0.13
pydantic
requests
termcolor
diff --git a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html
index d417f02f3..f5f5fa154 100644
--- a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html
+++ b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-04 10:28:38.779789"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-10 01:13:08.531639"
},
"servers": [
{
@@ -51,7 +51,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/BatchChatCompletionRequest"
+ "$ref": "#/components/schemas/BatchChatCompletionRequestWrapper"
}
}
},
@@ -81,7 +81,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/BatchCompletionRequest"
+ "$ref": "#/components/schemas/BatchCompletionRequestWrapper"
}
}
},
@@ -141,7 +141,7 @@
"200": {
"description": "SSE-stream of these events.",
"content": {
- "application/json": {
+ "text/event-stream": {
"schema": {
"$ref": "#/components/schemas/ChatCompletionResponseStreamChunk"
}
@@ -157,7 +157,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/ChatCompletionRequest"
+ "$ref": "#/components/schemas/ChatCompletionRequestWrapper"
}
}
},
@@ -187,7 +187,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/CompletionRequest"
+ "$ref": "#/components/schemas/CompletionRequestWrapper"
}
}
},
@@ -277,7 +277,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/AgenticSystemTurnCreateRequest"
+ "$ref": "#/components/schemas/AgenticSystemTurnCreateRequestWrapper"
}
}
},
@@ -300,7 +300,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/CreateDatasetRequest"
+ "$ref": "#/components/schemas/CreateDatasetRequestWrapper"
}
}
},
@@ -323,14 +323,14 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/CreateExperimentRequest"
+ "$ref": "#/components/schemas/CreateExperimentRequestWrapper"
}
}
},
@@ -383,14 +383,14 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/CreateRunRequest"
+ "$ref": "#/components/schemas/CreateRunRequestWrapper"
}
}
},
@@ -572,7 +572,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequest"
+ "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequestWrapper"
}
}
},
@@ -602,7 +602,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/EvaluateSummarizationRequest"
+ "$ref": "#/components/schemas/EvaluateSummarizationRequestWrapper"
}
}
},
@@ -632,7 +632,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/EvaluateTextGenerationRequest"
+ "$ref": "#/components/schemas/EvaluateTextGenerationRequestWrapper"
}
}
},
@@ -784,7 +784,7 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [
{
@@ -988,7 +988,7 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [
{
@@ -1017,14 +1017,14 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/LogSearchRequest"
+ "$ref": "#/components/schemas/LogSearchRequestWrapper"
}
}
},
@@ -1083,7 +1083,7 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [
{
@@ -1242,7 +1242,7 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
@@ -1272,7 +1272,7 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": []
}
@@ -1305,14 +1305,14 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/LogMessagesRequest"
+ "$ref": "#/components/schemas/LogMessagesRequestWrapper"
}
}
},
@@ -1328,14 +1328,14 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/LogMetricsRequest"
+ "$ref": "#/components/schemas/LogMetricsRequestWrapper"
}
}
},
@@ -1365,7 +1365,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/PostTrainingRLHFRequest"
+ "$ref": "#/components/schemas/PostTrainingRLHFRequestWrapper"
}
}
},
@@ -1425,7 +1425,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/RewardScoringRequest"
+ "$ref": "#/components/schemas/RewardScoringRequestWrapper"
}
}
},
@@ -1455,7 +1455,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/PostTrainingSFTRequest"
+ "$ref": "#/components/schemas/PostTrainingSFTRequestWrapper"
}
}
},
@@ -1485,7 +1485,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/SyntheticDataGenerationRequest"
+ "$ref": "#/components/schemas/SyntheticDataGenerationRequestWrapper"
}
}
},
@@ -1531,14 +1531,14 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/UpdateExperimentRequest"
+ "$ref": "#/components/schemas/UpdateExperimentRequestWrapper"
}
}
},
@@ -1561,14 +1561,14 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/UpdateRunRequest"
+ "$ref": "#/components/schemas/UpdateRunRequestWrapper"
}
}
},
@@ -1591,14 +1591,14 @@
}
},
"tags": [
- "Observability"
+ "Telemetry"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/UploadArtifactRequest"
+ "$ref": "#/components/schemas/UploadArtifactRequestWrapper"
}
}
},
@@ -2020,6 +2020,18 @@
"content"
]
},
+ "BatchChatCompletionRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/BatchChatCompletionRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"BatchChatCompletionResponse": {
"type": "object",
"properties": {
@@ -2076,6 +2088,18 @@
"content_batch"
]
},
+ "BatchCompletionRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/BatchCompletionRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"BatchCompletionResponse": {
"type": "object",
"properties": {
@@ -2174,6 +2198,18 @@
"messages"
]
},
+ "ChatCompletionRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/ChatCompletionRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"ChatCompletionResponseEvent": {
"type": "object",
"properties": {
@@ -2316,6 +2352,18 @@
"content"
]
},
+ "CompletionRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/CompletionRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"CompletionResponseStreamChunk": {
"type": "object",
"properties": {
@@ -2376,7 +2424,183 @@
"$ref": "#/components/schemas/FunctionCallToolDefinition"
},
{
- "$ref": "#/components/schemas/MemoryToolDefinition"
+ "type": "object",
+ "properties": {
+ "input_shields": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ShieldDefinition"
+ }
+ },
+ "output_shields": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ShieldDefinition"
+ }
+ },
+ "type": {
+ "type": "string",
+ "const": "memory"
+ },
+ "memory_bank_configs": {
+ "type": "array",
+ "items": {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "bank_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "vector"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "bank_id",
+ "type"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "bank_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "keyvalue"
+ },
+ "keys": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "bank_id",
+ "type",
+ "keys"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "bank_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "keyword"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "bank_id",
+ "type"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "bank_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "graph"
+ },
+ "entities": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "bank_id",
+ "type",
+ "entities"
+ ]
+ }
+ ]
+ }
+ },
+ "query_generator_config": {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "default"
+ },
+ "sep": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "sep"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "llm"
+ },
+ "model": {
+ "type": "string"
+ },
+ "template": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "model",
+ "template"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "custom"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type"
+ ]
+ }
+ ]
+ },
+ "max_tokens_in_context": {
+ "type": "integer"
+ },
+ "max_chunks": {
+ "type": "integer"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "memory_bank_configs",
+ "query_generator_config",
+ "max_tokens_in_context",
+ "max_chunks"
+ ]
}
]
}
@@ -2513,129 +2737,6 @@
"parameters"
]
},
- "MemoryToolDefinition": {
- "type": "object",
- "properties": {
- "input_shields": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/ShieldDefinition"
- }
- },
- "output_shields": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/ShieldDefinition"
- }
- },
- "type": {
- "type": "string",
- "const": "memory"
- },
- "memory_bank_configs": {
- "type": "array",
- "items": {
- "oneOf": [
- {
- "type": "object",
- "properties": {
- "bank_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "const": "vector"
- }
- },
- "additionalProperties": false,
- "required": [
- "bank_id",
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "bank_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "const": "keyvalue"
- },
- "keys": {
- "type": "array",
- "items": {
- "type": "string"
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "bank_id",
- "type",
- "keys"
- ]
- },
- {
- "type": "object",
- "properties": {
- "bank_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "const": "keyword"
- }
- },
- "additionalProperties": false,
- "required": [
- "bank_id",
- "type"
- ]
- },
- {
- "type": "object",
- "properties": {
- "bank_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "const": "graph"
- },
- "entities": {
- "type": "array",
- "items": {
- "type": "string"
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "bank_id",
- "type",
- "entities"
- ]
- }
- ]
- }
- },
- "max_tokens_in_context": {
- "type": "integer"
- },
- "max_chunks": {
- "type": "integer"
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "memory_bank_configs",
- "max_tokens_in_context",
- "max_chunks"
- ]
- },
"OnViolationAction": {
"type": "integer",
"enum": [
@@ -2873,7 +2974,183 @@
"$ref": "#/components/schemas/FunctionCallToolDefinition"
},
{
- "$ref": "#/components/schemas/MemoryToolDefinition"
+ "type": "object",
+ "properties": {
+ "input_shields": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ShieldDefinition"
+ }
+ },
+ "output_shields": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ShieldDefinition"
+ }
+ },
+ "type": {
+ "type": "string",
+ "const": "memory"
+ },
+ "memory_bank_configs": {
+ "type": "array",
+ "items": {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "bank_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "vector"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "bank_id",
+ "type"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "bank_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "keyvalue"
+ },
+ "keys": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "bank_id",
+ "type",
+ "keys"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "bank_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "keyword"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "bank_id",
+ "type"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "bank_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "graph"
+ },
+ "entities": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "bank_id",
+ "type",
+ "entities"
+ ]
+ }
+ ]
+ }
+ },
+ "query_generator_config": {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "default"
+ },
+ "sep": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "sep"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "llm"
+ },
+ "model": {
+ "type": "string"
+ },
+ "template": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "model",
+ "template"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "custom"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type"
+ ]
+ }
+ ]
+ },
+ "max_tokens_in_context": {
+ "type": "integer"
+ },
+ "max_chunks": {
+ "type": "integer"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "memory_bank_configs",
+ "query_generator_config",
+ "max_tokens_in_context",
+ "max_chunks"
+ ]
}
]
}
@@ -2952,6 +3229,18 @@
"mime_type"
]
},
+ "AgenticSystemTurnCreateRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/AgenticSystemTurnCreateRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"AgenticSystemTurnResponseEvent": {
"type": "object",
"properties": {
@@ -3523,6 +3812,18 @@
"json"
]
},
+ "CreateDatasetRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/CreateDatasetRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"CreateExperimentRequest": {
"type": "object",
"properties": {
@@ -3560,6 +3861,18 @@
"name"
]
},
+ "CreateExperimentRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/CreateExperimentRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"Experiment": {
"type": "object",
"properties": {
@@ -3832,6 +4145,18 @@
"experiment_id"
]
},
+ "CreateRunRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/CreateRunRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"Run": {
"type": "object",
"properties": {
@@ -4044,6 +4369,18 @@
],
"title": "Request to evaluate question answering."
},
+ "EvaluateQuestionAnsweringRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/EvaluateQuestionAnsweringRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"EvaluationJob": {
"type": "object",
"properties": {
@@ -4092,6 +4429,18 @@
],
"title": "Request to evaluate summarization."
},
+ "EvaluateSummarizationRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/EvaluateSummarizationRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"EvaluateTextGenerationRequest": {
"type": "object",
"properties": {
@@ -4129,6 +4478,18 @@
],
"title": "Request to evaluate text generation."
},
+ "EvaluateTextGenerationRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/EvaluateTextGenerationRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"GetAgenticSystemSessionRequest": {
"type": "object",
"properties": {
@@ -4414,6 +4775,18 @@
"query"
]
},
+ "LogSearchRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/LogSearchRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"Log": {
"type": "object",
"properties": {
@@ -4673,6 +5046,18 @@
"logs"
]
},
+ "LogMessagesRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/LogMessagesRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"LogMetricsRequest": {
"type": "object",
"properties": {
@@ -4692,6 +5077,18 @@
"metrics"
]
},
+ "LogMetricsRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/LogMetricsRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"DPOAlignmentConfig": {
"type": "object",
"properties": {
@@ -4880,6 +5277,18 @@
"fsdp_cpu_offload"
]
},
+ "PostTrainingRLHFRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/PostTrainingRLHFRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"QueryDocumentsRequest": {
"type": "object",
"properties": {
@@ -5048,6 +5457,18 @@
],
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
},
+ "RewardScoringRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/RewardScoringRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"RewardScoringResponse": {
"type": "object",
"properties": {
@@ -5333,6 +5754,18 @@
"alpha"
]
},
+ "PostTrainingSFTRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/PostTrainingSFTRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"SyntheticDataGenerationRequest": {
"type": "object",
"properties": {
@@ -5378,6 +5811,18 @@
],
"title": "Request to generate synthetic data. A small batch of prompts and a filtering function"
},
+ "SyntheticDataGenerationRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/SyntheticDataGenerationRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"SyntheticDataGenerationResponse": {
"type": "object",
"properties": {
@@ -5478,6 +5923,18 @@
"experiment_id"
]
},
+ "UpdateExperimentRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/UpdateExperimentRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"UpdateRunRequest": {
"type": "object",
"properties": {
@@ -5522,6 +5979,18 @@
"run_id"
]
},
+ "UpdateRunRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/UpdateRunRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
+ },
"UploadArtifactRequest": {
"type": "object",
"properties": {
@@ -5571,6 +6040,18 @@
"artifact_type",
"content"
]
+ },
+ "UploadArtifactRequestWrapper": {
+ "type": "object",
+ "properties": {
+ "request": {
+ "$ref": "#/components/schemas/UploadArtifactRequest"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "request"
+ ]
}
},
"responses": {}
@@ -5582,10 +6063,10 @@
],
"tags": [
{
- "name": "Observability"
+ "name": "SyntheticDataGeneration"
},
{
- "name": "Inference"
+ "name": "RewardScoring"
},
{
"name": "Datasets"
@@ -5597,19 +6078,19 @@
"name": "AgenticSystem"
},
{
- "name": "Evaluations"
- },
- {
- "name": "SyntheticDataGeneration"
- },
- {
- "name": "RewardScoring"
+ "name": "BatchInference"
},
{
"name": "PostTraining"
},
{
- "name": "BatchInference"
+ "name": "Evaluations"
+ },
+ {
+ "name": "Telemetry"
+ },
+ {
+ "name": "Inference"
},
{
"name": "BatchChatCompletionRequest",
@@ -5667,6 +6148,10 @@
"name": "UserMessage",
"description": ""
},
+ {
+ "name": "BatchChatCompletionRequestWrapper",
+ "description": ""
+ },
{
"name": "BatchChatCompletionResponse",
"description": ""
@@ -5675,6 +6160,10 @@
"name": "BatchCompletionRequest",
"description": ""
},
+ {
+ "name": "BatchCompletionRequestWrapper",
+ "description": ""
+ },
{
"name": "BatchCompletionResponse",
"description": ""
@@ -5691,6 +6180,10 @@
"name": "ChatCompletionRequest",
"description": ""
},
+ {
+ "name": "ChatCompletionRequestWrapper",
+ "description": ""
+ },
{
"name": "ChatCompletionResponseEvent",
"description": "Chat completion response event.\n\n"
@@ -5719,6 +6212,10 @@
"name": "CompletionRequest",
"description": ""
},
+ {
+ "name": "CompletionRequestWrapper",
+ "description": ""
+ },
{
"name": "CompletionResponseStreamChunk",
"description": "streamed completion response.\n\n"
@@ -5743,10 +6240,6 @@
"name": "FunctionCallToolDefinition",
"description": ""
},
- {
- "name": "MemoryToolDefinition",
- "description": ""
- },
{
"name": "OnViolationAction",
"description": ""
@@ -5799,6 +6292,10 @@
"name": "Attachment",
"description": ""
},
+ {
+ "name": "AgenticSystemTurnCreateRequestWrapper",
+ "description": ""
+ },
{
"name": "AgenticSystemTurnResponseEvent",
"description": "Streamed agent execution response.\n\n"
@@ -5867,10 +6364,18 @@
"name": "TrainEvalDatasetColumnType",
"description": ""
},
+ {
+ "name": "CreateDatasetRequestWrapper",
+ "description": ""
+ },
{
"name": "CreateExperimentRequest",
"description": ""
},
+ {
+ "name": "CreateExperimentRequestWrapper",
+ "description": ""
+ },
{
"name": "Experiment",
"description": ""
@@ -5891,6 +6396,10 @@
"name": "CreateRunRequest",
"description": ""
},
+ {
+ "name": "CreateRunRequestWrapper",
+ "description": ""
+ },
{
"name": "Run",
"description": ""
@@ -5931,6 +6440,10 @@
"name": "EvaluateQuestionAnsweringRequest",
"description": "Request to evaluate question answering.\n\n"
},
+ {
+ "name": "EvaluateQuestionAnsweringRequestWrapper",
+ "description": ""
+ },
{
"name": "EvaluationJob",
"description": ""
@@ -5939,10 +6452,18 @@
"name": "EvaluateSummarizationRequest",
"description": "Request to evaluate summarization.\n\n"
},
+ {
+ "name": "EvaluateSummarizationRequestWrapper",
+ "description": ""
+ },
{
"name": "EvaluateTextGenerationRequest",
"description": "Request to evaluate text generation.\n\n"
},
+ {
+ "name": "EvaluateTextGenerationRequestWrapper",
+ "description": ""
+ },
{
"name": "GetAgenticSystemSessionRequest",
"description": ""
@@ -5987,6 +6508,10 @@
"name": "LogSearchRequest",
"description": ""
},
+ {
+ "name": "LogSearchRequestWrapper",
+ "description": ""
+ },
{
"name": "Log",
"description": ""
@@ -6027,10 +6552,18 @@
"name": "LogMessagesRequest",
"description": ""
},
+ {
+ "name": "LogMessagesRequestWrapper",
+ "description": ""
+ },
{
"name": "LogMetricsRequest",
"description": ""
},
+ {
+ "name": "LogMetricsRequestWrapper",
+ "description": ""
+ },
{
"name": "DPOAlignmentConfig",
"description": ""
@@ -6051,6 +6584,10 @@
"name": "TrainingConfig",
"description": ""
},
+ {
+ "name": "PostTrainingRLHFRequestWrapper",
+ "description": ""
+ },
{
"name": "QueryDocumentsRequest",
"description": ""
@@ -6067,6 +6604,10 @@
"name": "RewardScoringRequest",
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n"
},
+ {
+ "name": "RewardScoringRequestWrapper",
+ "description": ""
+ },
{
"name": "RewardScoringResponse",
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n"
@@ -6099,10 +6640,18 @@
"name": "QLoraFinetuningConfig",
"description": ""
},
+ {
+ "name": "PostTrainingSFTRequestWrapper",
+ "description": ""
+ },
{
"name": "SyntheticDataGenerationRequest",
"description": "Request to generate synthetic data. A small batch of prompts and a filtering function\n\n"
},
+ {
+ "name": "SyntheticDataGenerationRequestWrapper",
+ "description": ""
+ },
{
"name": "SyntheticDataGenerationResponse",
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n"
@@ -6115,13 +6664,25 @@
"name": "UpdateExperimentRequest",
"description": ""
},
+ {
+ "name": "UpdateExperimentRequestWrapper",
+ "description": ""
+ },
{
"name": "UpdateRunRequest",
"description": ""
},
+ {
+ "name": "UpdateRunRequestWrapper",
+ "description": ""
+ },
{
"name": "UploadArtifactRequest",
"description": ""
+ },
+ {
+ "name": "UploadArtifactRequestWrapper",
+ "description": ""
}
],
"x-tagGroups": [
@@ -6134,10 +6695,10 @@
"Evaluations",
"Inference",
"Memory",
- "Observability",
"PostTraining",
"RewardScoring",
- "SyntheticDataGeneration"
+ "SyntheticDataGeneration",
+ "Telemetry"
]
},
{
@@ -6148,6 +6709,7 @@
"AgenticSystemSessionCreateResponse",
"AgenticSystemStepResponse",
"AgenticSystemTurnCreateRequest",
+ "AgenticSystemTurnCreateRequestWrapper",
"AgenticSystemTurnResponseEvent",
"AgenticSystemTurnResponseStepCompletePayload",
"AgenticSystemTurnResponseStepProgressPayload",
@@ -6159,8 +6721,10 @@
"ArtifactType",
"Attachment",
"BatchChatCompletionRequest",
+ "BatchChatCompletionRequestWrapper",
"BatchChatCompletionResponse",
"BatchCompletionRequest",
+ "BatchCompletionRequestWrapper",
"BatchCompletionResponse",
"BraveSearchToolDefinition",
"BuiltinShield",
@@ -6168,6 +6732,7 @@
"CancelEvaluationJobRequest",
"CancelTrainingJobRequest",
"ChatCompletionRequest",
+ "ChatCompletionRequestWrapper",
"ChatCompletionResponseEvent",
"ChatCompletionResponseEventType",
"ChatCompletionResponseStreamChunk",
@@ -6175,13 +6740,17 @@
"CodeInterpreterToolDefinition",
"CompletionMessage",
"CompletionRequest",
+ "CompletionRequestWrapper",
"CompletionResponseStreamChunk",
"CreateAgenticSystemRequest",
"CreateAgenticSystemSessionRequest",
"CreateDatasetRequest",
+ "CreateDatasetRequestWrapper",
"CreateExperimentRequest",
+ "CreateExperimentRequestWrapper",
"CreateMemoryBankRequest",
"CreateRunRequest",
+ "CreateRunRequestWrapper",
"DPOAlignmentConfig",
"DeleteAgenticSystemRequest",
"DeleteAgenticSystemSessionRequest",
@@ -6193,8 +6762,11 @@
"EmbeddingsRequest",
"EmbeddingsResponse",
"EvaluateQuestionAnsweringRequest",
+ "EvaluateQuestionAnsweringRequestWrapper",
"EvaluateSummarizationRequest",
+ "EvaluateSummarizationRequestWrapper",
"EvaluateTextGenerationRequest",
+ "EvaluateTextGenerationRequestWrapper",
"EvaluationJob",
"EvaluationJobArtifactsResponse",
"EvaluationJobLogStream",
@@ -6210,13 +6782,15 @@
"ListArtifactsRequest",
"Log",
"LogMessagesRequest",
+ "LogMessagesRequestWrapper",
"LogMetricsRequest",
+ "LogMetricsRequestWrapper",
"LogSearchRequest",
+ "LogSearchRequestWrapper",
"LoraFinetuningConfig",
"MemoryBank",
"MemoryBankDocument",
"MemoryRetrievalStep",
- "MemoryToolDefinition",
"Metric",
"OnViolationAction",
"OptimizerConfig",
@@ -6227,7 +6801,9 @@
"PostTrainingJobStatus",
"PostTrainingJobStatusResponse",
"PostTrainingRLHFRequest",
+ "PostTrainingRLHFRequestWrapper",
"PostTrainingSFTRequest",
+ "PostTrainingSFTRequestWrapper",
"QLoraFinetuningConfig",
"QueryDocumentsRequest",
"QueryDocumentsResponse",
@@ -6235,6 +6811,7 @@
"RestAPIExecutionConfig",
"RestAPIMethod",
"RewardScoringRequest",
+ "RewardScoringRequestWrapper",
"RewardScoringResponse",
"Run",
"SamplingParams",
@@ -6247,6 +6824,7 @@
"ShieldResponse",
"StopReason",
"SyntheticDataGenerationRequest",
+ "SyntheticDataGenerationRequestWrapper",
"SyntheticDataGenerationResponse",
"SystemMessage",
"TokenLogProbs",
@@ -6267,8 +6845,11 @@
"URL",
"UpdateDocumentsRequest",
"UpdateExperimentRequest",
+ "UpdateExperimentRequestWrapper",
"UpdateRunRequest",
+ "UpdateRunRequestWrapper",
"UploadArtifactRequest",
+ "UploadArtifactRequestWrapper",
"UserMessage",
"WolframAlphaToolDefinition"
]
diff --git a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml
index 7560f3c88..3c3475fff 100644
--- a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml
+++ b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml
@@ -30,7 +30,123 @@ components:
- $ref: '#/components/schemas/PhotogenToolDefinition'
- $ref: '#/components/schemas/CodeInterpreterToolDefinition'
- $ref: '#/components/schemas/FunctionCallToolDefinition'
- - $ref: '#/components/schemas/MemoryToolDefinition'
+ - additionalProperties: false
+ properties:
+ input_shields:
+ items:
+ $ref: '#/components/schemas/ShieldDefinition'
+ type: array
+ max_chunks:
+ type: integer
+ max_tokens_in_context:
+ type: integer
+ memory_bank_configs:
+ items:
+ oneOf:
+ - additionalProperties: false
+ properties:
+ bank_id:
+ type: string
+ type:
+ const: vector
+ type: string
+ required:
+ - bank_id
+ - type
+ type: object
+ - additionalProperties: false
+ properties:
+ bank_id:
+ type: string
+ keys:
+ items:
+ type: string
+ type: array
+ type:
+ const: keyvalue
+ type: string
+ required:
+ - bank_id
+ - type
+ - keys
+ type: object
+ - additionalProperties: false
+ properties:
+ bank_id:
+ type: string
+ type:
+ const: keyword
+ type: string
+ required:
+ - bank_id
+ - type
+ type: object
+ - additionalProperties: false
+ properties:
+ bank_id:
+ type: string
+ entities:
+ items:
+ type: string
+ type: array
+ type:
+ const: graph
+ type: string
+ required:
+ - bank_id
+ - type
+ - entities
+ type: object
+ type: array
+ output_shields:
+ items:
+ $ref: '#/components/schemas/ShieldDefinition'
+ type: array
+ query_generator_config:
+ oneOf:
+ - additionalProperties: false
+ properties:
+ sep:
+ type: string
+ type:
+ const: default
+ type: string
+ required:
+ - type
+ - sep
+ type: object
+ - additionalProperties: false
+ properties:
+ model:
+ type: string
+ template:
+ type: string
+ type:
+ const: llm
+ type: string
+ required:
+ - type
+ - model
+ - template
+ type: object
+ - additionalProperties: false
+ properties:
+ type:
+ const: custom
+ type: string
+ required:
+ - type
+ type: object
+ type:
+ const: memory
+ type: string
+ required:
+ - type
+ - memory_bank_configs
+ - query_generator_config
+ - max_tokens_in_context
+ - max_chunks
+ type: object
type: array
required:
- model
@@ -107,13 +223,137 @@ components:
- $ref: '#/components/schemas/PhotogenToolDefinition'
- $ref: '#/components/schemas/CodeInterpreterToolDefinition'
- $ref: '#/components/schemas/FunctionCallToolDefinition'
- - $ref: '#/components/schemas/MemoryToolDefinition'
+ - additionalProperties: false
+ properties:
+ input_shields:
+ items:
+ $ref: '#/components/schemas/ShieldDefinition'
+ type: array
+ max_chunks:
+ type: integer
+ max_tokens_in_context:
+ type: integer
+ memory_bank_configs:
+ items:
+ oneOf:
+ - additionalProperties: false
+ properties:
+ bank_id:
+ type: string
+ type:
+ const: vector
+ type: string
+ required:
+ - bank_id
+ - type
+ type: object
+ - additionalProperties: false
+ properties:
+ bank_id:
+ type: string
+ keys:
+ items:
+ type: string
+ type: array
+ type:
+ const: keyvalue
+ type: string
+ required:
+ - bank_id
+ - type
+ - keys
+ type: object
+ - additionalProperties: false
+ properties:
+ bank_id:
+ type: string
+ type:
+ const: keyword
+ type: string
+ required:
+ - bank_id
+ - type
+ type: object
+ - additionalProperties: false
+ properties:
+ bank_id:
+ type: string
+ entities:
+ items:
+ type: string
+ type: array
+ type:
+ const: graph
+ type: string
+ required:
+ - bank_id
+ - type
+ - entities
+ type: object
+ type: array
+ output_shields:
+ items:
+ $ref: '#/components/schemas/ShieldDefinition'
+ type: array
+ query_generator_config:
+ oneOf:
+ - additionalProperties: false
+ properties:
+ sep:
+ type: string
+ type:
+ const: default
+ type: string
+ required:
+ - type
+ - sep
+ type: object
+ - additionalProperties: false
+ properties:
+ model:
+ type: string
+ template:
+ type: string
+ type:
+ const: llm
+ type: string
+ required:
+ - type
+ - model
+ - template
+ type: object
+ - additionalProperties: false
+ properties:
+ type:
+ const: custom
+ type: string
+ required:
+ - type
+ type: object
+ type:
+ const: memory
+ type: string
+ required:
+ - type
+ - memory_bank_configs
+ - query_generator_config
+ - max_tokens_in_context
+ - max_chunks
+ type: object
type: array
required:
- agent_id
- session_id
- messages
type: object
+ AgenticSystemTurnCreateRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/AgenticSystemTurnCreateRequest'
+ required:
+ - request
+ type: object
AgenticSystemTurnResponseEvent:
additionalProperties: false
properties:
@@ -334,6 +574,14 @@ components:
- model
- messages_batch
type: object
+ BatchChatCompletionRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/BatchChatCompletionRequest'
+ required:
+ - request
+ type: object
BatchChatCompletionResponse:
additionalProperties: false
properties:
@@ -369,6 +617,14 @@ components:
- model
- content_batch
type: object
+ BatchCompletionRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/BatchCompletionRequest'
+ required:
+ - request
+ type: object
BatchCompletionResponse:
additionalProperties: false
properties:
@@ -464,6 +720,14 @@ components:
- model
- messages
type: object
+ ChatCompletionRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/ChatCompletionRequest'
+ required:
+ - request
+ type: object
ChatCompletionResponseEvent:
additionalProperties: false
properties:
@@ -572,6 +836,14 @@ components:
- model
- content
type: object
+ CompletionRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/CompletionRequest'
+ required:
+ - request
+ type: object
CompletionResponseStreamChunk:
additionalProperties: false
properties:
@@ -618,6 +890,14 @@ components:
- dataset
title: Request to create a dataset.
type: object
+ CreateDatasetRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/CreateDatasetRequest'
+ required:
+ - request
+ type: object
CreateExperimentRequest:
additionalProperties: false
properties:
@@ -636,6 +916,14 @@ components:
required:
- name
type: object
+ CreateExperimentRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/CreateExperimentRequest'
+ required:
+ - request
+ type: object
CreateMemoryBankRequest:
additionalProperties: false
properties:
@@ -707,6 +995,14 @@ components:
required:
- experiment_id
type: object
+ CreateRunRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/CreateRunRequest'
+ required:
+ - request
+ type: object
DPOAlignmentConfig:
additionalProperties: false
properties:
@@ -872,6 +1168,14 @@ components:
- metrics
title: Request to evaluate question answering.
type: object
+ EvaluateQuestionAnsweringRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/EvaluateQuestionAnsweringRequest'
+ required:
+ - request
+ type: object
EvaluateSummarizationRequest:
additionalProperties: false
properties:
@@ -898,6 +1202,14 @@ components:
- metrics
title: Request to evaluate summarization.
type: object
+ EvaluateSummarizationRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/EvaluateSummarizationRequest'
+ required:
+ - request
+ type: object
EvaluateTextGenerationRequest:
additionalProperties: false
properties:
@@ -925,6 +1237,14 @@ components:
- metrics
title: Request to evaluate text generation.
type: object
+ EvaluateTextGenerationRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/EvaluateTextGenerationRequest'
+ required:
+ - request
+ type: object
EvaluationJob:
additionalProperties: false
properties:
@@ -1138,6 +1458,14 @@ components:
required:
- logs
type: object
+ LogMessagesRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/LogMessagesRequest'
+ required:
+ - request
+ type: object
LogMetricsRequest:
additionalProperties: false
properties:
@@ -1151,6 +1479,14 @@ components:
- run_id
- metrics
type: object
+ LogMetricsRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/LogMetricsRequest'
+ required:
+ - request
+ type: object
LogSearchRequest:
additionalProperties: false
properties:
@@ -1169,6 +1505,14 @@ components:
required:
- query
type: object
+ LogSearchRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/LogSearchRequest'
+ required:
+ - request
+ type: object
LoraFinetuningConfig:
additionalProperties: false
properties:
@@ -1310,88 +1654,6 @@ components:
- memory_bank_ids
- inserted_context
type: object
- MemoryToolDefinition:
- additionalProperties: false
- properties:
- input_shields:
- items:
- $ref: '#/components/schemas/ShieldDefinition'
- type: array
- max_chunks:
- type: integer
- max_tokens_in_context:
- type: integer
- memory_bank_configs:
- items:
- oneOf:
- - additionalProperties: false
- properties:
- bank_id:
- type: string
- type:
- const: vector
- type: string
- required:
- - bank_id
- - type
- type: object
- - additionalProperties: false
- properties:
- bank_id:
- type: string
- keys:
- items:
- type: string
- type: array
- type:
- const: keyvalue
- type: string
- required:
- - bank_id
- - type
- - keys
- type: object
- - additionalProperties: false
- properties:
- bank_id:
- type: string
- type:
- const: keyword
- type: string
- required:
- - bank_id
- - type
- type: object
- - additionalProperties: false
- properties:
- bank_id:
- type: string
- entities:
- items:
- type: string
- type: array
- type:
- const: graph
- type: string
- required:
- - bank_id
- - type
- - entities
- type: object
- type: array
- output_shields:
- items:
- $ref: '#/components/schemas/ShieldDefinition'
- type: array
- type:
- const: memory
- type: string
- required:
- - type
- - memory_bank_configs
- - max_tokens_in_context
- - max_chunks
- type: object
Metric:
additionalProperties: false
properties:
@@ -1591,6 +1853,14 @@ components:
- logger_config
title: Request to finetune a model.
type: object
+ PostTrainingRLHFRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/PostTrainingRLHFRequest'
+ required:
+ - request
+ type: object
PostTrainingSFTRequest:
additionalProperties: false
properties:
@@ -1646,6 +1916,14 @@ components:
- logger_config
title: Request to finetune a model.
type: object
+ PostTrainingSFTRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/PostTrainingSFTRequest'
+ required:
+ - request
+ type: object
QLoraFinetuningConfig:
additionalProperties: false
properties:
@@ -1773,6 +2051,14 @@ components:
title: Request to score a reward function. A list of prompts and a list of responses
per prompt.
type: object
+ RewardScoringRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/RewardScoringRequest'
+ required:
+ - request
+ type: object
RewardScoringResponse:
additionalProperties: false
properties:
@@ -1995,6 +2281,14 @@ components:
title: Request to generate synthetic data. A small batch of prompts and a filtering
function
type: object
+ SyntheticDataGenerationRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/SyntheticDataGenerationRequest'
+ required:
+ - request
+ type: object
SyntheticDataGenerationResponse:
additionalProperties: false
properties:
@@ -2363,6 +2657,14 @@ components:
required:
- experiment_id
type: object
+ UpdateExperimentRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/UpdateExperimentRequest'
+ required:
+ - request
+ type: object
UpdateRunRequest:
additionalProperties: false
properties:
@@ -2386,6 +2688,14 @@ components:
required:
- run_id
type: object
+ UpdateRunRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/UpdateRunRequest'
+ required:
+ - request
+ type: object
UploadArtifactRequest:
additionalProperties: false
properties:
@@ -2414,6 +2724,14 @@ components:
- artifact_type
- content
type: object
+ UploadArtifactRequestWrapper:
+ additionalProperties: false
+ properties:
+ request:
+ $ref: '#/components/schemas/UploadArtifactRequest'
+ required:
+ - request
+ type: object
UserMessage:
additionalProperties: false
properties:
@@ -2459,7 +2777,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-09-04 10:28:38.779789"
+ \ draft and subject to change.\n Generated at 2024-09-10 01:13:08.531639"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -2591,7 +2909,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/AgenticSystemTurnCreateRequest'
+ $ref: '#/components/schemas/AgenticSystemTurnCreateRequestWrapper'
required: true
responses:
'200':
@@ -2640,7 +2958,7 @@ paths:
$ref: '#/components/schemas/Artifact'
description: OK
tags:
- - Observability
+ - Telemetry
/batch_inference/chat_completion:
post:
parameters: []
@@ -2648,7 +2966,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/BatchChatCompletionRequest'
+ $ref: '#/components/schemas/BatchChatCompletionRequestWrapper'
required: true
responses:
'200':
@@ -2666,7 +2984,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/BatchCompletionRequest'
+ $ref: '#/components/schemas/BatchCompletionRequestWrapper'
required: true
responses:
'200':
@@ -2684,7 +3002,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/CreateDatasetRequest'
+ $ref: '#/components/schemas/CreateDatasetRequestWrapper'
required: true
responses:
'200':
@@ -2806,7 +3124,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/EvaluateQuestionAnsweringRequest'
+ $ref: '#/components/schemas/EvaluateQuestionAnsweringRequestWrapper'
required: true
responses:
'200':
@@ -2824,7 +3142,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/EvaluateSummarizationRequest'
+ $ref: '#/components/schemas/EvaluateSummarizationRequestWrapper'
required: true
responses:
'200':
@@ -2842,7 +3160,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/EvaluateTextGenerationRequest'
+ $ref: '#/components/schemas/EvaluateTextGenerationRequestWrapper'
required: true
responses:
'200':
@@ -2870,7 +3188,7 @@ paths:
$ref: '#/components/schemas/Artifact'
description: OK
tags:
- - Observability
+ - Telemetry
/experiments/artifacts/upload:
post:
parameters: []
@@ -2878,7 +3196,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/UploadArtifactRequest'
+ $ref: '#/components/schemas/UploadArtifactRequestWrapper'
required: true
responses:
'200':
@@ -2888,7 +3206,7 @@ paths:
$ref: '#/components/schemas/Artifact'
description: OK
tags:
- - Observability
+ - Telemetry
/experiments/create:
post:
parameters: []
@@ -2896,7 +3214,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/CreateExperimentRequest'
+ $ref: '#/components/schemas/CreateExperimentRequestWrapper'
required: true
responses:
'200':
@@ -2906,7 +3224,7 @@ paths:
$ref: '#/components/schemas/Experiment'
description: OK
tags:
- - Observability
+ - Telemetry
/experiments/create_run:
post:
parameters: []
@@ -2914,7 +3232,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/CreateRunRequest'
+ $ref: '#/components/schemas/CreateRunRequestWrapper'
required: true
responses:
'200':
@@ -2924,7 +3242,7 @@ paths:
$ref: '#/components/schemas/Run'
description: OK
tags:
- - Observability
+ - Telemetry
/experiments/get:
get:
parameters:
@@ -2941,7 +3259,7 @@ paths:
$ref: '#/components/schemas/Experiment'
description: OK
tags:
- - Observability
+ - Telemetry
/experiments/list:
get:
parameters: []
@@ -2953,7 +3271,7 @@ paths:
$ref: '#/components/schemas/Experiment'
description: OK
tags:
- - Observability
+ - Telemetry
/experiments/update:
post:
parameters: []
@@ -2961,7 +3279,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/UpdateExperimentRequest'
+ $ref: '#/components/schemas/UpdateExperimentRequestWrapper'
required: true
responses:
'200':
@@ -2971,7 +3289,7 @@ paths:
$ref: '#/components/schemas/Experiment'
description: OK
tags:
- - Observability
+ - Telemetry
/inference/chat_completion:
post:
parameters: []
@@ -2979,12 +3297,12 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/ChatCompletionRequest'
+ $ref: '#/components/schemas/ChatCompletionRequestWrapper'
required: true
responses:
'200':
content:
- application/json:
+ text/event-stream:
schema:
$ref: '#/components/schemas/ChatCompletionResponseStreamChunk'
description: SSE-stream of these events.
@@ -2997,7 +3315,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/CompletionRequest'
+ $ref: '#/components/schemas/CompletionRequestWrapper'
required: true
responses:
'200':
@@ -3033,7 +3351,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/LogSearchRequest'
+ $ref: '#/components/schemas/LogSearchRequestWrapper'
required: true
responses:
'200':
@@ -3043,7 +3361,7 @@ paths:
$ref: '#/components/schemas/Log'
description: OK
tags:
- - Observability
+ - Telemetry
/logging/log_messages:
post:
parameters: []
@@ -3051,13 +3369,13 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/LogMessagesRequest'
+ $ref: '#/components/schemas/LogMessagesRequestWrapper'
required: true
responses:
'200':
description: OK
tags:
- - Observability
+ - Telemetry
/memory_bank/documents/delete:
post:
parameters: []
@@ -3292,7 +3610,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/PostTrainingRLHFRequest'
+ $ref: '#/components/schemas/PostTrainingRLHFRequestWrapper'
required: true
responses:
'200':
@@ -3310,7 +3628,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/PostTrainingSFTRequest'
+ $ref: '#/components/schemas/PostTrainingSFTRequestWrapper'
required: true
responses:
'200':
@@ -3328,7 +3646,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/RewardScoringRequest'
+ $ref: '#/components/schemas/RewardScoringRequestWrapper'
required: true
responses:
'200':
@@ -3346,13 +3664,13 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/LogMetricsRequest'
+ $ref: '#/components/schemas/LogMetricsRequestWrapper'
required: true
responses:
'200':
description: OK
tags:
- - Observability
+ - Telemetry
/runs/metrics:
get:
parameters:
@@ -3369,7 +3687,7 @@ paths:
$ref: '#/components/schemas/Metric'
description: OK
tags:
- - Observability
+ - Telemetry
/runs/update:
post:
parameters: []
@@ -3377,7 +3695,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/UpdateRunRequest'
+ $ref: '#/components/schemas/UpdateRunRequestWrapper'
required: true
responses:
'200':
@@ -3387,7 +3705,7 @@ paths:
$ref: '#/components/schemas/Run'
description: OK
tags:
- - Observability
+ - Telemetry
/synthetic_data_generation/generate:
post:
parameters: []
@@ -3395,7 +3713,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/SyntheticDataGenerationRequest'
+ $ref: '#/components/schemas/SyntheticDataGenerationRequestWrapper'
required: true
responses:
'200':
@@ -3411,16 +3729,16 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
-- name: Observability
-- name: Inference
+- name: SyntheticDataGeneration
+- name: RewardScoring
- name: Datasets
- name: Memory
- name: AgenticSystem
-- name: Evaluations
-- name: SyntheticDataGeneration
-- name: RewardScoring
-- name: PostTraining
- name: BatchInference
+- name: PostTraining
+- name: Evaluations
+- name: Telemetry
+- name: Inference
- description:
name: BatchChatCompletionRequest
@@ -3463,12 +3781,18 @@ tags:
name: ToolResponseMessage
- description:
name: UserMessage
+- description:
+ name: BatchChatCompletionRequestWrapper
- description:
name: BatchChatCompletionResponse
- description:
name: BatchCompletionRequest
+- description:
+ name: BatchCompletionRequestWrapper
- description:
name: BatchCompletionResponse
@@ -3481,6 +3805,9 @@ tags:
- description:
name: ChatCompletionRequest
+- description:
+ name: ChatCompletionRequestWrapper
- description: 'Chat completion response event.
@@ -3506,6 +3833,9 @@ tags:
- description:
name: CompletionRequest
+- description:
+ name: CompletionRequestWrapper
- description: 'streamed completion response.
@@ -3525,9 +3855,6 @@ tags:
- description:
name: FunctionCallToolDefinition
-- description:
- name: MemoryToolDefinition
- description:
name: OnViolationAction
@@ -3564,6 +3891,9 @@ tags:
name: AgenticSystemTurnCreateRequest
- description:
name: Attachment
+- description:
+ name: AgenticSystemTurnCreateRequestWrapper
- description: 'Streamed agent execution response.
@@ -3620,9 +3950,15 @@ tags:
- description:
name: TrainEvalDatasetColumnType
+- description:
+ name: CreateDatasetRequestWrapper
- description:
name: CreateExperimentRequest
+- description:
+ name: CreateExperimentRequestWrapper
- description:
name: Experiment
- description:
name: CreateRunRequest
+- description:
+ name: CreateRunRequestWrapper
- description:
name: Run
- description: '
name: EvaluateQuestionAnsweringRequest
+- description:
+ name: EvaluateQuestionAnsweringRequestWrapper
- description:
name: EvaluationJob
- description: 'Request to evaluate summarization.
@@ -3678,12 +4020,18 @@ tags:
'
name: EvaluateSummarizationRequest
+- description:
+ name: EvaluateSummarizationRequestWrapper
- description: 'Request to evaluate text generation.
'
name: EvaluateTextGenerationRequest
+- description:
+ name: EvaluateTextGenerationRequestWrapper
- description:
name: GetAgenticSystemSessionRequest
@@ -3720,6 +4068,9 @@ tags:
- description:
name: LogSearchRequest
+- description:
+ name: LogSearchRequestWrapper
- description:
name: Log
- description:
@@ -3756,9 +4107,15 @@ tags:
- description:
name: LogMessagesRequest
+- description:
+ name: LogMessagesRequestWrapper
- description:
name: LogMetricsRequest
+- description:
+ name: LogMetricsRequestWrapper
- description:
name: DPOAlignmentConfig
@@ -3774,6 +4131,9 @@ tags:
name: RLHFAlgorithm
- description:
name: TrainingConfig
+- description:
+ name: PostTrainingRLHFRequestWrapper
- description:
name: QueryDocumentsRequest
@@ -3789,6 +4149,9 @@ tags:
'
name: RewardScoringRequest
+- description:
+ name: RewardScoringRequestWrapper
- description: 'Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold.
@@ -3817,6 +4180,9 @@ tags:
- description:
name: QLoraFinetuningConfig
+- description:
+ name: PostTrainingSFTRequestWrapper
- description: 'Request to generate synthetic data. A small batch of prompts and a
filtering function
@@ -3824,6 +4190,9 @@ tags:
'
name: SyntheticDataGenerationRequest
+- description:
+ name: SyntheticDataGenerationRequestWrapper
- description: 'Response from the synthetic data generation. Batch of (prompt, response,
score) tuples that pass the threshold.
@@ -3837,12 +4206,21 @@ tags:
- description:
name: UpdateExperimentRequest
+- description:
+ name: UpdateExperimentRequestWrapper
- description:
name: UpdateRunRequest
+- description:
+ name: UpdateRunRequestWrapper
- description:
name: UploadArtifactRequest
+- description:
+ name: UploadArtifactRequestWrapper
x-tagGroups:
- name: Operations
tags:
@@ -3852,10 +4230,10 @@ x-tagGroups:
- Evaluations
- Inference
- Memory
- - Observability
- PostTraining
- RewardScoring
- SyntheticDataGeneration
+ - Telemetry
- name: Types
tags:
- AgentConfig
@@ -3863,6 +4241,7 @@ x-tagGroups:
- AgenticSystemSessionCreateResponse
- AgenticSystemStepResponse
- AgenticSystemTurnCreateRequest
+ - AgenticSystemTurnCreateRequestWrapper
- AgenticSystemTurnResponseEvent
- AgenticSystemTurnResponseStepCompletePayload
- AgenticSystemTurnResponseStepProgressPayload
@@ -3874,8 +4253,10 @@ x-tagGroups:
- ArtifactType
- Attachment
- BatchChatCompletionRequest
+ - BatchChatCompletionRequestWrapper
- BatchChatCompletionResponse
- BatchCompletionRequest
+ - BatchCompletionRequestWrapper
- BatchCompletionResponse
- BraveSearchToolDefinition
- BuiltinShield
@@ -3883,6 +4264,7 @@ x-tagGroups:
- CancelEvaluationJobRequest
- CancelTrainingJobRequest
- ChatCompletionRequest
+ - ChatCompletionRequestWrapper
- ChatCompletionResponseEvent
- ChatCompletionResponseEventType
- ChatCompletionResponseStreamChunk
@@ -3890,13 +4272,17 @@ x-tagGroups:
- CodeInterpreterToolDefinition
- CompletionMessage
- CompletionRequest
+ - CompletionRequestWrapper
- CompletionResponseStreamChunk
- CreateAgenticSystemRequest
- CreateAgenticSystemSessionRequest
- CreateDatasetRequest
+ - CreateDatasetRequestWrapper
- CreateExperimentRequest
+ - CreateExperimentRequestWrapper
- CreateMemoryBankRequest
- CreateRunRequest
+ - CreateRunRequestWrapper
- DPOAlignmentConfig
- DeleteAgenticSystemRequest
- DeleteAgenticSystemSessionRequest
@@ -3908,8 +4294,11 @@ x-tagGroups:
- EmbeddingsRequest
- EmbeddingsResponse
- EvaluateQuestionAnsweringRequest
+ - EvaluateQuestionAnsweringRequestWrapper
- EvaluateSummarizationRequest
+ - EvaluateSummarizationRequestWrapper
- EvaluateTextGenerationRequest
+ - EvaluateTextGenerationRequestWrapper
- EvaluationJob
- EvaluationJobArtifactsResponse
- EvaluationJobLogStream
@@ -3925,13 +4314,15 @@ x-tagGroups:
- ListArtifactsRequest
- Log
- LogMessagesRequest
+ - LogMessagesRequestWrapper
- LogMetricsRequest
+ - LogMetricsRequestWrapper
- LogSearchRequest
+ - LogSearchRequestWrapper
- LoraFinetuningConfig
- MemoryBank
- MemoryBankDocument
- MemoryRetrievalStep
- - MemoryToolDefinition
- Metric
- OnViolationAction
- OptimizerConfig
@@ -3942,7 +4333,9 @@ x-tagGroups:
- PostTrainingJobStatus
- PostTrainingJobStatusResponse
- PostTrainingRLHFRequest
+ - PostTrainingRLHFRequestWrapper
- PostTrainingSFTRequest
+ - PostTrainingSFTRequestWrapper
- QLoraFinetuningConfig
- QueryDocumentsRequest
- QueryDocumentsResponse
@@ -3950,6 +4343,7 @@ x-tagGroups:
- RestAPIExecutionConfig
- RestAPIMethod
- RewardScoringRequest
+ - RewardScoringRequestWrapper
- RewardScoringResponse
- Run
- SamplingParams
@@ -3962,6 +4356,7 @@ x-tagGroups:
- ShieldResponse
- StopReason
- SyntheticDataGenerationRequest
+ - SyntheticDataGenerationRequestWrapper
- SyntheticDataGenerationResponse
- SystemMessage
- TokenLogProbs
@@ -3982,7 +4377,10 @@ x-tagGroups:
- URL
- UpdateDocumentsRequest
- UpdateExperimentRequest
+ - UpdateExperimentRequestWrapper
- UpdateRunRequest
+ - UpdateRunRequestWrapper
- UploadArtifactRequest
+ - UploadArtifactRequestWrapper
- UserMessage
- WolframAlphaToolDefinition
diff --git a/rfcs/openapi_generator/generate.py b/rfcs/openapi_generator/generate.py
index ab9774e70..279389a47 100644
--- a/rfcs/openapi_generator/generate.py
+++ b/rfcs/openapi_generator/generate.py
@@ -35,7 +35,10 @@ from llama_toolchain.stack import LlamaStack
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
-STREAMING_ENDPOINTS = ["/agentic_system/turn/create"]
+STREAMING_ENDPOINTS = [
+ "/agentic_system/turn/create",
+ "/inference/chat_completion",
+]
def patch_sse_stream_responses(spec: Specification):
diff --git a/rfcs/openapi_generator/pyopenapi/generator.py b/rfcs/openapi_generator/pyopenapi/generator.py
index e1450074b..a711d9f68 100644
--- a/rfcs/openapi_generator/pyopenapi/generator.py
+++ b/rfcs/openapi_generator/pyopenapi/generator.py
@@ -468,12 +468,14 @@ class Generator:
builder = ContentBuilder(self.schema_builder)
first = next(iter(op.request_params))
request_name, request_type = first
+
+ from dataclasses import make_dataclass
+
if len(op.request_params) == 1 and "Request" in first[1].__name__:
# TODO(ashwin): Undo the "Request" hack and this entire block eventually
- request_name, request_type = first
+ request_name = first[1].__name__ + "Wrapper"
+ request_type = make_dataclass(request_name, op.request_params)
else:
- from dataclasses import make_dataclass
-
op_name = "".join(word.capitalize() for word in op.name.split("_"))
request_name = f"{op_name}Request"
request_type = make_dataclass(request_name, op.request_params)
diff --git a/rfcs/openapi_generator/run_openapi_generator.sh b/rfcs/openapi_generator/run_openapi_generator.sh
index cf4265ae5..1b2f979cc 100755
--- a/rfcs/openapi_generator/run_openapi_generator.sh
+++ b/rfcs/openapi_generator/run_openapi_generator.sh
@@ -28,4 +28,4 @@ if [ ${#missing_packages[@]} -ne 0 ]; then
exit 1
fi
-PYTHONPATH=$PYTHONPATH:../.. python3 -m rfcs.openapi_generator.generate $*
+PYTHONPATH=$PYTHONPATH:../.. python -m rfcs.openapi_generator.generate $*