Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -7,17 +7,7 @@ import textwrap
from typing import Any
from llama_models.sku_list import (
llama3_1_family,
llama3_2_family,
llama3_family,
resolve_model,
safety_models,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.distribution.distribution import (
@ -33,11 +23,6 @@ from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
ALLOWED_MODELS = (
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
)
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
@ -133,137 +118,10 @@ def configure_api_providers(
config.providers[api_str] = updated_providers
if is_nux:
print(
textwrap.dedent(
"""
=========================================================================================
Now let's configure the `objects` you will be serving via the stack. These are:
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
This wizard will guide you through setting up one of each of these objects. You can
always add more later by editing the run.yaml file.
"""
)
)
object_types = {
"models": (ModelDef, configure_models, "inference"),
"shields": (ShieldDef, configure_shields, "safety"),
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
}
safety_providers = config.providers.get("safety", [])
for otype, (odef, config_method, api_str) in object_types.items():
existing_objects = getattr(config, otype)
if existing_objects:
cprint(
f"{len(existing_objects)} {otype} exist. Skipping...",
"blue",
attrs=["bold"],
)
updated_objects = existing_objects
else:
providers = config.providers.get(api_str, [])
if not providers:
updated_objects = []
else:
# we are newly configuring this API
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
updated_objects = config_method(
config.providers[api_str], safety_providers
)
setattr(config, otype, updated_objects)
print("")
return config
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
if not safety_providers:
return None
provider = safety_providers[0]
assert provider.provider_type == "meta-reference"
cfg = provider.config["llama_guard_shield"]
if not cfg:
return None
return cfg["model"]
def configure_models(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ModelDef]:
model = prompt(
"> Please enter the model you want to serve: ",
default="Llama3.2-1B-Instruct",
validator=Validator.from_callable(
lambda x: resolve_model(x) is not None,
error_message="Model must be: {}".format(
[x.descriptor() for x in ALLOWED_MODELS]
),
),
)
model = ModelDef(
identifier=model,
llama_model=model,
provider_id=providers[0].provider_id,
)
ret = [model]
if llama_guard := get_llama_guard_model(safety_providers):
ret.append(
ModelDef(
identifier=llama_guard,
llama_model=llama_guard,
provider_id=providers[0].provider_id,
)
)
return ret
def configure_shields(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ShieldDef]:
if get_llama_guard_model(safety_providers):
return [
ShieldDef(
identifier="llama_guard",
type="llama_guard",
provider_id=providers[0].provider_id,
params={},
)
]
return []
def configure_memory_banks(
providers: List[Provider], safety_providers: List[Provider]
) -> List[MemoryBankDef]:
bank_name = prompt(
"> Please enter a name for your memory bank: ",
default="my-memory-bank",
)
return [
VectorMemoryBankDef(
identifier=bank_name,
provider_id=providers[0].provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
]
def upgrade_from_routing_table_to_registry(
def upgrade_from_routing_table(
config_dict: Dict[str, Any],
) -> Dict[str, Any]:
def get_providers(entries):
@ -281,57 +139,12 @@ def upgrade_from_routing_table_to_registry(
]
providers_by_api = {}
models = []
shields = []
memory_banks = []
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
if api_str == "inference":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
models.append(
ModelDef(
identifier=key,
provider_id=provider.provider_id,
llama_model=key,
)
)
elif api_str == "safety":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
shields.append(
ShieldDef(
identifier=key,
type=ShieldType.llama_guard.value,
provider_id=provider.provider_id,
)
)
elif api_str == "memory":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
# we currently only support Vector memory banks so this is OK
memory_banks.append(
VectorMemoryBankDef(
identifier=key,
provider_id=provider.provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
)
config_dict["models"] = models
config_dict["shields"] = shields
config_dict["memory_banks"] = memory_banks
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
@ -361,9 +174,9 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict)
if "models" not in config_dict:
if "routing_table" in config_dict:
print("Upgrading config...")
config_dict = upgrade_from_routing_table_to_registry(config_dict)
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
config_dict["built_at"] = datetime.now().isoformat()

View file

@ -32,6 +32,12 @@ RoutableObject = Union[
MemoryBankDef,
]
RoutableObjectWithProvider = Union[
ModelDefWithProvider,
ShieldDef,
MemoryBankDef,
]
RoutedProtocol = Union[
Inference,
Safety,
@ -63,7 +69,6 @@ class RoutingTableProviderSpec(ProviderSpec):
docker_image: Optional[str] = None
router_api: Api
registry: List[RoutableObject]
module: str
pip_packages: List[str] = Field(default_factory=list)
@ -121,25 +126,6 @@ can be instantiated multiple times (with different configs) if necessary.
""",
)
models: List[ModelDef] = Field(
description="""
List of model definitions to serve. This list may get extended by
/models/register API calls at runtime.
""",
)
shields: List[ShieldDef] = Field(
description="""
List of shield definitions to serve. This list may get extended by
/shields/register API calls at runtime.
""",
)
memory_banks: List[MemoryBankDef] = Field(
description="""
List of memory bank definitions to serve. This list may get extended by
/memory_banks/register API calls at runtime.
""",
)
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -4,10 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import inspect
from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
@ -15,6 +27,28 @@ from llama_stack.distribution.distribution import (
from llama_stack.distribution.utils.dynamic import instantiate_class_type
def api_protocol_map() -> Dict[Api, Any]:
return {
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.memory: Memory,
Api.memory_banks: MemoryBanks,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
}
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: ModelsProtocolPrivate,
Api.memory: MemoryBanksProtocolPrivate,
Api.safety: ShieldsProtocolPrivate,
}
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
@ -73,17 +107,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
inner_deps = []
registry = getattr(run_config, info.routing_table_api.value)
for entry in registry:
if entry.provider_id not in available_providers:
raise ValueError(
f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
)
provider = available_providers[entry.provider_id]
inner_deps.extend(provider.spec.api_dependencies)
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__builtin__",
@ -92,13 +115,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
registry=registry,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
deps__=(
[x.value for x in inner_deps]
+ [f"inner-{info.router_api.value}"]
),
api_dependencies=[],
deps__=([f"inner-{info.router_api.value}"]),
),
)
}
@ -212,6 +231,9 @@ async def instantiate_provider(
deps: Dict[str, Any],
inner_impls: Dict[str, Any],
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
module = importlib.import_module(provider_spec.module)
@ -234,7 +256,7 @@ async def instantiate_provider(
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
args = [provider_spec.api, inner_impls, deps]
else:
method = "get_provider_impl"
@ -247,4 +269,55 @@ async def instantiate_provider(
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
missing_methods = []
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):
missing_methods.append((name, "not_callable"))
else:
# Check if the method signatures are compatible
obj_method = getattr(obj, name)
proto_sig = inspect.signature(value)
obj_sig = inspect.signature(obj_method)
proto_params = set(proto_sig.parameters)
proto_params.discard("self")
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
print(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from .routing_tables import (
@ -16,7 +16,6 @@ from .routing_tables import (
async def get_routing_table_impl(
api: Api,
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
) -> Any:
@ -28,7 +27,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](registry, impls_by_provider_id)
impl = api_to_tables[api.value](impls_by_provider_id)
await impl.initialize()
return impl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
@ -29,115 +29,145 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
await p.register_memory_bank(obj)
Registry = Dict[str, List[RoutableObjectWithProvider]]
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
) -> None:
for obj in registry:
if obj.provider_id not in impls_by_provider_id:
print(f"{impls_by_provider_id=}")
raise ValueError(
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
)
self.impls_by_provider_id = impls_by_provider_id
self.registry = registry
for p in self.impls_by_provider_id.values():
async def initialize(self) -> None:
self.registry: Registry = {}
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
for obj in objs:
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
models = await p.list_models()
add_objects(
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
)
elif api == Api.safety:
p.shield_store = self
shields = await p.list_shields()
add_objects(
[
ShieldDefWithProvider(**s.dict(), provider_id=pid)
for s in shields
]
)
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
self.routing_key_to_object = {}
for obj in self.registry:
self.routing_key_to_object[obj.identifier] = obj
# do in-memory updates due to pesky Annotated unions
for m in memory_banks:
m.provider_id = pid
async def initialize(self) -> None:
for obj in self.registry:
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
add_objects(memory_banks)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.routing_key_to_object:
def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
if routing_key not in self.registry:
raise ValueError(f"`{routing_key}` not registered")
obj = self.routing_key_to_object[routing_key]
objs = self.registry[routing_key]
for obj in objs:
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
def get_object_by_identifier(
self, identifier: str
) -> Optional[RoutableObjectWithProvider]:
objs = self.registry.get(identifier, [])
if not objs:
return None
# kind of ill-defined behavior here, but we'll just return the first one
return objs[0]
async def register_object(self, obj: RoutableObjectWithProvider):
entries = self.registry.get(obj.identifier, [])
for entry in entries:
if entry.provider_id == obj.provider_id:
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
return
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
return self.impls_by_provider_id[obj.provider_id]
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
for obj in self.registry:
if obj.identifier == identifier:
return obj
return None
async def register_object(self, obj: RoutableObject):
if obj.identifier in self.routing_key_to_object:
print(f"`{obj.identifier}` is already registered")
return
if not obj.provider_id:
provider_ids = list(self.impls_by_provider_id.keys())
if not provider_ids:
raise ValueError("No providers found")
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
obj.provider_id = provider_ids[0]
else:
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj)
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
# TODO: persist this to a store
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDef]:
return self.registry
async def list_models(self) -> List[ModelDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_model(self, identifier: str) -> Optional[ModelDef]:
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None:
async def register_model(self, model: ModelDefWithProvider) -> None:
await self.register_object(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]:
return self.registry
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
await self.register_object(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.registry
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]:
return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object(bank)
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None:
await self.register_object(memory_bank)

View file

@ -9,15 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
@ -31,18 +23,7 @@ class ApiEndpoint(BaseModel):
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
Api.inspect: Inspect,
}
protocols = api_protocol_map()
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)