Merge branch 'main' into tgi-integration

This commit is contained in:
Celina Hanouti 2024-09-12 15:31:07 +02:00
commit 04f0b8fe11
38 changed files with 2157 additions and 548 deletions

View file

@ -248,51 +248,51 @@ llama stack list-distributions
```
<pre style="font-family: monospace;">
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
| Distribution ID | Providers | Description |
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
i+-------------------------------+---------------------------------------+----------------------------------------------------------------------+
| Distribution Type | Providers | Description |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
| | "inference": "meta-reference", | |
| | "memory": "meta-reference-faiss", | |
| | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference" | |
| | } | |
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| remote | { | Point to remote services for all llama stack APIs |
| | "inference": "remote", | |
| | "safety": "remote", | |
| | "agentic_system": "remote", | |
| | "memory": "remote" | |
| | } | |
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local-ollama | { | Like local, but use ollama for running LLM inference |
| | "inference": "remote::ollama", | |
| | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | |
| | } | |
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
| | "inference": "remote::fireworks", | |
| | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | |
| | } | |
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local-plus-together-inference | { | Use Together.ai for running LLM inference |
| | "inference": "remote::together", | |
| | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | |
| | } | |
|--------------------------------|---------------------------------------|-------------------------------------------------------------------------------------------|
| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https://huggingface.co/ |
| | "inference": "remote::tgi", | inference-endpoints/dedicated)) for running LLM inference. When using HF Inference |
| | "safety": "meta-reference", | Endpoints, you must provide the name of the endpoint. |
| | "agentic_system": "meta-reference", | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https:// |
| | "inference": "remote::tgi", | huggingface.co/inference-endpoints/dedicated)) for running LLM |
| | "safety": "meta-reference", | inference. When using HF Inference Endpoints, you must provide the |
| | "agentic_system": "meta-reference", | name of the endpoint. |
| | "memory": "meta-reference-faiss" | |
| | } | |
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
</pre>
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.

View file

@ -116,10 +116,47 @@ MemoryBankConfig = Annotated[
]
@json_schema_type
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 10

View file

@ -31,6 +31,7 @@ from llama_toolchain.tools.builtin import (
SingleMessageBuiltinTool,
)
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
@ -664,7 +665,9 @@ class ChatAgent(ShieldRunnerMixin):
# (i.e., no prior turns uploaded an Attachment)
return None, []
query = " ".join(m.content for m in messages)
query = await generate_rag_query(
memory.query_generator_config, messages, inference_api=self.inference_api
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from llama_toolchain.agentic_system.api import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from termcolor import cprint # noqa: F401
from llama_toolchain.inference.api import * # noqa: F403
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[Message],
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
# cprint(f"Generated query >>>: {query}", color="green")
return query
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message],
**kwargs,
):
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[Message],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [m.model_dump() for m in messages]}
template = Template(config.template)
content = template.render(m_dict)
model = config.model
message = UserMessage(content=content)
response = inference_api.chat_completion(
ChatCompletionRequest(
model=model,
messages=[message],
stream=False,
)
)
async for chunk in response:
query = chunk.completion_message.content
return query

View file

@ -13,7 +13,7 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.agentic_system,
provider_id="meta-reference",
provider_type="meta-reference",
pip_packages=[
"codeshield",
"matplotlib",

View file

@ -52,7 +52,7 @@ class StackBuild(Subcommand):
BuildType,
)
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
allowed_ids = [d.distribution_type for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
@ -101,7 +101,7 @@ class StackBuild(Subcommand):
api_inputs.append(
ApiInput(
api=api,
provider=provider_spec.provider_id,
provider=provider_spec.provider_type,
)
)
docker_image = None
@ -115,11 +115,11 @@ class StackBuild(Subcommand):
self.parser.error(f"Could not find distribution {args.distribution}")
return
for api, provider_id in dist.providers.items():
for api, provider_type in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
provider=provider_id,
provider=provider_type,
)
)
docker_image = dist.docker_image
@ -128,6 +128,6 @@ class StackBuild(Subcommand):
api_inputs,
build_type=BuildType(args.type),
name=args.name,
distribution_id=args.distribution,
distribution_type=args.distribution,
docker_image=docker_image,
)

View file

@ -36,7 +36,7 @@ class StackConfigure(Subcommand):
)
from llama_toolchain.core.package import BuildType
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
allowed_ids = [d.distribution_type for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
@ -84,7 +84,7 @@ def configure_llama_distribution(config_file: Path) -> None:
if config.providers:
cprint(
f"Configuration already exists for {config.distribution_id}. Will overwrite...",
f"Configuration already exists for {config.distribution_type}. Will overwrite...",
"yellow",
attrs=["bold"],
)

View file

@ -33,7 +33,7 @@ class StackListDistributions(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Distribution ID",
"Distribution Type",
"Providers",
"Description",
]
@ -43,7 +43,7 @@ class StackListDistributions(Subcommand):
providers = {k.value: v for k, v in spec.providers.items()}
rows.append(
[
spec.distribution_id,
spec.distribution_type,
json.dumps(providers, indent=2),
spec.description,
]

View file

@ -41,7 +41,7 @@ class StackListProviders(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Provider ID",
"Provider Type",
"PIP Package Dependencies",
]
@ -49,7 +49,7 @@ class StackListProviders(Subcommand):
for spec in providers_for_api.values():
rows.append(
[
spec.provider_id,
spec.provider_type,
",".join(spec.pip_packages),
]
)

View file

@ -80,7 +80,7 @@ class StackRun(Subcommand):
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
if not config.distribution_id:
if not config.distribution_type:
raise ValueError("Build config appears to be corrupt.")
if config.docker_image:

View file

@ -20,12 +20,12 @@ fi
set -euo pipefail
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <distribution_id> <build_name> <pip_dependencies>" >&2
echo "Example: $0 <distribution_id> mybuild 'numpy pandas scipy'" >&2
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
exit 1
fi
distribution_id="$1"
distribution_type="$1"
build_name="$2"
env_name="llamastack-$build_name"
pip_dependencies="$3"
@ -117,4 +117,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies"
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type conda_env
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type conda_env

View file

@ -5,12 +5,12 @@ LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then
echo "Usage: $0 <distribution_id> <build_name> <docker_base> <pip_dependencies>
echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn'
echo "Usage: $0 <distribution_type> <build_name> <docker_base> <pip_dependencies>
echo "Example: $0 distribution_type my-fastapi-app python:3.9-slim 'fastapi uvicorn'
exit 1
fi
distribution_id=$1
distribution_type=$1
build_name="$2"
image_name="llamastack-$build_name"
docker_base=$3
@ -110,4 +110,4 @@ set +x
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
echo "You can run it with: podman run -p 8000:8000 $image_name"
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type container
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_type --name "$build_name" --type container

View file

@ -21,14 +21,14 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
for api_str, stub_config in existing_configs.items():
api = Api(api_str)
providers = all_providers[api]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
provider_type = stub_config["provider_type"]
if provider_type not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
f"Unknown provider `{provider_type}` is not available for API `{api_str}`"
)
provider_spec = providers[provider_id]
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
provider_spec = providers[provider_type]
cprint(f"Configuring API: {api_str} ({provider_type})", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
try:
@ -43,7 +43,7 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
print("")
provider_configs[api_str] = {
"provider_id": provider_id,
"provider_type": provider_type,
**provider_config.dict(),
}

View file

@ -31,7 +31,7 @@ class ApiEndpoint(BaseModel):
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
provider_type: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
@ -100,7 +100,7 @@ class RemoteProviderConfig(BaseModel):
return url.rstrip("/")
def remote_provider_id(adapter_id: str) -> str:
def remote_provider_type(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@ -141,22 +141,22 @@ def remote_provider_spec(
if adapter and adapter.config_class
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
provider_type = remote_provider_type(adapter.adapter_id) if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
)
@json_schema_type
class DistributionSpec(BaseModel):
distribution_id: str
distribution_type: str
description: str
docker_image: Optional[str] = None
providers: Dict[Api, str] = Field(
default_factory=dict,
description="Provider IDs for each of the APIs provided by this distribution",
description="Provider Types for each of the APIs provided by this distribution",
)
@ -171,7 +171,7 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
this could be just a hash
""",
)
distribution_id: Optional[str] = None
distribution_type: Optional[str] = None
docker_image: Optional[str] = Field(
default=None,

View file

@ -83,18 +83,18 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
a.provider_type: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
a.provider_type: a for a in available_agentic_system_providers()
}
ret = {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
Api.memory: {a.provider_type: a for a in available_memory_providers()},
}
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)

View file

@ -14,7 +14,7 @@ from .datatypes import * # noqa: F403
def available_distribution_specs() -> List[DistributionSpec]:
return [
DistributionSpec(
distribution_id="local",
distribution_type="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
providers={
Api.inference: "meta-reference",
@ -24,35 +24,35 @@ def available_distribution_specs() -> List[DistributionSpec]:
},
),
DistributionSpec(
distribution_id="remote",
distribution_type="remote",
description="Point to remote services for all llama stack APIs",
providers={x: "remote" for x in Api},
),
DistributionSpec(
distribution_id="local-ollama",
distribution_type="local-ollama",
description="Like local, but use ollama for running LLM inference",
providers={
Api.inference: remote_provider_id("ollama"),
Api.inference: remote_provider_type("ollama"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
distribution_id="local-plus-fireworks-inference",
distribution_type="local-plus-fireworks-inference",
description="Use Fireworks.ai for running LLM inference",
providers={
Api.inference: remote_provider_id("fireworks"),
Api.inference: remote_provider_type("fireworks"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
distribution_id="local-plus-together-inference",
distribution_type="local-plus-together-inference",
description="Use Together.ai for running LLM inference",
providers={
Api.inference: remote_provider_id("together"),
Api.inference: remote_provider_type("together"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
@ -72,8 +72,8 @@ def available_distribution_specs() -> List[DistributionSpec]:
@lru_cache()
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.distribution_id == distribution_id:
if spec.distribution_type == distribution_type:
return spec
return None

View file

@ -46,13 +46,13 @@ def build_package(
api_inputs: List[ApiInput],
build_type: BuildType,
name: str,
distribution_id: Optional[str] = None,
distribution_type: Optional[str] = None,
docker_image: Optional[str] = None,
):
if not distribution_id:
distribution_id = "adhoc"
if not distribution_type:
distribution_type = "adhoc"
build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor()
build_dir = BUILDS_BASE_DIR / distribution_type / build_type.descriptor()
os.makedirs(build_dir, exist_ok=True)
package_name = name.replace("::", "-")
@ -79,7 +79,7 @@ def build_package(
if provider.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
stub_config[api.value] = {"provider_id": api_input.provider}
stub_config[api.value] = {"provider_type": api_input.provider}
if package_file.exists():
cprint(
@ -92,7 +92,7 @@ def build_package(
c.providers[api_str] = new_config
else:
existing_config = c.providers[api_str]
if existing_config["provider_id"] != new_config["provider_id"]:
if existing_config["provider_type"] != new_config["provider_type"]:
cprint(
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
color="yellow",
@ -105,7 +105,7 @@ def build_package(
providers=stub_config,
)
c.distribution_id = distribution_id
c.distribution_type = distribution_type
c.docker_image = package_name if build_type == BuildType.container else None
c.conda_env = package_name if build_type == BuildType.conda_env else None
@ -119,7 +119,7 @@ def build_package(
)
args = [
script,
distribution_id,
distribution_type,
package_name,
package_deps.docker_image,
" ".join(package_deps.pip_packages),
@ -130,7 +130,7 @@ def build_package(
)
args = [
script,
distribution_id,
distribution_type,
package_name,
" ".join(package_deps.pip_packages),
]

View file

@ -284,13 +284,13 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
provider_type = provider_config["provider_type"]
if provider_type not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
f"Unknown provider `{provider_type}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
provider_specs[api] = providers[provider_type]
impls = resolve_impls(provider_specs, config)

View file

@ -13,7 +13,7 @@ def available_inference_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,
provider_id="meta-reference",
provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_toolchain.core.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url)
await impl.initialize()
return impl

View file

@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from typing import List
from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.memory.common.vector_store import BankWithIndex, EmbeddingIndex
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: chromadb.AsyncHttpClient, collection):
self.client = client
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
await self.collection.add(
documents=[chunk.json() for chunk in chunks],
embeddings=embeddings,
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
results = await self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
)
distances = results["distances"][0]
documents = results["documents"][0]
chunks = []
scores = []
for dist, doc in zip(distances, documents):
try:
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {doc}")
continue
chunks.append(chunk)
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
parsed = urlparse(url)
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.host = parsed.hostname
self.port = parsed.port
self.client = None
self.cache = {}
async def initialize(self) -> None:
try:
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to Chroma server") from e
async def shutdown(self) -> None:
pass
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
collection = await self.client.create_collection(
name=bank_id,
metadata={"bank": bank.json()},
)
bank_index = BankWithIndex(
bank=bank, index=ChromaIndex(self.client, collection)
)
self.cache[bank_id] = bank_index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
collections = await self.client.list_collections()
for collection in collections:
if collection.name == bank_id:
print(collection.metadata)
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank_id] = index
return index
return None
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps):
from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class PGVectorConfig(BaseModel):
host: str = Field(default="localhost")
port: int = Field(default=5432)
db: str
user: str
password: str

View file

@ -0,0 +1,234 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import List, Tuple
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.memory.common.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from .config import PGVectorConfig
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
result = cur.fetchone()
return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, keys: List[str], cls):
query = "SELECT key, data FROM metadata_store"
if keys:
placeholders = ",".join(["%s"] * len(keys))
query += f" WHERE key IN ({placeholders})"
cur.execute(query, keys)
else:
cur.execute(query)
rows = cur.fetchall()
return [cls(**row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor):
self.cursor = cursor
self.table_name = f"vector_store_{bank.name}"
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension})
)
"""
)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
values = []
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
values.append(
(
f"{chunk.document_id}:chunk-{i}",
Json(chunk.dict()),
embeddings[i].tolist(),
)
)
query = sql.SQL(
f"""
INSERT INTO {self.table_name} (id, document, embedding)
VALUES %s
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
"""
)
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
FROM {self.table_name}
ORDER BY distance
LIMIT %s
""",
(embedding.tolist(), k),
)
results = self.cursor.fetchall()
chunks = []
scores = []
for doc, dist in results:
chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory):
def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config
self.cursor = None
self.conn = None
self.cache = {}
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
self.cursor = self.conn.cursor()
version = check_extension_version(self.cursor)
if version:
print(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None:
pass
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
upsert_models(
self.cursor,
[
(bank.bank_id, bank),
],
)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
banks = load_models(self.cursor, [bank_id], MemoryBank)
if not banks:
return None
bank = banks[0]
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import httpx
import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.memory.api import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None
def get_embedding_model() -> "SentenceTransformer":
global EMBEDDING_MODEL
if EMBEDDING_MODEL is None:
print("Loading sentence transformer")
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return EMBEDDING_MODEL
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
return interleaved_text_media_as_str(doc.content)
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int
) -> List[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
)
return chunks
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
raise NotImplementedError()
@abstractmethod
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
raise NotImplementedError()
@dataclass
class BankWithIndex:
bank: MemoryBank
index: EmbeddingIndex
async def insert_documents(
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model()
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
doc.document_id,
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
model = get_embedding_model()
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)

View file

@ -5,108 +5,45 @@
# the root directory of this source tree.
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import faiss
import httpx
import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.memory.common.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from .config import FaissImplConfig
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
return interleaved_text_media_as_str(doc.content)
def __init__(self, dimension: int):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
def make_overlapped_chunks(
text: str, window_len: int, overlap_len: int
) -> List[Tuple[str, int]]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunks.append((chunk, len(toks)))
return chunks
@dataclass
class BankState:
bank: MemoryBank
index: Optional[faiss.IndexFlatL2] = None
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
id_by_index: Dict[int, str] = field(default_factory=dict)
chunk_by_index: Dict[int, str] = field(default_factory=dict)
async def insert_documents(
self,
model: "SentenceTransformer",
documents: List[MemoryBankDocument],
) -> None:
tokenizer = Tokenizer.get_instance()
chunk_size = self.bank.config.chunk_size_in_tokens
for doc in documents:
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
self.doc_by_id[doc.document_id] = doc
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
await self._ensure_index(embeddings.shape[1])
self.index.add(embeddings)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = Chunk(
content=chunk[0],
token_count=chunk[1],
document_id=doc.document_id,
)
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id
self.chunk_by_index[indexlen + i] = chunk
print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
self.id_by_index[indexlen + i] = chunk.document_id
async def query_documents(
self,
model: "SentenceTransformer",
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
self.index.add(np.array(embeddings).astype(np.float32))
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
query_vector = model.encode([query_str])[0]
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
distances, indices = self.index.search(
query_vector.reshape(1, -1).astype(np.float32), k
embedding.reshape(1, -1).astype(np.float32), k
)
chunks = []
@ -119,17 +56,11 @@ class BankState:
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
if self.index is None:
self.index = faiss.IndexFlatL2(dimension)
return self.index
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.model = None
self.states = {}
self.cache = {}
async def initialize(self) -> None: ...
@ -153,14 +84,15 @@ class FaissMemoryImpl(Memory):
config=config,
url=url,
)
state = BankState(bank=bank)
self.states[bank_id] = state
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
if bank_id not in self.states:
index = self.cache.get(bank_id)
if index is None:
return None
return self.states[bank_id].bank
return index.bank
async def insert_documents(
self,
@ -168,10 +100,11 @@ class FaissMemoryImpl(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
await state.insert_documents(self.get_model(), documents)
await index.insert_documents(documents)
async def query_documents(
self,
@ -179,16 +112,8 @@ class FaissMemoryImpl(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
return await state.query_documents(self.get_model(), query, params)
def get_model(self) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer
if self.model is None:
print("Loading sentence transformer")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
return self.model
return await index.query_documents(query, params)

View file

@ -6,20 +6,38 @@
from typing import List
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_toolchain.core.datatypes import * # noqa: F403
EMBEDDING_DEPS = [
"blobfile",
"sentence-transformers",
]
def available_memory_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
provider_id="meta-reference-faiss",
pip_packages=[
"blobfile",
"faiss-cpu",
"sentence-transformers",
],
provider_type="meta-reference-faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_toolchain.memory.meta_reference.faiss",
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
adapter_id="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_toolchain.memory.adapters.chroma",
),
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
adapter_id="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
module="llama_toolchain.memory.adapters.pgvector",
config_class="llama_toolchain.memory.adapters.pgvector.PGVectorConfig",
),
),
]

View file

@ -13,7 +13,7 @@ def available_safety_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.safety,
provider_id="meta-reference",
provider_type="meta-reference",
pip_packages=[
"accelerate",
"codeshield",

View file

@ -11,7 +11,7 @@ from llama_toolchain.evaluations.api import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.batch_inference.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.observability.api import * # noqa: F403
from llama_toolchain.telemetry.api import * # noqa: F403
from llama_toolchain.post_training.api import * # noqa: F403
from llama_toolchain.reward_scoring.api import * # noqa: F403
from llama_toolchain.synthetic_data_generation.api import * # noqa: F403
@ -24,7 +24,7 @@ class LlamaStack(
RewardScoring,
SyntheticDataGeneration,
Datasets,
Observability,
Telemetry,
PostTraining,
Memory,
Evaluations,

View file

@ -134,7 +134,7 @@ class LogSearchRequest(BaseModel):
filters: Optional[Dict[str, Any]] = None
class Observability(Protocol):
class Telemetry(Protocol):
@webmethod(route="/experiments/create")
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ...

View file

@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
llama-models
llama-models>=0.0.13
pydantic
requests
termcolor

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -35,7 +35,10 @@ from llama_toolchain.stack import LlamaStack
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
STREAMING_ENDPOINTS = ["/agentic_system/turn/create"]
STREAMING_ENDPOINTS = [
"/agentic_system/turn/create",
"/inference/chat_completion",
]
def patch_sse_stream_responses(spec: Specification):

View file

@ -468,12 +468,14 @@ class Generator:
builder = ContentBuilder(self.schema_builder)
first = next(iter(op.request_params))
request_name, request_type = first
if len(op.request_params) == 1 and "Request" in first[1].__name__:
# TODO(ashwin): Undo the "Request" hack and this entire block eventually
request_name, request_type = first
else:
from dataclasses import make_dataclass
if len(op.request_params) == 1 and "Request" in first[1].__name__:
# TODO(ashwin): Undo the "Request" hack and this entire block eventually
request_name = first[1].__name__ + "Wrapper"
request_type = make_dataclass(request_name, op.request_params)
else:
op_name = "".join(word.capitalize() for word in op.name.split("_"))
request_name = f"{op_name}Request"
request_type = make_dataclass(request_name, op.request_params)

View file

@ -28,4 +28,4 @@ if [ ${#missing_packages[@]} -ne 0 ]; then
exit 1
fi
PYTHONPATH=$PYTHONPATH:../.. python3 -m rfcs.openapi_generator.generate $*
PYTHONPATH=$PYTHONPATH:../.. python -m rfcs.openapi_generator.generate $*