llama-stack-mirror/llama_stack/distribution/store/registry.py
Ashwin Bharambe 983d6ce2df
Remove the "ShieldType" concept (#430)
# What does this PR do?

This PR kills the notion of "ShieldType". The impetus for this is the
realization:

> Why is keyword llama-guard appearing so many times everywhere,
sometimes with hyphens, sometimes with underscores?

Now that we have a notion of "provider specific resource identifiers"
and "user specific aliases" for those and the fact that this works with
models ("Llama3.1-8B-Instruct" <> "fireworks/llama-3pv1-..."), we can
follow the same rules for Shields.

So each Safety provider can make up a notion of identifiers it has
registered. This already happens with Bedrock correctly. We just
generalize it for Llama Guard, Prompt Guard, etc.

For Llama Guard, we further simplify by just adopting the underlying
model name itself as the identifier! No confusion necessary.

While doing this, I noticed a bug in our DistributionRegistry where we
weren't scoping identifiers by type. Fixed.

## Feature/Issue validation/testing/test plan

Ran (inference, safety, memory, agents) tests with ollama and fireworks
providers.
2024-11-12 12:37:24 -08:00

169 lines
5.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Dict, List, Optional, Protocol, Tuple
import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import (
KVStore,
kvstore_impl,
SqliteKVStoreConfig,
)
class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
async def initialize(self) -> None: ...
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ...
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ...
# The current data structure allows multiple objects with the same identifier but different providers.
# This is not ideal - we should have a single object that can be served by multiple providers,
# suggesting a data structure like (obj: Obj, providers: List[str]) rather than List[RoutableObjectWithProvider].
# The current approach could lead to inconsistencies if the same logical object has different data across providers.
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
KEY_VERSION = "v1"
KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}"
class DiskDistributionRegistry(DistributionRegistry):
def __init__(self, kvstore: KVStore):
self.kvstore = kvstore
async def initialize(self) -> None:
pass
def get_cached(
self, type: str, identifier: str
) -> List[RoutableObjectWithProvider]:
# Disk registry does not have a cache
return []
async def get_all(self) -> List[RoutableObjectWithProvider]:
start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys]
return [await self.get(type, identifier) for type, identifier in tuples]
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
KEY_FORMAT.format(type=type, identifier=identifier)
)
if not json_str:
return []
objects_data = json.loads(json_str)
return [
pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(obj_str),
)
for obj_str in objects_data
]
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_objects = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
for eobj in existing_objects:
if eobj.provider_id == obj.provider_id:
return False
existing_objects.append(obj)
objects_json = [
obj.model_dump_json() for obj in existing_objects
] # Fixed variable name
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
json.dumps(objects_json),
)
return True
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
async def initialize(self) -> None:
start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
for key in keys:
type, identifier = key.split(":")[-2:]
objects = await super().get(type, identifier)
if objects:
self.cache[type, identifier] = objects
def get_cached(
self, type: str, identifier: str
) -> List[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), [])
async def get_all(self) -> List[RoutableObjectWithProvider]:
return [item for sublist in self.cache.values() for item in sublist]
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
cachekey = (type, identifier)
if cachekey in self.cache:
return self.cache[cachekey]
objects = await super().get(type, identifier)
if objects:
self.cache[cachekey] = objects
return objects
async def register(self, obj: RoutableObjectWithProvider) -> bool:
# First update disk
success = await super().register(obj)
if success:
# Then update cache
cachekey = (obj.type, obj.identifier)
if cachekey not in self.cache:
self.cache[cachekey] = []
# Check if provider already exists in cache
for cached_obj in self.cache[cachekey]:
if cached_obj.provider_id == obj.provider_id:
return success
# If not, update cache
self.cache[cachekey].append(obj)
return success
async def create_dist_registry(
metadata_store: Optional[KVStoreConfig],
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
)
return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore