mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? This PR brings back the facility to not force registration of resources onto the user. This is not just annoying but actually not feasible sometimes. For example, you may have a Stack which boots up with private providers for inference for models A and B. There is no way for the user to actually know which model is being served by these providers now (to be able to register it.) How will this avoid the users needing to do registration? In a follow-up diff, I will make sure I update the sample run.yaml files so they list the models served by the distributions explicitly. So when users do `llama stack build --template <...>` and run it, their distributions come up with the right set of models they expect. For self-hosted distributions, it also allows us to have a place to explicit list the models that need to be served to make the "complete" stack (including safety, e.g.) ## Test Plan Started ollama locally with two lightweight models: Llama3.2-3B-Instruct and Llama-Guard-3-1B. Updated all the tests including agents. Here's the tests I ran so far: ```bash pytest -s -v -m "fireworks and llama_3b" test_text_inference.py::TestInference \ --env FIREWORKS_API_KEY=... pytest -s -v -m "ollama and llama_3b" test_text_inference.py::TestInference pytest -s -v -m ollama test_safety.py pytest -s -v -m faiss test_memory.py pytest -s -v -m ollama test_agents.py \ --inference-model=Llama3.2-3B-Instruct --safety-model=Llama-Guard-3-1B ``` Found a few bugs here and there pre-existing that these test runs fixed.
157 lines
5.3 KiB
Python
157 lines
5.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
from typing import Dict, List, Optional, Protocol
|
|
|
|
import pydantic
|
|
|
|
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
|
|
|
from llama_stack.providers.utils.kvstore import (
|
|
KVStore,
|
|
kvstore_impl,
|
|
SqliteKVStoreConfig,
|
|
)
|
|
|
|
|
|
class DistributionRegistry(Protocol):
|
|
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
|
|
|
async def initialize(self) -> None: ...
|
|
|
|
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ...
|
|
|
|
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ...
|
|
|
|
# The current data structure allows multiple objects with the same identifier but different providers.
|
|
# This is not ideal - we should have a single object that can be served by multiple providers,
|
|
# suggesting a data structure like (obj: Obj, providers: List[str]) rather than List[RoutableObjectWithProvider].
|
|
# The current approach could lead to inconsistencies if the same logical object has different data across providers.
|
|
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
|
|
|
|
|
KEY_FORMAT = "distributions:registry:v1::{}"
|
|
|
|
|
|
class DiskDistributionRegistry(DistributionRegistry):
|
|
def __init__(self, kvstore: KVStore):
|
|
self.kvstore = kvstore
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]:
|
|
# Disk registry does not have a cache
|
|
return []
|
|
|
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
|
start_key = KEY_FORMAT.format("")
|
|
end_key = KEY_FORMAT.format("\xff")
|
|
keys = await self.kvstore.range(start_key, end_key)
|
|
return [await self.get(key.split(":")[-1]) for key in keys]
|
|
|
|
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]:
|
|
json_str = await self.kvstore.get(KEY_FORMAT.format(identifier))
|
|
if not json_str:
|
|
return []
|
|
|
|
objects_data = json.loads(json_str)
|
|
return [
|
|
pydantic.parse_obj_as(
|
|
RoutableObjectWithProvider,
|
|
json.loads(obj_str),
|
|
)
|
|
for obj_str in objects_data
|
|
]
|
|
|
|
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
|
existing_objects = await self.get(obj.identifier)
|
|
# dont register if the object's providerid already exists
|
|
for eobj in existing_objects:
|
|
if eobj.provider_id == obj.provider_id:
|
|
return False
|
|
|
|
existing_objects.append(obj)
|
|
|
|
objects_json = [
|
|
obj.model_dump_json() for obj in existing_objects
|
|
] # Fixed variable name
|
|
await self.kvstore.set(
|
|
KEY_FORMAT.format(obj.identifier), json.dumps(objects_json)
|
|
)
|
|
return True
|
|
|
|
|
|
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|
def __init__(self, kvstore: KVStore):
|
|
super().__init__(kvstore)
|
|
self.cache: Dict[str, List[RoutableObjectWithProvider]] = {}
|
|
|
|
async def initialize(self) -> None:
|
|
start_key = KEY_FORMAT.format("")
|
|
end_key = KEY_FORMAT.format("\xff")
|
|
|
|
keys = await self.kvstore.range(start_key, end_key)
|
|
|
|
for key in keys:
|
|
identifier = key.split(":")[-1]
|
|
objects = await super().get(identifier)
|
|
if objects:
|
|
self.cache[identifier] = objects
|
|
|
|
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]:
|
|
return self.cache.get(identifier, [])
|
|
|
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
|
return [item for sublist in self.cache.values() for item in sublist]
|
|
|
|
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]:
|
|
if identifier in self.cache:
|
|
return self.cache[identifier]
|
|
|
|
objects = await super().get(identifier)
|
|
if objects:
|
|
self.cache[identifier] = objects
|
|
|
|
return objects
|
|
|
|
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
|
# First update disk
|
|
success = await super().register(obj)
|
|
|
|
if success:
|
|
# Then update cache
|
|
if obj.identifier not in self.cache:
|
|
self.cache[obj.identifier] = []
|
|
|
|
# Check if provider already exists in cache
|
|
for cached_obj in self.cache[obj.identifier]:
|
|
if cached_obj.provider_id == obj.provider_id:
|
|
return success
|
|
|
|
# If not, update cache
|
|
self.cache[obj.identifier].append(obj)
|
|
|
|
return success
|
|
|
|
|
|
async def create_dist_registry(
|
|
metadata_store: Optional[KVStoreConfig],
|
|
image_name: str,
|
|
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
|
# instantiate kvstore for storing and retrieving distribution metadata
|
|
if metadata_store:
|
|
dist_kvstore = await kvstore_impl(metadata_store)
|
|
else:
|
|
dist_kvstore = await kvstore_impl(
|
|
SqliteKVStoreConfig(
|
|
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
|
)
|
|
)
|
|
|
|
return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore
|