llama-stack-mirror/llama_stack/distribution/store/registry.py
Ashwin Bharambe 1a7490470a
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Third part:
- we need to make `tool_runtime.rag_tool.query_context()` and
`tool_runtime.rag_tool.insert_documents()` methods work smoothly with
complete type safety. To that end, we introduce a sub-resource path
`tool-runtime/rag-tool/` and make changes to the resolver to make things
work.
- the PR updates the agents implementation to directly call these typed
APIs for memory accesses rather than going through the complex, untyped
"invoke_tool" API. the code looks much nicer and simpler (expectedly.)
- there are a number of hacks in the server resolver implementation
still, we will live with some and fix some

Note that we must make sure the client SDKs are able to handle this
subresource complexity also. Stainless has support for subresources, so
this should be possible but beware.

## Test Plan

Our RAG test is sad (doesn't actually test for actual RAG output) but I
verified that the implementation works. I will work on fixing the RAG
test afterwards.

```bash
pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B
```
2025-01-22 10:04:16 -08:00

206 lines
7.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple
import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
async def initialize(self) -> None: ...
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
async def update(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider: ...
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
async def delete(self, type: str, identifier: str) -> None: ...
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v6"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
def _get_registry_key_range() -> Tuple[str, str]:
"""Returns the start and end keys for the registry range query."""
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
return start_key, f"{start_key}\xff"
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
return all_objects
class DiskDistributionRegistry(DistributionRegistry):
def __init__(self, kvstore: KVStore):
self.kvstore = kvstore
async def initialize(self) -> None:
pass
def get_cached(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache")
async def get_all(self) -> List[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values)
async def get(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
KEY_FORMAT.format(type=type, identifier=identifier)
)
if not json_str:
return None
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
obj.model_dump_json(),
)
return obj
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id == obj.provider_id:
return False
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
obj.model_dump_json(),
)
return True
async def delete(self, type: str, identifier: str) -> None:
await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier))
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {}
self._initialized = False
self._initialize_lock = asyncio.Lock()
self._cache_lock = asyncio.Lock()
@asynccontextmanager
async def _locked_cache(self):
"""Context manager for safely accessing the cache with a lock."""
async with self._cache_lock:
yield self.cache
async def _ensure_initialized(self):
"""Ensures the registry is initialized before operations."""
if self._initialized:
return
async with self._initialize_lock:
if self._initialized:
return
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:
for obj in objects:
cache_key = (obj.type, obj.identifier)
cache[cache_key] = obj
self._initialized = True
async def initialize(self) -> None:
await self._ensure_initialized()
def get_cached(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), None)
async def get_all(self) -> List[RoutableObjectWithProvider]:
await self._ensure_initialized()
async with self._locked_cache() as cache:
return list(cache.values())
async def get(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
await self._ensure_initialized()
cache_key = (type, identifier)
async with self._locked_cache() as cache:
return cache.get(cache_key, None)
async def register(self, obj: RoutableObjectWithProvider) -> bool:
await self._ensure_initialized()
success = await super().register(obj)
if success:
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
cache[cache_key] = obj
return success
async def update(self, obj: RoutableObjectWithProvider) -> None:
await super().update(obj)
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
cache[cache_key] = obj
return obj
async def delete(self, type: str, identifier: str) -> None:
await super().delete(type, identifier)
cache_key = (type, identifier)
async with self._locked_cache() as cache:
if cache_key in cache:
del cache[cache_key]
async def create_dist_registry(
metadata_store: Optional[KVStoreConfig],
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize()
return dist_registry, dist_kvstore