mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'main' into tgi-integration
This commit is contained in:
commit
04f0b8fe11
38 changed files with 2157 additions and 548 deletions
|
@ -248,51 +248,51 @@ llama stack list-distributions
|
||||||
```
|
```
|
||||||
|
|
||||||
<pre style="font-family: monospace;">
|
<pre style="font-family: monospace;">
|
||||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
i+-------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| Distribution ID | Providers | Description |
|
| Distribution Type | Providers | Description |
|
||||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
|
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
|
||||||
| | "inference": "meta-reference", | |
|
| | "inference": "meta-reference", | |
|
||||||
| | "memory": "meta-reference-faiss", | |
|
| | "memory": "meta-reference-faiss", | |
|
||||||
| | "safety": "meta-reference", | |
|
| | "safety": "meta-reference", | |
|
||||||
| | "agentic_system": "meta-reference" | |
|
| | "agentic_system": "meta-reference" | |
|
||||||
| | } | |
|
| | } | |
|
||||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| remote | { | Point to remote services for all llama stack APIs |
|
| remote | { | Point to remote services for all llama stack APIs |
|
||||||
| | "inference": "remote", | |
|
| | "inference": "remote", | |
|
||||||
| | "safety": "remote", | |
|
| | "safety": "remote", | |
|
||||||
| | "agentic_system": "remote", | |
|
| | "agentic_system": "remote", | |
|
||||||
| | "memory": "remote" | |
|
| | "memory": "remote" | |
|
||||||
| | } | |
|
| | } | |
|
||||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| local-ollama | { | Like local, but use ollama for running LLM inference |
|
| local-ollama | { | Like local, but use ollama for running LLM inference |
|
||||||
| | "inference": "remote::ollama", | |
|
| | "inference": "remote::ollama", | |
|
||||||
| | "safety": "meta-reference", | |
|
| | "safety": "meta-reference", | |
|
||||||
| | "agentic_system": "meta-reference", | |
|
| | "agentic_system": "meta-reference", | |
|
||||||
| | "memory": "meta-reference-faiss" | |
|
| | "memory": "meta-reference-faiss" | |
|
||||||
| | } | |
|
| | } | |
|
||||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
|
| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
|
||||||
| | "inference": "remote::fireworks", | |
|
| | "inference": "remote::fireworks", | |
|
||||||
| | "safety": "meta-reference", | |
|
| | "safety": "meta-reference", | |
|
||||||
| | "agentic_system": "meta-reference", | |
|
| | "agentic_system": "meta-reference", | |
|
||||||
| | "memory": "meta-reference-faiss" | |
|
| | "memory": "meta-reference-faiss" | |
|
||||||
| | } | |
|
| | } | |
|
||||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| local-plus-together-inference | { | Use Together.ai for running LLM inference |
|
| local-plus-together-inference | { | Use Together.ai for running LLM inference |
|
||||||
| | "inference": "remote::together", | |
|
| | "inference": "remote::together", | |
|
||||||
| | "safety": "meta-reference", | |
|
| | "safety": "meta-reference", | |
|
||||||
| | "agentic_system": "meta-reference", | |
|
| | "agentic_system": "meta-reference", | |
|
||||||
| | "memory": "meta-reference-faiss" | |
|
| | "memory": "meta-reference-faiss" | |
|
||||||
| | } | |
|
| | } | |
|
||||||
|--------------------------------|---------------------------------------|-------------------------------------------------------------------------------------------|
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https://huggingface.co/ |
|
| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https:// |
|
||||||
| | "inference": "remote::tgi", | inference-endpoints/dedicated)) for running LLM inference. When using HF Inference |
|
| | "inference": "remote::tgi", | huggingface.co/inference-endpoints/dedicated)) for running LLM |
|
||||||
| | "safety": "meta-reference", | Endpoints, you must provide the name of the endpoint. |
|
| | "safety": "meta-reference", | inference. When using HF Inference Endpoints, you must provide the |
|
||||||
| | "agentic_system": "meta-reference", | |
|
| | "agentic_system": "meta-reference", | name of the endpoint. |
|
||||||
| | "memory": "meta-reference-faiss" | |
|
| | "memory": "meta-reference-faiss" | |
|
||||||
| | } | |
|
| | } | |
|
||||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.
|
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.
|
||||||
|
|
|
@ -116,10 +116,47 @@ MemoryBankConfig = Annotated[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class MemoryQueryGenerator(Enum):
|
||||||
|
default = "default"
|
||||||
|
llm = "llm"
|
||||||
|
custom = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||||
|
MemoryQueryGenerator.default.value
|
||||||
|
)
|
||||||
|
sep: str = " "
|
||||||
|
|
||||||
|
|
||||||
|
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||||
|
model: str
|
||||||
|
template: str
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||||
|
|
||||||
|
|
||||||
|
MemoryQueryGeneratorConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
|
LLMMemoryQueryGeneratorConfig,
|
||||||
|
CustomMemoryQueryGeneratorConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
|
# This config defines how a query is generated using the messages
|
||||||
|
# for memory bank retrieval.
|
||||||
|
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||||
|
default=DefaultMemoryQueryGeneratorConfig()
|
||||||
|
)
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 10
|
max_chunks: int = 10
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ from llama_toolchain.tools.builtin import (
|
||||||
SingleMessageBuiltinTool,
|
SingleMessageBuiltinTool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .rag.context_retriever import generate_rag_query
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
|
|
||||||
|
@ -664,7 +665,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# (i.e., no prior turns uploaded an Attachment)
|
# (i.e., no prior turns uploaded an Attachment)
|
||||||
return None, []
|
return None, []
|
||||||
|
|
||||||
query = " ".join(m.content for m in messages)
|
query = await generate_rag_query(
|
||||||
|
memory.query_generator_config, messages, inference_api=self.inference_api
|
||||||
|
)
|
||||||
tasks = [
|
tasks = [
|
||||||
self.memory_api.query_documents(
|
self.memory_api.query_documents(
|
||||||
bank_id=bank_id,
|
bank_id=bank_id,
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from jinja2 import Template
|
||||||
|
from llama_models.llama3.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
from llama_toolchain.agentic_system.api import (
|
||||||
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
|
LLMMemoryQueryGeneratorConfig,
|
||||||
|
MemoryQueryGenerator,
|
||||||
|
MemoryQueryGeneratorConfig,
|
||||||
|
)
|
||||||
|
from termcolor import cprint # noqa: F401
|
||||||
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_rag_query(
|
||||||
|
config: MemoryQueryGeneratorConfig,
|
||||||
|
messages: List[Message],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates a query that will be used for
|
||||||
|
retrieving relevant information from the memory bank.
|
||||||
|
"""
|
||||||
|
if config.type == MemoryQueryGenerator.default.value:
|
||||||
|
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||||
|
elif config.type == MemoryQueryGenerator.llm.value:
|
||||||
|
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||||
|
# cprint(f"Generated query >>>: {query}", color="green")
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
async def default_rag_query_generator(
|
||||||
|
config: DefaultMemoryQueryGeneratorConfig,
|
||||||
|
messages: List[Message],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_rag_query_generator(
|
||||||
|
config: LLMMemoryQueryGeneratorConfig,
|
||||||
|
messages: List[Message],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||||
|
inference_api = kwargs["inference_api"]
|
||||||
|
|
||||||
|
m_dict = {"messages": [m.model_dump() for m in messages]}
|
||||||
|
|
||||||
|
template = Template(config.template)
|
||||||
|
content = template.render(m_dict)
|
||||||
|
|
||||||
|
model = config.model
|
||||||
|
message = UserMessage(content=content)
|
||||||
|
response = inference_api.chat_completion(
|
||||||
|
ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=[message],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
query = chunk.completion_message.content
|
||||||
|
|
||||||
|
return query
|
|
@ -13,7 +13,7 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.agentic_system,
|
api=Api.agentic_system,
|
||||||
provider_id="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
"codeshield",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
|
|
|
@ -52,7 +52,7 @@ class StackBuild(Subcommand):
|
||||||
BuildType,
|
BuildType,
|
||||||
)
|
)
|
||||||
|
|
||||||
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
|
allowed_ids = [d.distribution_type for d in available_distribution_specs()]
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"distribution",
|
"distribution",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -101,7 +101,7 @@ class StackBuild(Subcommand):
|
||||||
api_inputs.append(
|
api_inputs.append(
|
||||||
ApiInput(
|
ApiInput(
|
||||||
api=api,
|
api=api,
|
||||||
provider=provider_spec.provider_id,
|
provider=provider_spec.provider_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
docker_image = None
|
docker_image = None
|
||||||
|
@ -115,11 +115,11 @@ class StackBuild(Subcommand):
|
||||||
self.parser.error(f"Could not find distribution {args.distribution}")
|
self.parser.error(f"Could not find distribution {args.distribution}")
|
||||||
return
|
return
|
||||||
|
|
||||||
for api, provider_id in dist.providers.items():
|
for api, provider_type in dist.providers.items():
|
||||||
api_inputs.append(
|
api_inputs.append(
|
||||||
ApiInput(
|
ApiInput(
|
||||||
api=api,
|
api=api,
|
||||||
provider=provider_id,
|
provider=provider_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
docker_image = dist.docker_image
|
docker_image = dist.docker_image
|
||||||
|
@ -128,6 +128,6 @@ class StackBuild(Subcommand):
|
||||||
api_inputs,
|
api_inputs,
|
||||||
build_type=BuildType(args.type),
|
build_type=BuildType(args.type),
|
||||||
name=args.name,
|
name=args.name,
|
||||||
distribution_id=args.distribution,
|
distribution_type=args.distribution,
|
||||||
docker_image=docker_image,
|
docker_image=docker_image,
|
||||||
)
|
)
|
||||||
|
|
|
@ -36,7 +36,7 @@ class StackConfigure(Subcommand):
|
||||||
)
|
)
|
||||||
from llama_toolchain.core.package import BuildType
|
from llama_toolchain.core.package import BuildType
|
||||||
|
|
||||||
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
|
allowed_ids = [d.distribution_type for d in available_distribution_specs()]
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"distribution",
|
"distribution",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -84,7 +84,7 @@ def configure_llama_distribution(config_file: Path) -> None:
|
||||||
|
|
||||||
if config.providers:
|
if config.providers:
|
||||||
cprint(
|
cprint(
|
||||||
f"Configuration already exists for {config.distribution_id}. Will overwrite...",
|
f"Configuration already exists for {config.distribution_type}. Will overwrite...",
|
||||||
"yellow",
|
"yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,7 +33,7 @@ class StackListDistributions(Subcommand):
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
||||||
"Distribution ID",
|
"Distribution Type",
|
||||||
"Providers",
|
"Providers",
|
||||||
"Description",
|
"Description",
|
||||||
]
|
]
|
||||||
|
@ -43,7 +43,7 @@ class StackListDistributions(Subcommand):
|
||||||
providers = {k.value: v for k, v in spec.providers.items()}
|
providers = {k.value: v for k, v in spec.providers.items()}
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
spec.distribution_id,
|
spec.distribution_type,
|
||||||
json.dumps(providers, indent=2),
|
json.dumps(providers, indent=2),
|
||||||
spec.description,
|
spec.description,
|
||||||
]
|
]
|
||||||
|
|
|
@ -41,7 +41,7 @@ class StackListProviders(Subcommand):
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
||||||
"Provider ID",
|
"Provider Type",
|
||||||
"PIP Package Dependencies",
|
"PIP Package Dependencies",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class StackListProviders(Subcommand):
|
||||||
for spec in providers_for_api.values():
|
for spec in providers_for_api.values():
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
spec.provider_id,
|
spec.provider_type,
|
||||||
",".join(spec.pip_packages),
|
",".join(spec.pip_packages),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -80,7 +80,7 @@ class StackRun(Subcommand):
|
||||||
with open(config_file, "r") as f:
|
with open(config_file, "r") as f:
|
||||||
config = PackageConfig(**yaml.safe_load(f))
|
config = PackageConfig(**yaml.safe_load(f))
|
||||||
|
|
||||||
if not config.distribution_id:
|
if not config.distribution_type:
|
||||||
raise ValueError("Build config appears to be corrupt.")
|
raise ValueError("Build config appears to be corrupt.")
|
||||||
|
|
||||||
if config.docker_image:
|
if config.docker_image:
|
||||||
|
|
|
@ -20,12 +20,12 @@ fi
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
if [ "$#" -ne 3 ]; then
|
if [ "$#" -ne 3 ]; then
|
||||||
echo "Usage: $0 <distribution_id> <build_name> <pip_dependencies>" >&2
|
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
|
||||||
echo "Example: $0 <distribution_id> mybuild 'numpy pandas scipy'" >&2
|
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
distribution_id="$1"
|
distribution_type="$1"
|
||||||
build_name="$2"
|
build_name="$2"
|
||||||
env_name="llamastack-$build_name"
|
env_name="llamastack-$build_name"
|
||||||
pip_dependencies="$3"
|
pip_dependencies="$3"
|
||||||
|
@ -117,4 +117,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
||||||
|
|
||||||
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
|
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
|
||||||
|
|
||||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type conda_env
|
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type conda_env
|
||||||
|
|
|
@ -5,12 +5,12 @@ LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
|
|
||||||
if [ "$#" -ne 4 ]; then
|
if [ "$#" -ne 4 ]; then
|
||||||
echo "Usage: $0 <distribution_id> <build_name> <docker_base> <pip_dependencies>
|
echo "Usage: $0 <distribution_type> <build_name> <docker_base> <pip_dependencies>
|
||||||
echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
echo "Example: $0 distribution_type my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
distribution_id=$1
|
distribution_type=$1
|
||||||
build_name="$2"
|
build_name="$2"
|
||||||
image_name="llamastack-$build_name"
|
image_name="llamastack-$build_name"
|
||||||
docker_base=$3
|
docker_base=$3
|
||||||
|
@ -110,4 +110,4 @@ set +x
|
||||||
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
|
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
|
||||||
echo "You can run it with: podman run -p 8000:8000 $image_name"
|
echo "You can run it with: podman run -p 8000:8000 $image_name"
|
||||||
|
|
||||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type container
|
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type container
|
||||||
|
|
|
@ -21,14 +21,14 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
|
||||||
for api_str, stub_config in existing_configs.items():
|
for api_str, stub_config in existing_configs.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
providers = all_providers[api]
|
providers = all_providers[api]
|
||||||
provider_id = stub_config["provider_id"]
|
provider_type = stub_config["provider_type"]
|
||||||
if provider_id not in providers:
|
if provider_type not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
|
f"Unknown provider `{provider_type}` is not available for API `{api_str}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_spec = providers[provider_id]
|
provider_spec = providers[provider_type]
|
||||||
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
|
cprint(f"Configuring API: {api_str} ({provider_type})", "white", attrs=["bold"])
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -43,7 +43,7 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
provider_configs[api_str] = {
|
provider_configs[api_str] = {
|
||||||
"provider_id": provider_id,
|
"provider_type": provider_type,
|
||||||
**provider_config.dict(),
|
**provider_config.dict(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ class ApiEndpoint(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProviderSpec(BaseModel):
|
class ProviderSpec(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
provider_id: str
|
provider_type: str
|
||||||
config_class: str = Field(
|
config_class: str = Field(
|
||||||
...,
|
...,
|
||||||
description="Fully-qualified classname of the config for this provider",
|
description="Fully-qualified classname of the config for this provider",
|
||||||
|
@ -100,7 +100,7 @@ class RemoteProviderConfig(BaseModel):
|
||||||
return url.rstrip("/")
|
return url.rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_id(adapter_id: str) -> str:
|
def remote_provider_type(adapter_id: str) -> str:
|
||||||
return f"remote::{adapter_id}"
|
return f"remote::{adapter_id}"
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,22 +141,22 @@ def remote_provider_spec(
|
||||||
if adapter and adapter.config_class
|
if adapter and adapter.config_class
|
||||||
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
|
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
|
||||||
)
|
)
|
||||||
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
provider_type = remote_provider_type(adapter.adapter_id) if adapter else "remote"
|
||||||
|
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DistributionSpec(BaseModel):
|
class DistributionSpec(BaseModel):
|
||||||
distribution_id: str
|
distribution_type: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
docker_image: Optional[str] = None
|
docker_image: Optional[str] = None
|
||||||
providers: Dict[Api, str] = Field(
|
providers: Dict[Api, str] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Provider IDs for each of the APIs provided by this distribution",
|
description="Provider Types for each of the APIs provided by this distribution",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,7 +171,7 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
|
||||||
this could be just a hash
|
this could be just a hash
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
distribution_id: Optional[str] = None
|
distribution_type: Optional[str] = None
|
||||||
|
|
||||||
docker_image: Optional[str] = Field(
|
docker_image: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|
|
@ -83,18 +83,18 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
|
|
||||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
inference_providers_by_id = {
|
inference_providers_by_id = {
|
||||||
a.provider_id: a for a in available_inference_providers()
|
a.provider_type: a for a in available_inference_providers()
|
||||||
}
|
}
|
||||||
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
|
safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()}
|
||||||
agentic_system_providers_by_id = {
|
agentic_system_providers_by_id = {
|
||||||
a.provider_id: a for a in available_agentic_system_providers()
|
a.provider_type: a for a in available_agentic_system_providers()
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
Api.inference: inference_providers_by_id,
|
Api.inference: inference_providers_by_id,
|
||||||
Api.safety: safety_providers_by_id,
|
Api.safety: safety_providers_by_id,
|
||||||
Api.agentic_system: agentic_system_providers_by_id,
|
Api.agentic_system: agentic_system_providers_by_id,
|
||||||
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
Api.memory: {a.provider_type: a for a in available_memory_providers()},
|
||||||
}
|
}
|
||||||
for k, v in ret.items():
|
for k, v in ret.items():
|
||||||
v["remote"] = remote_provider_spec(k)
|
v["remote"] = remote_provider_spec(k)
|
||||||
|
|
|
@ -14,7 +14,7 @@ from .datatypes import * # noqa: F403
|
||||||
def available_distribution_specs() -> List[DistributionSpec]:
|
def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
return [
|
return [
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
distribution_id="local",
|
distribution_type="local",
|
||||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||||
providers={
|
providers={
|
||||||
Api.inference: "meta-reference",
|
Api.inference: "meta-reference",
|
||||||
|
@ -24,35 +24,35 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
distribution_id="remote",
|
distribution_type="remote",
|
||||||
description="Point to remote services for all llama stack APIs",
|
description="Point to remote services for all llama stack APIs",
|
||||||
providers={x: "remote" for x in Api},
|
providers={x: "remote" for x in Api},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
distribution_id="local-ollama",
|
distribution_type="local-ollama",
|
||||||
description="Like local, but use ollama for running LLM inference",
|
description="Like local, but use ollama for running LLM inference",
|
||||||
providers={
|
providers={
|
||||||
Api.inference: remote_provider_id("ollama"),
|
Api.inference: remote_provider_type("ollama"),
|
||||||
Api.safety: "meta-reference",
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: "meta-reference",
|
Api.agentic_system: "meta-reference",
|
||||||
Api.memory: "meta-reference-faiss",
|
Api.memory: "meta-reference-faiss",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
distribution_id="local-plus-fireworks-inference",
|
distribution_type="local-plus-fireworks-inference",
|
||||||
description="Use Fireworks.ai for running LLM inference",
|
description="Use Fireworks.ai for running LLM inference",
|
||||||
providers={
|
providers={
|
||||||
Api.inference: remote_provider_id("fireworks"),
|
Api.inference: remote_provider_type("fireworks"),
|
||||||
Api.safety: "meta-reference",
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: "meta-reference",
|
Api.agentic_system: "meta-reference",
|
||||||
Api.memory: "meta-reference-faiss",
|
Api.memory: "meta-reference-faiss",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
distribution_id="local-plus-together-inference",
|
distribution_type="local-plus-together-inference",
|
||||||
description="Use Together.ai for running LLM inference",
|
description="Use Together.ai for running LLM inference",
|
||||||
providers={
|
providers={
|
||||||
Api.inference: remote_provider_id("together"),
|
Api.inference: remote_provider_type("together"),
|
||||||
Api.safety: "meta-reference",
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: "meta-reference",
|
Api.agentic_system: "meta-reference",
|
||||||
Api.memory: "meta-reference-faiss",
|
Api.memory: "meta-reference-faiss",
|
||||||
|
@ -72,8 +72,8 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
|
def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]:
|
||||||
for spec in available_distribution_specs():
|
for spec in available_distribution_specs():
|
||||||
if spec.distribution_id == distribution_id:
|
if spec.distribution_type == distribution_type:
|
||||||
return spec
|
return spec
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -46,13 +46,13 @@ def build_package(
|
||||||
api_inputs: List[ApiInput],
|
api_inputs: List[ApiInput],
|
||||||
build_type: BuildType,
|
build_type: BuildType,
|
||||||
name: str,
|
name: str,
|
||||||
distribution_id: Optional[str] = None,
|
distribution_type: Optional[str] = None,
|
||||||
docker_image: Optional[str] = None,
|
docker_image: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if not distribution_id:
|
if not distribution_type:
|
||||||
distribution_id = "adhoc"
|
distribution_type = "adhoc"
|
||||||
|
|
||||||
build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor()
|
build_dir = BUILDS_BASE_DIR / distribution_type / build_type.descriptor()
|
||||||
os.makedirs(build_dir, exist_ok=True)
|
os.makedirs(build_dir, exist_ok=True)
|
||||||
|
|
||||||
package_name = name.replace("::", "-")
|
package_name = name.replace("::", "-")
|
||||||
|
@ -79,7 +79,7 @@ def build_package(
|
||||||
if provider.docker_image:
|
if provider.docker_image:
|
||||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||||
|
|
||||||
stub_config[api.value] = {"provider_id": api_input.provider}
|
stub_config[api.value] = {"provider_type": api_input.provider}
|
||||||
|
|
||||||
if package_file.exists():
|
if package_file.exists():
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -92,7 +92,7 @@ def build_package(
|
||||||
c.providers[api_str] = new_config
|
c.providers[api_str] = new_config
|
||||||
else:
|
else:
|
||||||
existing_config = c.providers[api_str]
|
existing_config = c.providers[api_str]
|
||||||
if existing_config["provider_id"] != new_config["provider_id"]:
|
if existing_config["provider_type"] != new_config["provider_type"]:
|
||||||
cprint(
|
cprint(
|
||||||
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
|
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
|
@ -105,7 +105,7 @@ def build_package(
|
||||||
providers=stub_config,
|
providers=stub_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
c.distribution_id = distribution_id
|
c.distribution_type = distribution_type
|
||||||
c.docker_image = package_name if build_type == BuildType.container else None
|
c.docker_image = package_name if build_type == BuildType.container else None
|
||||||
c.conda_env = package_name if build_type == BuildType.conda_env else None
|
c.conda_env = package_name if build_type == BuildType.conda_env else None
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ def build_package(
|
||||||
)
|
)
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
distribution_id,
|
distribution_type,
|
||||||
package_name,
|
package_name,
|
||||||
package_deps.docker_image,
|
package_deps.docker_image,
|
||||||
" ".join(package_deps.pip_packages),
|
" ".join(package_deps.pip_packages),
|
||||||
|
@ -130,7 +130,7 @@ def build_package(
|
||||||
)
|
)
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
distribution_id,
|
distribution_type,
|
||||||
package_name,
|
package_name,
|
||||||
" ".join(package_deps.pip_packages),
|
" ".join(package_deps.pip_packages),
|
||||||
]
|
]
|
||||||
|
|
|
@ -284,13 +284,13 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
for api_str, provider_config in config["providers"].items():
|
for api_str, provider_config in config["providers"].items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
providers = all_providers[api]
|
providers = all_providers[api]
|
||||||
provider_id = provider_config["provider_id"]
|
provider_type = provider_config["provider_type"]
|
||||||
if provider_id not in providers:
|
if provider_type not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
f"Unknown provider `{provider_type}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_specs[api] = providers[provider_id]
|
provider_specs[api] = providers[provider_type]
|
||||||
|
|
||||||
impls = resolve_impls(provider_specs, config)
|
impls = resolve_impls(provider_specs, config)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ def available_inference_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_id="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
|
|
15
llama_toolchain/memory/adapters/chroma/__init__.py
Normal file
15
llama_toolchain/memory/adapters/chroma/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||||
|
from .chroma import ChromaMemoryAdapter
|
||||||
|
|
||||||
|
impl = ChromaMemoryAdapter(config.url)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
165
llama_toolchain/memory/adapters/chroma/chroma.py
Normal file
165
llama_toolchain/memory/adapters/chroma/chroma.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
from llama_toolchain.memory.common.vector_store import BankWithIndex, EmbeddingIndex
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, client: chromadb.AsyncHttpClient, collection):
|
||||||
|
self.client = client
|
||||||
|
self.collection = collection
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(
|
||||||
|
embeddings
|
||||||
|
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
print(f"Adding chunk #{i} tokens={chunk.token_count}")
|
||||||
|
|
||||||
|
await self.collection.add(
|
||||||
|
documents=[chunk.json() for chunk in chunks],
|
||||||
|
embeddings=embeddings,
|
||||||
|
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
results = await self.collection.query(
|
||||||
|
query_embeddings=[embedding.tolist()],
|
||||||
|
n_results=k,
|
||||||
|
include=["documents", "distances"],
|
||||||
|
)
|
||||||
|
distances = results["distances"][0]
|
||||||
|
documents = results["documents"][0]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for dist, doc in zip(distances, documents):
|
||||||
|
try:
|
||||||
|
doc = json.loads(doc)
|
||||||
|
chunk = Chunk(**doc)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Failed to parse document: {doc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(1.0 / float(dist))
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaMemoryAdapter(Memory):
|
||||||
|
def __init__(self, url: str) -> None:
|
||||||
|
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
|
url = url.rstrip("/")
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
raise ValueError("URL should not contain a path")
|
||||||
|
|
||||||
|
self.host = parsed.hostname
|
||||||
|
self.port = parsed.port
|
||||||
|
|
||||||
|
self.client = None
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
try:
|
||||||
|
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
||||||
|
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise RuntimeError("Could not connect to Chroma server") from e
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
bank_id = str(uuid.uuid4())
|
||||||
|
bank = MemoryBank(
|
||||||
|
bank_id=bank_id,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
collection = await self.client.create_collection(
|
||||||
|
name=bank_id,
|
||||||
|
metadata={"bank": bank.json()},
|
||||||
|
)
|
||||||
|
bank_index = BankWithIndex(
|
||||||
|
bank=bank, index=ChromaIndex(self.client, collection)
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = bank_index
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if bank_index is None:
|
||||||
|
return None
|
||||||
|
return bank_index.bank
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
collections = await self.client.list_collections()
|
||||||
|
for collection in collections:
|
||||||
|
if collection.name == bank_id:
|
||||||
|
print(collection.metadata)
|
||||||
|
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=ChromaIndex(self.client, collection),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
return await index.query_documents(query, params)
|
15
llama_toolchain/memory/adapters/pgvector/__init__.py
Normal file
15
llama_toolchain/memory/adapters/pgvector/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
||||||
|
from .pgvector import PGVectorMemoryAdapter
|
||||||
|
|
||||||
|
impl = PGVectorMemoryAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
17
llama_toolchain/memory/adapters/pgvector/config.py
Normal file
17
llama_toolchain/memory/adapters/pgvector/config.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PGVectorConfig(BaseModel):
|
||||||
|
host: str = Field(default="localhost")
|
||||||
|
port: int = Field(default=5432)
|
||||||
|
db: str
|
||||||
|
user: str
|
||||||
|
password: str
|
234
llama_toolchain/memory/adapters/pgvector/pgvector.py
Normal file
234
llama_toolchain/memory/adapters/pgvector/pgvector.py
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from psycopg2 import sql
|
||||||
|
from psycopg2.extras import execute_values, Json
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
from llama_toolchain.memory.common.vector_store import (
|
||||||
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
|
||||||
|
def check_extension_version(cur):
|
||||||
|
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
result = cur.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
||||||
|
query = sql.SQL(
|
||||||
|
"""
|
||||||
|
INSERT INTO metadata_store (key, data)
|
||||||
|
VALUES %s
|
||||||
|
ON CONFLICT (key) DO UPDATE
|
||||||
|
SET data = EXCLUDED.data
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
values = [(key, Json(model.dict())) for key, model in keys_models]
|
||||||
|
execute_values(cur, query, values, template="(%s, %s)")
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(cur, keys: List[str], cls):
|
||||||
|
query = "SELECT key, data FROM metadata_store"
|
||||||
|
if keys:
|
||||||
|
placeholders = ",".join(["%s"] * len(keys))
|
||||||
|
query += f" WHERE key IN ({placeholders})"
|
||||||
|
cur.execute(query, keys)
|
||||||
|
else:
|
||||||
|
cur.execute(query)
|
||||||
|
|
||||||
|
rows = cur.fetchall()
|
||||||
|
return [cls(**row["data"]) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
class PGVectorIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
||||||
|
self.cursor = cursor
|
||||||
|
self.table_name = f"vector_store_{bank.name}"
|
||||||
|
|
||||||
|
self.cursor.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
document JSONB,
|
||||||
|
embedding vector({dimension})
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(
|
||||||
|
embeddings
|
||||||
|
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
print(f"Adding chunk #{i} tokens={chunk.token_count}")
|
||||||
|
values.append(
|
||||||
|
(
|
||||||
|
f"{chunk.document_id}:chunk-{i}",
|
||||||
|
Json(chunk.dict()),
|
||||||
|
embeddings[i].tolist(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query = sql.SQL(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {self.table_name} (id, document, embedding)
|
||||||
|
VALUES %s
|
||||||
|
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
self.cursor.execute(
|
||||||
|
f"""
|
||||||
|
SELECT document, embedding <-> %s::vector AS distance
|
||||||
|
FROM {self.table_name}
|
||||||
|
ORDER BY distance
|
||||||
|
LIMIT %s
|
||||||
|
""",
|
||||||
|
(embedding.tolist(), k),
|
||||||
|
)
|
||||||
|
results = self.cursor.fetchall()
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc, dist in results:
|
||||||
|
chunks.append(Chunk(**doc))
|
||||||
|
scores.append(1.0 / float(dist))
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class PGVectorMemoryAdapter(Memory):
|
||||||
|
def __init__(self, config: PGVectorConfig) -> None:
|
||||||
|
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||||
|
self.config = config
|
||||||
|
self.cursor = None
|
||||||
|
self.conn = None
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
try:
|
||||||
|
self.conn = psycopg2.connect(
|
||||||
|
host=self.config.host,
|
||||||
|
port=self.config.port,
|
||||||
|
database=self.config.db,
|
||||||
|
user=self.config.user,
|
||||||
|
password=self.config.password,
|
||||||
|
)
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
version = check_extension_version(self.cursor)
|
||||||
|
if version:
|
||||||
|
print(f"Vector extension version: {version}")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Vector extension is not installed.")
|
||||||
|
|
||||||
|
self.cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
data JSONB
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
bank_id = str(uuid.uuid4())
|
||||||
|
bank = MemoryBank(
|
||||||
|
bank_id=bank_id,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
upsert_models(
|
||||||
|
self.cursor,
|
||||||
|
[
|
||||||
|
(bank.bank_id, bank),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if bank_index is None:
|
||||||
|
return None
|
||||||
|
return bank_index.bank
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
banks = load_models(self.cursor, [bank_id], MemoryBank)
|
||||||
|
if not banks:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bank = banks[0]
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
return await index.query_documents(query, params)
|
120
llama_toolchain/memory/common/vector_store.py
Normal file
120
llama_toolchain/memory/common/vector_store.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||||
|
|
||||||
|
EMBEDDING_MODEL = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model() -> "SentenceTransformer":
|
||||||
|
global EMBEDDING_MODEL
|
||||||
|
|
||||||
|
if EMBEDDING_MODEL is None:
|
||||||
|
print("Loading sentence transformer")
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
return EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
|
if isinstance(doc.content, URL):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(doc.content.uri)
|
||||||
|
return r.text
|
||||||
|
|
||||||
|
return interleaved_text_media_as_str(doc.content)
|
||||||
|
|
||||||
|
|
||||||
|
def make_overlapped_chunks(
|
||||||
|
document_id: str, text: str, window_len: int, overlap_len: int
|
||||||
|
) -> List[Chunk]:
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(tokens), window_len - overlap_len):
|
||||||
|
toks = tokens[i : i + window_len]
|
||||||
|
chunk = tokenizer.decode(toks)
|
||||||
|
chunks.append(
|
||||||
|
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingIndex(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BankWithIndex:
|
||||||
|
bank: MemoryBank
|
||||||
|
index: EmbeddingIndex
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None:
|
||||||
|
model = get_embedding_model()
|
||||||
|
for doc in documents:
|
||||||
|
content = await content_from_doc(doc)
|
||||||
|
chunks = make_overlapped_chunks(
|
||||||
|
doc.document_id,
|
||||||
|
content,
|
||||||
|
self.bank.config.chunk_size_in_tokens,
|
||||||
|
self.bank.config.overlap_size_in_tokens
|
||||||
|
or (self.bank.config.chunk_size_in_tokens // 4),
|
||||||
|
)
|
||||||
|
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
||||||
|
|
||||||
|
await self.index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
k = params.get("max_chunks", 3)
|
||||||
|
|
||||||
|
def _process(c) -> str:
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
|
if isinstance(query, list):
|
||||||
|
query_str = " ".join([_process(c) for c in query])
|
||||||
|
else:
|
||||||
|
query_str = _process(query)
|
||||||
|
|
||||||
|
model = get_embedding_model()
|
||||||
|
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||||
|
return await self.index.query(query_vector, k)
|
|
@ -5,108 +5,45 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import httpx
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from llama_toolchain.memory.api import * # noqa: F403
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
from llama_toolchain.memory.common.vector_store import (
|
||||||
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
class FaissIndex(EmbeddingIndex):
|
||||||
if isinstance(doc.content, URL):
|
id_by_index: Dict[int, str]
|
||||||
async with httpx.AsyncClient() as client:
|
chunk_by_index: Dict[int, str]
|
||||||
r = await client.get(doc.content.uri)
|
|
||||||
return r.text
|
|
||||||
|
|
||||||
return interleaved_text_media_as_str(doc.content)
|
def __init__(self, dimension: int):
|
||||||
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
|
self.id_by_index = {}
|
||||||
|
self.chunk_by_index = {}
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
indexlen = len(self.id_by_index)
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
|
print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
|
||||||
|
self.id_by_index[indexlen + i] = chunk.document_id
|
||||||
|
|
||||||
def make_overlapped_chunks(
|
self.index.add(np.array(embeddings).astype(np.float32))
|
||||||
text: str, window_len: int, overlap_len: int
|
|
||||||
) -> List[Tuple[str, int]]:
|
|
||||||
tokenizer = Tokenizer.get_instance()
|
|
||||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
|
||||||
|
|
||||||
chunks = []
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
for i in range(0, len(tokens), window_len - overlap_len):
|
|
||||||
toks = tokens[i : i + window_len]
|
|
||||||
chunk = tokenizer.decode(toks)
|
|
||||||
chunks.append((chunk, len(toks)))
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BankState:
|
|
||||||
bank: MemoryBank
|
|
||||||
index: Optional[faiss.IndexFlatL2] = None
|
|
||||||
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
|
|
||||||
id_by_index: Dict[int, str] = field(default_factory=dict)
|
|
||||||
chunk_by_index: Dict[int, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
async def insert_documents(
|
|
||||||
self,
|
|
||||||
model: "SentenceTransformer",
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None:
|
|
||||||
tokenizer = Tokenizer.get_instance()
|
|
||||||
chunk_size = self.bank.config.chunk_size_in_tokens
|
|
||||||
|
|
||||||
for doc in documents:
|
|
||||||
indexlen = len(self.id_by_index)
|
|
||||||
self.doc_by_id[doc.document_id] = doc
|
|
||||||
|
|
||||||
content = await content_from_doc(doc)
|
|
||||||
chunks = make_overlapped_chunks(
|
|
||||||
content,
|
|
||||||
self.bank.config.chunk_size_in_tokens,
|
|
||||||
self.bank.config.overlap_size_in_tokens
|
|
||||||
or (self.bank.config.chunk_size_in_tokens // 4),
|
|
||||||
)
|
|
||||||
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
|
|
||||||
await self._ensure_index(embeddings.shape[1])
|
|
||||||
|
|
||||||
self.index.add(embeddings)
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
self.chunk_by_index[indexlen + i] = Chunk(
|
|
||||||
content=chunk[0],
|
|
||||||
token_count=chunk[1],
|
|
||||||
document_id=doc.document_id,
|
|
||||||
)
|
|
||||||
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
|
||||||
self.id_by_index[indexlen + i] = doc.document_id
|
|
||||||
|
|
||||||
async def query_documents(
|
|
||||||
self,
|
|
||||||
model: "SentenceTransformer",
|
|
||||||
query: InterleavedTextMedia,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> QueryDocumentsResponse:
|
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
k = params.get("max_chunks", 3)
|
|
||||||
|
|
||||||
def _process(c) -> str:
|
|
||||||
if isinstance(c, str):
|
|
||||||
return c
|
|
||||||
else:
|
|
||||||
return "<media>"
|
|
||||||
|
|
||||||
if isinstance(query, list):
|
|
||||||
query_str = " ".join([_process(c) for c in query])
|
|
||||||
else:
|
|
||||||
query_str = _process(query)
|
|
||||||
|
|
||||||
query_vector = model.encode([query_str])[0]
|
|
||||||
distances, indices = self.index.search(
|
distances, indices = self.index.search(
|
||||||
query_vector.reshape(1, -1).astype(np.float32), k
|
embedding.reshape(1, -1).astype(np.float32), k
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
|
@ -119,17 +56,11 @@ class BankState:
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
|
||||||
if self.index is None:
|
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
|
||||||
return self.index
|
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory):
|
class FaissMemoryImpl(Memory):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = None
|
self.cache = {}
|
||||||
self.states = {}
|
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
@ -153,14 +84,15 @@ class FaissMemoryImpl(Memory):
|
||||||
config=config,
|
config=config,
|
||||||
url=url,
|
url=url,
|
||||||
)
|
)
|
||||||
state = BankState(bank=bank)
|
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
|
||||||
self.states[bank_id] = state
|
self.cache[bank_id] = index
|
||||||
return bank
|
return bank
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
if bank_id not in self.states:
|
index = self.cache.get(bank_id)
|
||||||
|
if index is None:
|
||||||
return None
|
return None
|
||||||
return self.states[bank_id].bank
|
return index.bank
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -168,10 +100,11 @@ class FaissMemoryImpl(Memory):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
index = self.cache.get(bank_id)
|
||||||
state = self.states[bank_id]
|
if index is None:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
await state.insert_documents(self.get_model(), documents)
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -179,16 +112,8 @@ class FaissMemoryImpl(Memory):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
index = self.cache.get(bank_id)
|
||||||
state = self.states[bank_id]
|
if index is None:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
return await state.query_documents(self.get_model(), query, params)
|
return await index.query_documents(query, params)
|
||||||
|
|
||||||
def get_model(self) -> "SentenceTransformer":
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
if self.model is None:
|
|
||||||
print("Loading sentence transformer")
|
|
||||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
|
@ -6,20 +6,38 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
EMBEDDING_DEPS = [
|
||||||
|
"blobfile",
|
||||||
|
"sentence-transformers",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def available_memory_providers() -> List[ProviderSpec]:
|
def available_memory_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
provider_id="meta-reference-faiss",
|
provider_type="meta-reference-faiss",
|
||||||
pip_packages=[
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
"blobfile",
|
|
||||||
"faiss-cpu",
|
|
||||||
"sentence-transformers",
|
|
||||||
],
|
|
||||||
module="llama_toolchain.memory.meta_reference.faiss",
|
module="llama_toolchain.memory.meta_reference.faiss",
|
||||||
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.memory,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_id="chromadb",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||||
|
module="llama_toolchain.memory.adapters.chroma",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.memory,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_id="pgvector",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||||
|
module="llama_toolchain.memory.adapters.pgvector",
|
||||||
|
config_class="llama_toolchain.memory.adapters.pgvector.PGVectorConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -13,7 +13,7 @@ def available_safety_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_id="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"codeshield",
|
"codeshield",
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_toolchain.evaluations.api import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.batch_inference.api import * # noqa: F403
|
from llama_toolchain.batch_inference.api import * # noqa: F403
|
||||||
from llama_toolchain.memory.api import * # noqa: F403
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
from llama_toolchain.observability.api import * # noqa: F403
|
from llama_toolchain.telemetry.api import * # noqa: F403
|
||||||
from llama_toolchain.post_training.api import * # noqa: F403
|
from llama_toolchain.post_training.api import * # noqa: F403
|
||||||
from llama_toolchain.reward_scoring.api import * # noqa: F403
|
from llama_toolchain.reward_scoring.api import * # noqa: F403
|
||||||
from llama_toolchain.synthetic_data_generation.api import * # noqa: F403
|
from llama_toolchain.synthetic_data_generation.api import * # noqa: F403
|
||||||
|
@ -24,7 +24,7 @@ class LlamaStack(
|
||||||
RewardScoring,
|
RewardScoring,
|
||||||
SyntheticDataGeneration,
|
SyntheticDataGeneration,
|
||||||
Datasets,
|
Datasets,
|
||||||
Observability,
|
Telemetry,
|
||||||
PostTraining,
|
PostTraining,
|
||||||
Memory,
|
Memory,
|
||||||
Evaluations,
|
Evaluations,
|
||||||
|
|
|
@ -134,7 +134,7 @@ class LogSearchRequest(BaseModel):
|
||||||
filters: Optional[Dict[str, Any]] = None
|
filters: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class Observability(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/experiments/create")
|
@webmethod(route="/experiments/create")
|
||||||
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ...
|
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ...
|
||||||
|
|
|
@ -2,7 +2,7 @@ blobfile
|
||||||
fire
|
fire
|
||||||
httpx
|
httpx
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
llama-models
|
llama-models>=0.0.13
|
||||||
pydantic
|
pydantic
|
||||||
requests
|
requests
|
||||||
termcolor
|
termcolor
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -35,7 +35,10 @@ from llama_toolchain.stack import LlamaStack
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
||||||
STREAMING_ENDPOINTS = ["/agentic_system/turn/create"]
|
STREAMING_ENDPOINTS = [
|
||||||
|
"/agentic_system/turn/create",
|
||||||
|
"/inference/chat_completion",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_sse_stream_responses(spec: Specification):
|
def patch_sse_stream_responses(spec: Specification):
|
||||||
|
|
|
@ -468,12 +468,14 @@ class Generator:
|
||||||
builder = ContentBuilder(self.schema_builder)
|
builder = ContentBuilder(self.schema_builder)
|
||||||
first = next(iter(op.request_params))
|
first = next(iter(op.request_params))
|
||||||
request_name, request_type = first
|
request_name, request_type = first
|
||||||
|
|
||||||
|
from dataclasses import make_dataclass
|
||||||
|
|
||||||
if len(op.request_params) == 1 and "Request" in first[1].__name__:
|
if len(op.request_params) == 1 and "Request" in first[1].__name__:
|
||||||
# TODO(ashwin): Undo the "Request" hack and this entire block eventually
|
# TODO(ashwin): Undo the "Request" hack and this entire block eventually
|
||||||
request_name, request_type = first
|
request_name = first[1].__name__ + "Wrapper"
|
||||||
|
request_type = make_dataclass(request_name, op.request_params)
|
||||||
else:
|
else:
|
||||||
from dataclasses import make_dataclass
|
|
||||||
|
|
||||||
op_name = "".join(word.capitalize() for word in op.name.split("_"))
|
op_name = "".join(word.capitalize() for word in op.name.split("_"))
|
||||||
request_name = f"{op_name}Request"
|
request_name = f"{op_name}Request"
|
||||||
request_type = make_dataclass(request_name, op.request_params)
|
request_type = make_dataclass(request_name, op.request_params)
|
||||||
|
|
|
@ -28,4 +28,4 @@ if [ ${#missing_packages[@]} -ne 0 ]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
PYTHONPATH=$PYTHONPATH:../.. python3 -m rfcs.openapi_generator.generate $*
|
PYTHONPATH=$PYTHONPATH:../.. python -m rfcs.openapi_generator.generate $*
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue