mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'main' into tgi-integration
This commit is contained in:
commit
04f0b8fe11
38 changed files with 2157 additions and 548 deletions
|
@ -248,51 +248,51 @@ llama stack list-distributions
|
|||
```
|
||||
|
||||
<pre style="font-family: monospace;">
|
||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
||||
| Distribution ID | Providers | Description |
|
||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
||||
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
|
||||
| | "inference": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agentic_system": "meta-reference" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
||||
| remote | { | Point to remote services for all llama stack APIs |
|
||||
| | "inference": "remote", | |
|
||||
| | "safety": "remote", | |
|
||||
| | "agentic_system": "remote", | |
|
||||
| | "memory": "remote" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
||||
| local-ollama | { | Like local, but use ollama for running LLM inference |
|
||||
| | "inference": "remote::ollama", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agentic_system": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
||||
| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
|
||||
| | "inference": "remote::fireworks", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agentic_system": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
||||
| local-plus-together-inference | { | Use Together.ai for running LLM inference |
|
||||
| | "inference": "remote::together", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agentic_system": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss" | |
|
||||
| | } | |
|
||||
|--------------------------------|---------------------------------------|-------------------------------------------------------------------------------------------|
|
||||
| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https://huggingface.co/ |
|
||||
| | "inference": "remote::tgi", | inference-endpoints/dedicated)) for running LLM inference. When using HF Inference |
|
||||
| | "safety": "meta-reference", | Endpoints, you must provide the name of the endpoint. |
|
||||
| | "agentic_system": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
|
||||
i+-------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||
| Distribution Type | Providers | Description |
|
||||
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
|
||||
| | "inference": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agentic_system": "meta-reference" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||
| remote | { | Point to remote services for all llama stack APIs |
|
||||
| | "inference": "remote", | |
|
||||
| | "safety": "remote", | |
|
||||
| | "agentic_system": "remote", | |
|
||||
| | "memory": "remote" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||
| local-ollama | { | Like local, but use ollama for running LLM inference |
|
||||
| | "inference": "remote::ollama", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agentic_system": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||
| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
|
||||
| | "inference": "remote::fireworks", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agentic_system": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||
| local-plus-together-inference | { | Use Together.ai for running LLM inference |
|
||||
| | "inference": "remote::together", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agentic_system": "meta-reference", | |
|
||||
| | "memory": "meta-reference-faiss" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||
| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https:// |
|
||||
| | "inference": "remote::tgi", | huggingface.co/inference-endpoints/dedicated)) for running LLM |
|
||||
| | "safety": "meta-reference", | inference. When using HF Inference Endpoints, you must provide the |
|
||||
| | "agentic_system": "meta-reference", | name of the endpoint. |
|
||||
| | "memory": "meta-reference-faiss" | |
|
||||
| | } | |
|
||||
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
|
||||
</pre>
|
||||
|
||||
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.
|
||||
|
|
|
@ -116,10 +116,47 @@ MemoryBankConfig = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 10
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ from llama_toolchain.tools.builtin import (
|
|||
SingleMessageBuiltinTool,
|
||||
)
|
||||
|
||||
from .rag.context_retriever import generate_rag_query
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
||||
|
@ -664,7 +665,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# (i.e., no prior turns uploaded an Attachment)
|
||||
return None, []
|
||||
|
||||
query = " ".join(m.content for m in messages)
|
||||
query = await generate_rag_query(
|
||||
memory.query_generator_config, messages, inference_api=self.inference_api
|
||||
)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
|
||||
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from termcolor import cprint # noqa: F401
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
# cprint(f"Generated query >>>: {query}", color="green")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {"messages": [m.model_dump() for m in messages]}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
response = inference_api.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
query = chunk.completion_message.content
|
||||
|
||||
return query
|
|
@ -13,7 +13,7 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
|
|||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.agentic_system,
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
"matplotlib",
|
||||
|
|
|
@ -52,7 +52,7 @@ class StackBuild(Subcommand):
|
|||
BuildType,
|
||||
)
|
||||
|
||||
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
|
||||
allowed_ids = [d.distribution_type for d in available_distribution_specs()]
|
||||
self.parser.add_argument(
|
||||
"distribution",
|
||||
type=str,
|
||||
|
@ -101,7 +101,7 @@ class StackBuild(Subcommand):
|
|||
api_inputs.append(
|
||||
ApiInput(
|
||||
api=api,
|
||||
provider=provider_spec.provider_id,
|
||||
provider=provider_spec.provider_type,
|
||||
)
|
||||
)
|
||||
docker_image = None
|
||||
|
@ -115,11 +115,11 @@ class StackBuild(Subcommand):
|
|||
self.parser.error(f"Could not find distribution {args.distribution}")
|
||||
return
|
||||
|
||||
for api, provider_id in dist.providers.items():
|
||||
for api, provider_type in dist.providers.items():
|
||||
api_inputs.append(
|
||||
ApiInput(
|
||||
api=api,
|
||||
provider=provider_id,
|
||||
provider=provider_type,
|
||||
)
|
||||
)
|
||||
docker_image = dist.docker_image
|
||||
|
@ -128,6 +128,6 @@ class StackBuild(Subcommand):
|
|||
api_inputs,
|
||||
build_type=BuildType(args.type),
|
||||
name=args.name,
|
||||
distribution_id=args.distribution,
|
||||
distribution_type=args.distribution,
|
||||
docker_image=docker_image,
|
||||
)
|
||||
|
|
|
@ -36,7 +36,7 @@ class StackConfigure(Subcommand):
|
|||
)
|
||||
from llama_toolchain.core.package import BuildType
|
||||
|
||||
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
|
||||
allowed_ids = [d.distribution_type for d in available_distribution_specs()]
|
||||
self.parser.add_argument(
|
||||
"distribution",
|
||||
type=str,
|
||||
|
@ -84,7 +84,7 @@ def configure_llama_distribution(config_file: Path) -> None:
|
|||
|
||||
if config.providers:
|
||||
cprint(
|
||||
f"Configuration already exists for {config.distribution_id}. Will overwrite...",
|
||||
f"Configuration already exists for {config.distribution_type}. Will overwrite...",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
|
|
|
@ -33,7 +33,7 @@ class StackListDistributions(Subcommand):
|
|||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
"Distribution ID",
|
||||
"Distribution Type",
|
||||
"Providers",
|
||||
"Description",
|
||||
]
|
||||
|
@ -43,7 +43,7 @@ class StackListDistributions(Subcommand):
|
|||
providers = {k.value: v for k, v in spec.providers.items()}
|
||||
rows.append(
|
||||
[
|
||||
spec.distribution_id,
|
||||
spec.distribution_type,
|
||||
json.dumps(providers, indent=2),
|
||||
spec.description,
|
||||
]
|
||||
|
|
|
@ -41,7 +41,7 @@ class StackListProviders(Subcommand):
|
|||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
"Provider ID",
|
||||
"Provider Type",
|
||||
"PIP Package Dependencies",
|
||||
]
|
||||
|
||||
|
@ -49,7 +49,7 @@ class StackListProviders(Subcommand):
|
|||
for spec in providers_for_api.values():
|
||||
rows.append(
|
||||
[
|
||||
spec.provider_id,
|
||||
spec.provider_type,
|
||||
",".join(spec.pip_packages),
|
||||
]
|
||||
)
|
||||
|
|
|
@ -80,7 +80,7 @@ class StackRun(Subcommand):
|
|||
with open(config_file, "r") as f:
|
||||
config = PackageConfig(**yaml.safe_load(f))
|
||||
|
||||
if not config.distribution_id:
|
||||
if not config.distribution_type:
|
||||
raise ValueError("Build config appears to be corrupt.")
|
||||
|
||||
if config.docker_image:
|
||||
|
|
|
@ -20,12 +20,12 @@ fi
|
|||
set -euo pipefail
|
||||
|
||||
if [ "$#" -ne 3 ]; then
|
||||
echo "Usage: $0 <distribution_id> <build_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 <distribution_id> mybuild 'numpy pandas scipy'" >&2
|
||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
distribution_id="$1"
|
||||
distribution_type="$1"
|
||||
build_name="$2"
|
||||
env_name="llamastack-$build_name"
|
||||
pip_dependencies="$3"
|
||||
|
@ -117,4 +117,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
|||
|
||||
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
|
||||
|
||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type conda_env
|
||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type conda_env
|
||||
|
|
|
@ -5,12 +5,12 @@ LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
|||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
if [ "$#" -ne 4 ]; then
|
||||
echo "Usage: $0 <distribution_id> <build_name> <docker_base> <pip_dependencies>
|
||||
echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
||||
echo "Usage: $0 <distribution_type> <build_name> <docker_base> <pip_dependencies>
|
||||
echo "Example: $0 distribution_type my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
distribution_id=$1
|
||||
distribution_type=$1
|
||||
build_name="$2"
|
||||
image_name="llamastack-$build_name"
|
||||
docker_base=$3
|
||||
|
@ -110,4 +110,4 @@ set +x
|
|||
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
|
||||
echo "You can run it with: podman run -p 8000:8000 $image_name"
|
||||
|
||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type container
|
||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type container
|
||||
|
|
|
@ -21,14 +21,14 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
|
|||
for api_str, stub_config in existing_configs.items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
provider_id = stub_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
provider_type = stub_config["provider_type"]
|
||||
if provider_type not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
|
||||
f"Unknown provider `{provider_type}` is not available for API `{api_str}`"
|
||||
)
|
||||
|
||||
provider_spec = providers[provider_id]
|
||||
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
|
||||
provider_spec = providers[provider_type]
|
||||
cprint(f"Configuring API: {api_str} ({provider_type})", "white", attrs=["bold"])
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
try:
|
||||
|
@ -43,7 +43,7 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
|
|||
print("")
|
||||
|
||||
provider_configs[api_str] = {
|
||||
"provider_id": provider_id,
|
||||
"provider_type": provider_type,
|
||||
**provider_config.dict(),
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class ApiEndpoint(BaseModel):
|
|||
@json_schema_type
|
||||
class ProviderSpec(BaseModel):
|
||||
api: Api
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config_class: str = Field(
|
||||
...,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
|
@ -100,7 +100,7 @@ class RemoteProviderConfig(BaseModel):
|
|||
return url.rstrip("/")
|
||||
|
||||
|
||||
def remote_provider_id(adapter_id: str) -> str:
|
||||
def remote_provider_type(adapter_id: str) -> str:
|
||||
return f"remote::{adapter_id}"
|
||||
|
||||
|
||||
|
@ -141,22 +141,22 @@ def remote_provider_spec(
|
|||
if adapter and adapter.config_class
|
||||
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
||||
provider_type = remote_provider_type(adapter.adapter_id) if adapter else "remote"
|
||||
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
distribution_id: str
|
||||
distribution_type: str
|
||||
description: str
|
||||
|
||||
docker_image: Optional[str] = None
|
||||
providers: Dict[Api, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Provider IDs for each of the APIs provided by this distribution",
|
||||
description="Provider Types for each of the APIs provided by this distribution",
|
||||
)
|
||||
|
||||
|
||||
|
@ -171,7 +171,7 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
|
|||
this could be just a hash
|
||||
""",
|
||||
)
|
||||
distribution_id: Optional[str] = None
|
||||
distribution_type: Optional[str] = None
|
||||
|
||||
docker_image: Optional[str] = Field(
|
||||
default=None,
|
||||
|
|
|
@ -83,18 +83,18 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
|
||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
inference_providers_by_id = {
|
||||
a.provider_id: a for a in available_inference_providers()
|
||||
a.provider_type: a for a in available_inference_providers()
|
||||
}
|
||||
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
|
||||
safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()}
|
||||
agentic_system_providers_by_id = {
|
||||
a.provider_id: a for a in available_agentic_system_providers()
|
||||
a.provider_type: a for a in available_agentic_system_providers()
|
||||
}
|
||||
|
||||
ret = {
|
||||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||
Api.memory: {a.provider_type: a for a in available_memory_providers()},
|
||||
}
|
||||
for k, v in ret.items():
|
||||
v["remote"] = remote_provider_spec(k)
|
||||
|
|
|
@ -14,7 +14,7 @@ from .datatypes import * # noqa: F403
|
|||
def available_distribution_specs() -> List[DistributionSpec]:
|
||||
return [
|
||||
DistributionSpec(
|
||||
distribution_id="local",
|
||||
distribution_type="local",
|
||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||
providers={
|
||||
Api.inference: "meta-reference",
|
||||
|
@ -24,35 +24,35 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="remote",
|
||||
distribution_type="remote",
|
||||
description="Point to remote services for all llama stack APIs",
|
||||
providers={x: "remote" for x in Api},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="local-ollama",
|
||||
distribution_type="local-ollama",
|
||||
description="Like local, but use ollama for running LLM inference",
|
||||
providers={
|
||||
Api.inference: remote_provider_id("ollama"),
|
||||
Api.inference: remote_provider_type("ollama"),
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="local-plus-fireworks-inference",
|
||||
distribution_type="local-plus-fireworks-inference",
|
||||
description="Use Fireworks.ai for running LLM inference",
|
||||
providers={
|
||||
Api.inference: remote_provider_id("fireworks"),
|
||||
Api.inference: remote_provider_type("fireworks"),
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="local-plus-together-inference",
|
||||
distribution_type="local-plus-together-inference",
|
||||
description="Use Together.ai for running LLM inference",
|
||||
providers={
|
||||
Api.inference: remote_provider_id("together"),
|
||||
Api.inference: remote_provider_type("together"),
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
|
@ -72,8 +72,8 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
|
||||
|
||||
@lru_cache()
|
||||
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
|
||||
def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]:
|
||||
for spec in available_distribution_specs():
|
||||
if spec.distribution_id == distribution_id:
|
||||
if spec.distribution_type == distribution_type:
|
||||
return spec
|
||||
return None
|
||||
|
|
|
@ -46,13 +46,13 @@ def build_package(
|
|||
api_inputs: List[ApiInput],
|
||||
build_type: BuildType,
|
||||
name: str,
|
||||
distribution_id: Optional[str] = None,
|
||||
distribution_type: Optional[str] = None,
|
||||
docker_image: Optional[str] = None,
|
||||
):
|
||||
if not distribution_id:
|
||||
distribution_id = "adhoc"
|
||||
if not distribution_type:
|
||||
distribution_type = "adhoc"
|
||||
|
||||
build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor()
|
||||
build_dir = BUILDS_BASE_DIR / distribution_type / build_type.descriptor()
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
|
||||
package_name = name.replace("::", "-")
|
||||
|
@ -79,7 +79,7 @@ def build_package(
|
|||
if provider.docker_image:
|
||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||
|
||||
stub_config[api.value] = {"provider_id": api_input.provider}
|
||||
stub_config[api.value] = {"provider_type": api_input.provider}
|
||||
|
||||
if package_file.exists():
|
||||
cprint(
|
||||
|
@ -92,7 +92,7 @@ def build_package(
|
|||
c.providers[api_str] = new_config
|
||||
else:
|
||||
existing_config = c.providers[api_str]
|
||||
if existing_config["provider_id"] != new_config["provider_id"]:
|
||||
if existing_config["provider_type"] != new_config["provider_type"]:
|
||||
cprint(
|
||||
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
|
||||
color="yellow",
|
||||
|
@ -105,7 +105,7 @@ def build_package(
|
|||
providers=stub_config,
|
||||
)
|
||||
|
||||
c.distribution_id = distribution_id
|
||||
c.distribution_type = distribution_type
|
||||
c.docker_image = package_name if build_type == BuildType.container else None
|
||||
c.conda_env = package_name if build_type == BuildType.conda_env else None
|
||||
|
||||
|
@ -119,7 +119,7 @@ def build_package(
|
|||
)
|
||||
args = [
|
||||
script,
|
||||
distribution_id,
|
||||
distribution_type,
|
||||
package_name,
|
||||
package_deps.docker_image,
|
||||
" ".join(package_deps.pip_packages),
|
||||
|
@ -130,7 +130,7 @@ def build_package(
|
|||
)
|
||||
args = [
|
||||
script,
|
||||
distribution_id,
|
||||
distribution_type,
|
||||
package_name,
|
||||
" ".join(package_deps.pip_packages),
|
||||
]
|
||||
|
|
|
@ -284,13 +284,13 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
for api_str, provider_config in config["providers"].items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
provider_id = provider_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
provider_type = provider_config["provider_type"]
|
||||
if provider_type not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
f"Unknown provider `{provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
provider_specs[api] = providers[provider_id]
|
||||
provider_specs[api] = providers[provider_type]
|
||||
|
||||
impls = resolve_impls(provider_specs, config)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ def available_inference_providers() -> List[ProviderSpec]:
|
|||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"accelerate",
|
||||
"blobfile",
|
||||
|
|
15
llama_toolchain/memory/adapters/chroma/__init__.py
Normal file
15
llama_toolchain/memory/adapters/chroma/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
from .chroma import ChromaMemoryAdapter
|
||||
|
||||
impl = ChromaMemoryAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
165
llama_toolchain/memory/adapters/chroma/chroma.py
Normal file
165
llama_toolchain/memory/adapters/chroma/chroma.py
Normal file
|
@ -0,0 +1,165 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import chromadb
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
|
||||
|
||||
from llama_toolchain.memory.common.vector_store import BankWithIndex, EmbeddingIndex
|
||||
|
||||
|
||||
class ChromaIndex(EmbeddingIndex):
|
||||
def __init__(self, client: chromadb.AsyncHttpClient, collection):
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"Adding chunk #{i} tokens={chunk.token_count}")
|
||||
|
||||
await self.collection.add(
|
||||
documents=[chunk.json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
||||
)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
results = await self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
)
|
||||
distances = results["distances"][0]
|
||||
documents = results["documents"][0]
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for dist, doc in zip(distances, documents):
|
||||
try:
|
||||
doc = json.loads(doc)
|
||||
chunk = Chunk(**doc)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {doc}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / float(dist))
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(Memory):
|
||||
def __init__(self, url: str) -> None:
|
||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.path and parsed.path != "/":
|
||||
raise ValueError("URL should not contain a path")
|
||||
|
||||
self.host = parsed.hostname
|
||||
self.port = parsed.port
|
||||
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
||||
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise RuntimeError("Could not connect to Chroma server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
collection = await self.client.create_collection(
|
||||
name=bank_id,
|
||||
metadata={"bank": bank.json()},
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=bank, index=ChromaIndex(self.client, collection)
|
||||
)
|
||||
self.cache[bank_id] = bank_index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
if collection.name == bank_id:
|
||||
print(collection.metadata)
|
||||
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
||||
return None
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
15
llama_toolchain/memory/adapters/pgvector/__init__.py
Normal file
15
llama_toolchain/memory/adapters/pgvector/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import PGVectorConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
||||
from .pgvector import PGVectorMemoryAdapter
|
||||
|
||||
impl = PGVectorMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
17
llama_toolchain/memory/adapters/pgvector/config.py
Normal file
17
llama_toolchain/memory/adapters/pgvector/config.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PGVectorConfig(BaseModel):
|
||||
host: str = Field(default="localhost")
|
||||
port: int = Field(default=5432)
|
||||
db: str
|
||||
user: str
|
||||
password: str
|
234
llama_toolchain/memory/adapters/pgvector/pgvector.py
Normal file
234
llama_toolchain/memory/adapters/pgvector/pgvector.py
Normal file
|
@ -0,0 +1,234 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import psycopg2
|
||||
from numpy.typing import NDArray
|
||||
from psycopg2 import sql
|
||||
from psycopg2.extras import execute_values, Json
|
||||
from pydantic import BaseModel
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
|
||||
|
||||
from llama_toolchain.memory.common.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
from .config import PGVectorConfig
|
||||
|
||||
|
||||
def check_extension_version(cur):
|
||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||
result = cur.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
||||
query = sql.SQL(
|
||||
"""
|
||||
INSERT INTO metadata_store (key, data)
|
||||
VALUES %s
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET data = EXCLUDED.data
|
||||
"""
|
||||
)
|
||||
|
||||
values = [(key, Json(model.dict())) for key, model in keys_models]
|
||||
execute_values(cur, query, values, template="(%s, %s)")
|
||||
|
||||
|
||||
def load_models(cur, keys: List[str], cls):
|
||||
query = "SELECT key, data FROM metadata_store"
|
||||
if keys:
|
||||
placeholders = ",".join(["%s"] * len(keys))
|
||||
query += f" WHERE key IN ({placeholders})"
|
||||
cur.execute(query, keys)
|
||||
else:
|
||||
cur.execute(query)
|
||||
|
||||
rows = cur.fetchall()
|
||||
return [cls(**row["data"]) for row in rows]
|
||||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
||||
self.cursor = cursor
|
||||
self.table_name = f"vector_store_{bank.name}"
|
||||
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({dimension})
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
values = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"Adding chunk #{i} tokens={chunk.token_count}")
|
||||
values.append(
|
||||
(
|
||||
f"{chunk.document_id}:chunk-{i}",
|
||||
Json(chunk.dict()),
|
||||
embeddings[i].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
query = sql.SQL(
|
||||
f"""
|
||||
INSERT INTO {self.table_name} (id, document, embedding)
|
||||
VALUES %s
|
||||
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
|
||||
"""
|
||||
)
|
||||
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
||||
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
SELECT document, embedding <-> %s::vector AS distance
|
||||
FROM {self.table_name}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
""",
|
||||
(embedding.tolist(), k),
|
||||
)
|
||||
results = self.cursor.fetchall()
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc, dist in results:
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(1.0 / float(dist))
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||
self.config = config
|
||||
self.cursor = None
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
version = check_extension_version(self.cursor)
|
||||
if version:
|
||||
print(f"Vector extension version: {version}")
|
||||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
data JSONB
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
[
|
||||
(bank.bank_id, bank),
|
||||
],
|
||||
)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
banks = load_models(self.cursor, [bank_id], MemoryBank)
|
||||
if not banks:
|
||||
return None
|
||||
|
||||
bank = banks[0]
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
120
llama_toolchain/memory/common/vector_store.py
Normal file
120
llama_toolchain/memory/common/vector_store.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
|
||||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODEL = None
|
||||
|
||||
|
||||
def get_embedding_model() -> "SentenceTransformer":
|
||||
global EMBEDDING_MODEL
|
||||
|
||||
if EMBEDDING_MODEL is None:
|
||||
print("Loading sentence transformer")
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return EMBEDDING_MODEL
|
||||
|
||||
|
||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
return r.text
|
||||
|
||||
return interleaved_text_media_as_str(doc.content)
|
||||
|
||||
|
||||
def make_overlapped_chunks(
|
||||
document_id: str, text: str, window_len: int, overlap_len: int
|
||||
) -> List[Chunk]:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), window_len - overlap_len):
|
||||
toks = tokens[i : i + window_len]
|
||||
chunk = tokenizer.decode(toks)
|
||||
chunks.append(
|
||||
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class EmbeddingIndex(ABC):
|
||||
@abstractmethod
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BankWithIndex:
|
||||
bank: MemoryBank
|
||||
index: EmbeddingIndex
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model()
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
self.bank.config.chunk_size_in_tokens,
|
||||
self.bank.config.overlap_size_in_tokens
|
||||
or (self.bank.config.chunk_size_in_tokens // 4),
|
||||
)
|
||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
||||
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
if isinstance(query, list):
|
||||
query_str = " ".join([_process(c) for c in query])
|
||||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model()
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
return await self.index.query(query_vector, k)
|
|
@ -5,108 +5,45 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import faiss
|
||||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
from llama_toolchain.memory.common.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
from .config import FaissImplConfig
|
||||
|
||||
|
||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
return r.text
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
id_by_index: Dict[int, str]
|
||||
chunk_by_index: Dict[int, str]
|
||||
|
||||
return interleaved_text_media_as_str(doc.content)
|
||||
def __init__(self, dimension: int):
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.id_by_index = {}
|
||||
self.chunk_by_index = {}
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
indexlen = len(self.id_by_index)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
|
||||
self.id_by_index[indexlen + i] = chunk.document_id
|
||||
|
||||
def make_overlapped_chunks(
|
||||
text: str, window_len: int, overlap_len: int
|
||||
) -> List[Tuple[str, int]]:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), window_len - overlap_len):
|
||||
toks = tokens[i : i + window_len]
|
||||
chunk = tokenizer.decode(toks)
|
||||
chunks.append((chunk, len(toks)))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@dataclass
|
||||
class BankState:
|
||||
bank: MemoryBank
|
||||
index: Optional[faiss.IndexFlatL2] = None
|
||||
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
|
||||
id_by_index: Dict[int, str] = field(default_factory=dict)
|
||||
chunk_by_index: Dict[int, str] = field(default_factory=dict)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
model: "SentenceTransformer",
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
chunk_size = self.bank.config.chunk_size_in_tokens
|
||||
|
||||
for doc in documents:
|
||||
indexlen = len(self.id_by_index)
|
||||
self.doc_by_id[doc.document_id] = doc
|
||||
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
content,
|
||||
self.bank.config.chunk_size_in_tokens,
|
||||
self.bank.config.overlap_size_in_tokens
|
||||
or (self.bank.config.chunk_size_in_tokens // 4),
|
||||
)
|
||||
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
|
||||
await self._ensure_index(embeddings.shape[1])
|
||||
|
||||
self.index.add(embeddings)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = Chunk(
|
||||
content=chunk[0],
|
||||
token_count=chunk[1],
|
||||
document_id=doc.document_id,
|
||||
)
|
||||
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
||||
self.id_by_index[indexlen + i] = doc.document_id
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
model: "SentenceTransformer",
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
if isinstance(query, list):
|
||||
query_str = " ".join([_process(c) for c in query])
|
||||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
query_vector = model.encode([query_str])[0]
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
distances, indices = self.index.search(
|
||||
query_vector.reshape(1, -1).astype(np.float32), k
|
||||
embedding.reshape(1, -1).astype(np.float32), k
|
||||
)
|
||||
|
||||
chunks = []
|
||||
|
@ -119,17 +56,11 @@ class BankState:
|
|||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
||||
if self.index is None:
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
return self.index
|
||||
|
||||
|
||||
class FaissMemoryImpl(Memory):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.states = {}
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
|
@ -153,14 +84,15 @@ class FaissMemoryImpl(Memory):
|
|||
config=config,
|
||||
url=url,
|
||||
)
|
||||
state = BankState(bank=bank)
|
||||
self.states[bank_id] = state
|
||||
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
if bank_id not in self.states:
|
||||
index = self.cache.get(bank_id)
|
||||
if index is None:
|
||||
return None
|
||||
return self.states[bank_id].bank
|
||||
return index.bank
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -168,10 +100,11 @@ class FaissMemoryImpl(Memory):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||
state = self.states[bank_id]
|
||||
index = self.cache.get(bank_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await state.insert_documents(self.get_model(), documents)
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
|
@ -179,16 +112,8 @@ class FaissMemoryImpl(Memory):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||
state = self.states[bank_id]
|
||||
index = self.cache.get(bank_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await state.query_documents(self.get_model(), query, params)
|
||||
|
||||
def get_model(self) -> "SentenceTransformer":
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
if self.model is None:
|
||||
print("Loading sentence transformer")
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return self.model
|
||||
return await index.query_documents(query, params)
|
||||
|
|
|
@ -6,20 +6,38 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
|
||||
EMBEDDING_DEPS = [
|
||||
"blobfile",
|
||||
"sentence-transformers",
|
||||
]
|
||||
|
||||
|
||||
def available_memory_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
provider_id="meta-reference-faiss",
|
||||
pip_packages=[
|
||||
"blobfile",
|
||||
"faiss-cpu",
|
||||
"sentence-transformers",
|
||||
],
|
||||
provider_type="meta-reference-faiss",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_toolchain.memory.meta_reference.faiss",
|
||||
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||
module="llama_toolchain.memory.adapters.chroma",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="pgvector",
|
||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||
module="llama_toolchain.memory.adapters.pgvector",
|
||||
config_class="llama_toolchain.memory.adapters.pgvector.PGVectorConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -13,7 +13,7 @@ def available_safety_providers() -> List[ProviderSpec]:
|
|||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"accelerate",
|
||||
"codeshield",
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_toolchain.evaluations.api import * # noqa: F403
|
|||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.batch_inference.api import * # noqa: F403
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
from llama_toolchain.observability.api import * # noqa: F403
|
||||
from llama_toolchain.telemetry.api import * # noqa: F403
|
||||
from llama_toolchain.post_training.api import * # noqa: F403
|
||||
from llama_toolchain.reward_scoring.api import * # noqa: F403
|
||||
from llama_toolchain.synthetic_data_generation.api import * # noqa: F403
|
||||
|
@ -24,7 +24,7 @@ class LlamaStack(
|
|||
RewardScoring,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Observability,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
Evaluations,
|
||||
|
|
|
@ -134,7 +134,7 @@ class LogSearchRequest(BaseModel):
|
|||
filters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class Observability(Protocol):
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/experiments/create")
|
||||
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ...
|
||||
|
|
@ -2,7 +2,7 @@ blobfile
|
|||
fire
|
||||
httpx
|
||||
huggingface-hub
|
||||
llama-models
|
||||
llama-models>=0.0.13
|
||||
pydantic
|
||||
requests
|
||||
termcolor
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -35,7 +35,10 @@ from llama_toolchain.stack import LlamaStack
|
|||
|
||||
|
||||
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
||||
STREAMING_ENDPOINTS = ["/agentic_system/turn/create"]
|
||||
STREAMING_ENDPOINTS = [
|
||||
"/agentic_system/turn/create",
|
||||
"/inference/chat_completion",
|
||||
]
|
||||
|
||||
|
||||
def patch_sse_stream_responses(spec: Specification):
|
||||
|
|
|
@ -468,12 +468,14 @@ class Generator:
|
|||
builder = ContentBuilder(self.schema_builder)
|
||||
first = next(iter(op.request_params))
|
||||
request_name, request_type = first
|
||||
|
||||
from dataclasses import make_dataclass
|
||||
|
||||
if len(op.request_params) == 1 and "Request" in first[1].__name__:
|
||||
# TODO(ashwin): Undo the "Request" hack and this entire block eventually
|
||||
request_name, request_type = first
|
||||
request_name = first[1].__name__ + "Wrapper"
|
||||
request_type = make_dataclass(request_name, op.request_params)
|
||||
else:
|
||||
from dataclasses import make_dataclass
|
||||
|
||||
op_name = "".join(word.capitalize() for word in op.name.split("_"))
|
||||
request_name = f"{op_name}Request"
|
||||
request_type = make_dataclass(request_name, op.request_params)
|
||||
|
|
|
@ -28,4 +28,4 @@ if [ ${#missing_packages[@]} -ne 0 ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
PYTHONPATH=$PYTHONPATH:../.. python3 -m rfcs.openapi_generator.generate $*
|
||||
PYTHONPATH=$PYTHONPATH:../.. python -m rfcs.openapi_generator.generate $*
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue