From b94e6c0bd40e1d27777a422b2ae4dba68fd32bf7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Nov 2024 12:25:43 -0800 Subject: [PATCH] Fix Registry so it scopes keys by object types --- .github/PULL_REQUEST_TEMPLATE.md | 3 +- docs/_deprecating_soon.ipynb | 4 +- llama_stack/apis/safety/client.py | 4 +- .../distribution/routers/routing_tables.py | 29 ++++---- llama_stack/distribution/store/registry.py | 66 +++++++++++-------- .../providers/tests/agents/fixtures.py | 2 +- .../providers/tests/safety/fixtures.py | 2 + 7 files changed, 63 insertions(+), 47 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 79701d926..fb02dd136 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -4,7 +4,8 @@ In short, provide a summary of what this PR does and why. Usually, the relevant - [ ] Addresses issue (#issue) -## Feature/Issue validation/testing/test plan + +## Test Plan Please describe: - tests you ran to verify your changes with result summaries. diff --git a/docs/_deprecating_soon.ipynb b/docs/_deprecating_soon.ipynb index 343005962..7fa4034ce 100644 --- a/docs/_deprecating_soon.ipynb +++ b/docs/_deprecating_soon.ipynb @@ -180,8 +180,8 @@ " tools=tools,\n", " tool_choice=\"auto\",\n", " tool_prompt_format=\"json\",\n", - " input_shields=[\"llama_guard\"],\n", - " output_shields=[\"llama_guard\"],\n", + " input_shields=[\"Llama-Guard-3-1B\"],\n", + " output_shields=[\"Llama-Guard-3-1B\"],\n", " enable_session_persistence=True,\n", " )\n", "\n", diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index 96168fedd..d7d4bc981 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety: def encodable_dict(d: BaseModel): - return json.loads(d.json()) + return json.loads(d.model_dump_json()) class SafetyClient(Safety): @@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None): ) cprint(f"User>{message.content}", "green") response = await client.run_shield( - shield_id="llama_guard", + shield_id="Llama-Guard-3-1B", messages=[message], ) print(response) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 0ca798641..d6fb5d662 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -136,17 +136,18 @@ class CommonRoutingTableImpl(RoutingTable): else: raise ValueError("Unknown routing table type") + apiname, objtype = apiname_object() + # Get objects from disk registry - objects = self.dist_registry.get_cached(routing_key) + objects = self.dist_registry.get_cached(objtype, routing_key) if not objects: - apiname, objname = apiname_object() provider_ids = list(self.impls_by_provider_id.keys()) if len(provider_ids) > 1: provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" else: provider_ids_str = f"provider: `{provider_ids[0]}`" raise ValueError( - f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}." + f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." ) for obj in objects: @@ -156,19 +157,19 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") async def get_object_by_identifier( - self, identifier: str + self, type: str, identifier: str ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry - objects = await self.dist_registry.get(identifier) + objects = await self.dist_registry.get(type, identifier) if not objects: return None - # kind of ill-defined behavior here, but we'll just return the first one + assert len(objects) == 1 return objects[0] async def register_object(self, obj: RoutableObjectWithProvider): # Get existing objects from registry - existing_objects = await self.dist_registry.get(obj.identifier) + existing_objects = await self.dist_registry.get(obj.type, obj.identifier) # Check for existing registration for existing_obj in existing_objects: @@ -200,7 +201,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return await self.get_all_with_type("model") async def get_model(self, identifier: str) -> Optional[Model]: - return await self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier("model", identifier) async def register_model( self, @@ -236,7 +237,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return await self.get_all_with_type(ResourceType.shield.value) async def get_shield(self, identifier: str) -> Optional[Shield]: - return await self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier("shield", identifier) async def register_shield( self, @@ -272,7 +273,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): return await self.get_all_with_type(ResourceType.memory_bank.value) async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: - return await self.get_object_by_identifier(memory_bank_id) + return await self.get_object_by_identifier("memory_bank", memory_bank_id) async def register_memory_bank( self, @@ -310,7 +311,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): return await self.get_all_with_type("dataset") async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: - return await self.get_object_by_identifier(dataset_id) + return await self.get_object_by_identifier("dataset", dataset_id) async def register_dataset( self, @@ -346,10 +347,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFn]: - return await self.get_all_with_type(ResourceType.scoring_function.value) + return await self.get_all_with_type("scoring_function") async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: - return await self.get_object_by_identifier(scoring_fn_id) + return await self.get_object_by_identifier("scoring_function", scoring_fn_id) async def register_scoring_function( self, @@ -387,7 +388,7 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): return await self.get_all_with_type("eval_task") async def get_eval_task(self, name: str) -> Optional[EvalTask]: - return await self.get_object_by_identifier(name) + return await self.get_object_by_identifier("eval_task", name) async def register_eval_task( self, diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 6115ea1b3..d837c4375 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Dict, List, Optional, Protocol +from typing import Dict, List, Optional, Protocol, Tuple import pydantic @@ -35,7 +35,8 @@ class DistributionRegistry(Protocol): async def register(self, obj: RoutableObjectWithProvider) -> bool: ... -KEY_FORMAT = "distributions:registry:v1::{}" +KEY_VERSION = "v1" +KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}" class DiskDistributionRegistry(DistributionRegistry): @@ -45,18 +46,24 @@ class DiskDistributionRegistry(DistributionRegistry): async def initialize(self) -> None: pass - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: + def get_cached( + self, type: str, identifier: str + ) -> List[RoutableObjectWithProvider]: # Disk registry does not have a cache return [] async def get_all(self) -> List[RoutableObjectWithProvider]: - start_key = KEY_FORMAT.format("") - end_key = KEY_FORMAT.format("\xff") + start_key = KEY_FORMAT.format(type="", identifier="") + end_key = KEY_FORMAT.format(type="", identifier="\xff") keys = await self.kvstore.range(start_key, end_key) - return [await self.get(key.split(":")[-1]) for key in keys] - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: - json_str = await self.kvstore.get(KEY_FORMAT.format(identifier)) + tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys] + return [await self.get(type, identifier) for type, identifier in tuples] + + async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + json_str = await self.kvstore.get( + KEY_FORMAT.format(type=type, identifier=identifier) + ) if not json_str: return [] @@ -70,7 +77,7 @@ class DiskDistributionRegistry(DistributionRegistry): ] async def register(self, obj: RoutableObjectWithProvider) -> bool: - existing_objects = await self.get(obj.identifier) + existing_objects = await self.get(obj.type, obj.identifier) # dont register if the object's providerid already exists for eobj in existing_objects: if eobj.provider_id == obj.provider_id: @@ -82,7 +89,8 @@ class DiskDistributionRegistry(DistributionRegistry): obj.model_dump_json() for obj in existing_objects ] # Fixed variable name await self.kvstore.set( - KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) + KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), + json.dumps(objects_json), ) return True @@ -90,33 +98,36 @@ class DiskDistributionRegistry(DistributionRegistry): class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) - self.cache: Dict[str, List[RoutableObjectWithProvider]] = {} + self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {} async def initialize(self) -> None: - start_key = KEY_FORMAT.format("") - end_key = KEY_FORMAT.format("\xff") + start_key = KEY_FORMAT.format(type="", identifier="") + end_key = KEY_FORMAT.format(type="", identifier="\xff") keys = await self.kvstore.range(start_key, end_key) for key in keys: - identifier = key.split(":")[-1] - objects = await super().get(identifier) + type, identifier = key.split(":")[-2:] + objects = await super().get(type, identifier) if objects: - self.cache[identifier] = objects + self.cache[type, identifier] = objects - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: - return self.cache.get(identifier, []) + def get_cached( + self, type: str, identifier: str + ) -> List[RoutableObjectWithProvider]: + return self.cache.get((type, identifier), []) async def get_all(self) -> List[RoutableObjectWithProvider]: return [item for sublist in self.cache.values() for item in sublist] - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: - if identifier in self.cache: - return self.cache[identifier] + async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + cachekey = (type, identifier) + if cachekey in self.cache: + return self.cache[cachekey] - objects = await super().get(identifier) + objects = await super().get(type, identifier) if objects: - self.cache[identifier] = objects + self.cache[cachekey] = objects return objects @@ -126,16 +137,17 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): if success: # Then update cache - if obj.identifier not in self.cache: - self.cache[obj.identifier] = [] + cachekey = (obj.type, obj.identifier) + if cachekey not in self.cache: + self.cache[cachekey] = [] # Check if provider already exists in cache - for cached_obj in self.cache[obj.identifier]: + for cached_obj in self.cache[cachekey]: if cached_obj.provider_id == obj.provider_id: return success # If not, update cache - self.cache[obj.identifier].append(obj) + self.cache[cachekey].append(obj) return success diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 64f493b88..db157174f 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -44,7 +44,7 @@ def agents_meta_reference() -> ProviderFixture: providers=[ Provider( provider_id="meta-reference", - provider_type="meta-reference", + provider_type="inline::meta-reference", config=MetaReferenceAgentsImplConfig( # TODO: make this an in-memory store persistence_store=SqliteKVStoreConfig( diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 437c8fa54..b73c2d798 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -101,6 +101,8 @@ async def safety_stack(inference_model, safety_model, request): shield_provider_type = safety_fixture.providers[0].provider_type shield_input = get_shield_to_register(shield_provider_type, safety_model) + print(f"inference_model: {inference_model}") + print(f"shield_input = {shield_input}") impls = await resolve_impls_for_test_v2( [Api.safety, Api.shields, Api.inference], providers,