Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -6,7 +6,16 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
@ -404,6 +413,7 @@ class AgentStepResponse(BaseModel):
step: Step
@runtime_checkable
class Agents(Protocol):
@webmethod(route="/agents/create")
async def create_agent(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@ -47,6 +47,7 @@ class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@runtime_checkable
class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion")
async def batch_completion(

View file

@ -6,7 +6,7 @@
from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
from llama_models.schema_utils import json_schema_type, webmethod
@ -177,6 +177,7 @@ class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
@runtime_checkable
class Inference(Protocol):
model_store: ModelStore
@ -214,6 +215,3 @@ class Inference(Protocol):
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
@webmethod(route="/inference/register_model")
async def register_model(self, model: ModelDef) -> None: ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Protocol
from typing import Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@ -29,6 +29,7 @@ class HealthInfo(BaseModel):
# TODO: add a provider level status
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
async def list_providers(self) -> Dict[str, ProviderInfo]: ...

View file

@ -8,7 +8,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@ -42,6 +42,7 @@ class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
@runtime_checkable
class Memory(Protocol):
memory_bank_store: MemoryBankStore
@ -55,13 +56,6 @@ class Memory(Protocol):
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory/update")
async def update_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory/query")
async def query_documents(
self,
@ -69,20 +63,3 @@ class Memory(Protocol):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory/documents/get", method="GET")
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory/documents/delete", method="DELETE")
async def delete_documents(
self,
bank_id: str,
document_ids: List[str],
) -> None: ...
@webmethod(route="/memory/register_memory_bank")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -22,7 +22,8 @@ class MemoryBankType(Enum):
class CommonDef(BaseModel):
identifier: str
provider_id: Optional[str] = None
# Hack: move this out later
provider_id: str = ""
@json_schema_type
@ -58,13 +59,20 @@ MemoryBankDef = Annotated[
Field(discriminator="type"),
]
MemoryBankDefWithProvider = MemoryBankDef
@runtime_checkable
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]: ...
@webmethod(route="/memory_banks/register", method="POST")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None: ...

View file

@ -4,34 +4,39 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@json_schema_type
class ModelDef(BaseModel):
identifier: str = Field(
description="A unique identifier for the model type",
description="A unique name for the model type",
)
llama_model: str = Field(
description="Pointer to the core Llama family model",
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this model"
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)
# For now, we are only supporting core llama models but as soon as finetuned
# and other custom models (for example various quantizations) are allowed, there
# will be more metadata fields here
@json_schema_type
class ModelDefWithProvider(ModelDef):
provider_id: str = Field(
description="The provider ID for this model",
)
@runtime_checkable
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[ModelDef]: ...
async def list_models(self) -> List[ModelDefWithProvider]: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: ...
@webmethod(route="/models/register", method="POST")
async def register_model(self, model: ModelDef) -> None: ...
async def register_model(self, model: ModelDefWithProvider) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Protocol
from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@ -42,6 +42,7 @@ class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
@runtime_checkable
class Safety(Protocol):
shield_store: ShieldStore
@ -49,6 +50,3 @@ class Safety(Protocol):
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...
@webmethod(route="/safety/register_shield")
async def register_shield(self, shield: ShieldDef) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -26,21 +26,26 @@ class ShieldDef(BaseModel):
type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum"
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this shield"
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional parameters needed for this shield",
)
@json_schema_type
class ShieldDefWithProvider(ShieldDef):
provider_id: str = Field(
description="The provider ID for this shield type",
)
@runtime_checkable
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldDef]: ...
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
@webmethod(route="/shields/register", method="POST")
async def register_shield(self, shield: ShieldDef) -> None: ...
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Literal, Optional, Protocol, Union
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -123,6 +123,7 @@ Event = Annotated[
]
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event")
async def log_event(self, event: Event) -> None: ...

View file

@ -7,17 +7,7 @@ import textwrap
from typing import Any
from llama_models.sku_list import (
llama3_1_family,
llama3_2_family,
llama3_family,
resolve_model,
safety_models,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.distribution.distribution import (
@ -33,11 +23,6 @@ from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
ALLOWED_MODELS = (
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
)
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
@ -133,137 +118,10 @@ def configure_api_providers(
config.providers[api_str] = updated_providers
if is_nux:
print(
textwrap.dedent(
"""
=========================================================================================
Now let's configure the `objects` you will be serving via the stack. These are:
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
This wizard will guide you through setting up one of each of these objects. You can
always add more later by editing the run.yaml file.
"""
)
)
object_types = {
"models": (ModelDef, configure_models, "inference"),
"shields": (ShieldDef, configure_shields, "safety"),
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
}
safety_providers = config.providers.get("safety", [])
for otype, (odef, config_method, api_str) in object_types.items():
existing_objects = getattr(config, otype)
if existing_objects:
cprint(
f"{len(existing_objects)} {otype} exist. Skipping...",
"blue",
attrs=["bold"],
)
updated_objects = existing_objects
else:
providers = config.providers.get(api_str, [])
if not providers:
updated_objects = []
else:
# we are newly configuring this API
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
updated_objects = config_method(
config.providers[api_str], safety_providers
)
setattr(config, otype, updated_objects)
print("")
return config
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
if not safety_providers:
return None
provider = safety_providers[0]
assert provider.provider_type == "meta-reference"
cfg = provider.config["llama_guard_shield"]
if not cfg:
return None
return cfg["model"]
def configure_models(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ModelDef]:
model = prompt(
"> Please enter the model you want to serve: ",
default="Llama3.2-1B-Instruct",
validator=Validator.from_callable(
lambda x: resolve_model(x) is not None,
error_message="Model must be: {}".format(
[x.descriptor() for x in ALLOWED_MODELS]
),
),
)
model = ModelDef(
identifier=model,
llama_model=model,
provider_id=providers[0].provider_id,
)
ret = [model]
if llama_guard := get_llama_guard_model(safety_providers):
ret.append(
ModelDef(
identifier=llama_guard,
llama_model=llama_guard,
provider_id=providers[0].provider_id,
)
)
return ret
def configure_shields(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ShieldDef]:
if get_llama_guard_model(safety_providers):
return [
ShieldDef(
identifier="llama_guard",
type="llama_guard",
provider_id=providers[0].provider_id,
params={},
)
]
return []
def configure_memory_banks(
providers: List[Provider], safety_providers: List[Provider]
) -> List[MemoryBankDef]:
bank_name = prompt(
"> Please enter a name for your memory bank: ",
default="my-memory-bank",
)
return [
VectorMemoryBankDef(
identifier=bank_name,
provider_id=providers[0].provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
]
def upgrade_from_routing_table_to_registry(
def upgrade_from_routing_table(
config_dict: Dict[str, Any],
) -> Dict[str, Any]:
def get_providers(entries):
@ -281,57 +139,12 @@ def upgrade_from_routing_table_to_registry(
]
providers_by_api = {}
models = []
shields = []
memory_banks = []
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
if api_str == "inference":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
models.append(
ModelDef(
identifier=key,
provider_id=provider.provider_id,
llama_model=key,
)
)
elif api_str == "safety":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
shields.append(
ShieldDef(
identifier=key,
type=ShieldType.llama_guard.value,
provider_id=provider.provider_id,
)
)
elif api_str == "memory":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
# we currently only support Vector memory banks so this is OK
memory_banks.append(
VectorMemoryBankDef(
identifier=key,
provider_id=provider.provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
)
config_dict["models"] = models
config_dict["shields"] = shields
config_dict["memory_banks"] = memory_banks
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
@ -361,9 +174,9 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict)
if "models" not in config_dict:
if "routing_table" in config_dict:
print("Upgrading config...")
config_dict = upgrade_from_routing_table_to_registry(config_dict)
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
config_dict["built_at"] = datetime.now().isoformat()

View file

@ -32,6 +32,12 @@ RoutableObject = Union[
MemoryBankDef,
]
RoutableObjectWithProvider = Union[
ModelDefWithProvider,
ShieldDef,
MemoryBankDef,
]
RoutedProtocol = Union[
Inference,
Safety,
@ -63,7 +69,6 @@ class RoutingTableProviderSpec(ProviderSpec):
docker_image: Optional[str] = None
router_api: Api
registry: List[RoutableObject]
module: str
pip_packages: List[str] = Field(default_factory=list)
@ -121,25 +126,6 @@ can be instantiated multiple times (with different configs) if necessary.
""",
)
models: List[ModelDef] = Field(
description="""
List of model definitions to serve. This list may get extended by
/models/register API calls at runtime.
""",
)
shields: List[ShieldDef] = Field(
description="""
List of shield definitions to serve. This list may get extended by
/shields/register API calls at runtime.
""",
)
memory_banks: List[MemoryBankDef] = Field(
description="""
List of memory bank definitions to serve. This list may get extended by
/memory_banks/register API calls at runtime.
""",
)
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -4,10 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import inspect
from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
@ -15,6 +27,28 @@ from llama_stack.distribution.distribution import (
from llama_stack.distribution.utils.dynamic import instantiate_class_type
def api_protocol_map() -> Dict[Api, Any]:
return {
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.memory: Memory,
Api.memory_banks: MemoryBanks,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
}
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: ModelsProtocolPrivate,
Api.memory: MemoryBanksProtocolPrivate,
Api.safety: ShieldsProtocolPrivate,
}
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
@ -73,17 +107,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
inner_deps = []
registry = getattr(run_config, info.routing_table_api.value)
for entry in registry:
if entry.provider_id not in available_providers:
raise ValueError(
f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
)
provider = available_providers[entry.provider_id]
inner_deps.extend(provider.spec.api_dependencies)
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__builtin__",
@ -92,13 +115,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
registry=registry,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
deps__=(
[x.value for x in inner_deps]
+ [f"inner-{info.router_api.value}"]
),
api_dependencies=[],
deps__=([f"inner-{info.router_api.value}"]),
),
)
}
@ -212,6 +231,9 @@ async def instantiate_provider(
deps: Dict[str, Any],
inner_impls: Dict[str, Any],
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
module = importlib.import_module(provider_spec.module)
@ -234,7 +256,7 @@ async def instantiate_provider(
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
args = [provider_spec.api, inner_impls, deps]
else:
method = "get_provider_impl"
@ -247,4 +269,55 @@ async def instantiate_provider(
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
missing_methods = []
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):
missing_methods.append((name, "not_callable"))
else:
# Check if the method signatures are compatible
obj_method = getattr(obj, name)
proto_sig = inspect.signature(value)
obj_sig = inspect.signature(obj_method)
proto_params = set(proto_sig.parameters)
proto_params.discard("self")
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
print(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from .routing_tables import (
@ -16,7 +16,6 @@ from .routing_tables import (
async def get_routing_table_impl(
api: Api,
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
) -> Any:
@ -28,7 +27,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](registry, impls_by_provider_id)
impl = api_to_tables[api.value](impls_by_provider_id)
await impl.initialize()
return impl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
@ -29,115 +29,145 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
await p.register_memory_bank(obj)
Registry = Dict[str, List[RoutableObjectWithProvider]]
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
) -> None:
for obj in registry:
if obj.provider_id not in impls_by_provider_id:
print(f"{impls_by_provider_id=}")
raise ValueError(
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
)
self.impls_by_provider_id = impls_by_provider_id
self.registry = registry
for p in self.impls_by_provider_id.values():
async def initialize(self) -> None:
self.registry: Registry = {}
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
for obj in objs:
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
models = await p.list_models()
add_objects(
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
)
elif api == Api.safety:
p.shield_store = self
shields = await p.list_shields()
add_objects(
[
ShieldDefWithProvider(**s.dict(), provider_id=pid)
for s in shields
]
)
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
self.routing_key_to_object = {}
for obj in self.registry:
self.routing_key_to_object[obj.identifier] = obj
# do in-memory updates due to pesky Annotated unions
for m in memory_banks:
m.provider_id = pid
async def initialize(self) -> None:
for obj in self.registry:
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
add_objects(memory_banks)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.routing_key_to_object:
def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
if routing_key not in self.registry:
raise ValueError(f"`{routing_key}` not registered")
obj = self.routing_key_to_object[routing_key]
objs = self.registry[routing_key]
for obj in objs:
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
def get_object_by_identifier(
self, identifier: str
) -> Optional[RoutableObjectWithProvider]:
objs = self.registry.get(identifier, [])
if not objs:
return None
# kind of ill-defined behavior here, but we'll just return the first one
return objs[0]
async def register_object(self, obj: RoutableObjectWithProvider):
entries = self.registry.get(obj.identifier, [])
for entry in entries:
if entry.provider_id == obj.provider_id:
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
return
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
return self.impls_by_provider_id[obj.provider_id]
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
for obj in self.registry:
if obj.identifier == identifier:
return obj
return None
async def register_object(self, obj: RoutableObject):
if obj.identifier in self.routing_key_to_object:
print(f"`{obj.identifier}` is already registered")
return
if not obj.provider_id:
provider_ids = list(self.impls_by_provider_id.keys())
if not provider_ids:
raise ValueError("No providers found")
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
obj.provider_id = provider_ids[0]
else:
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj)
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
# TODO: persist this to a store
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDef]:
return self.registry
async def list_models(self) -> List[ModelDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_model(self, identifier: str) -> Optional[ModelDef]:
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None:
async def register_model(self, model: ModelDefWithProvider) -> None:
await self.register_object(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]:
return self.registry
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
await self.register_object(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.registry
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]:
return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object(bank)
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None:
await self.register_object(memory_bank)

View file

@ -9,15 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
@ -31,18 +23,7 @@ class ApiEndpoint(BaseModel):
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
Api.inspect: Inspect,
}
protocols = api_protocol_map()
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)

View file

@ -121,3 +121,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
"stream": request.stream,
**options,
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -15,6 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@ -35,7 +36,7 @@ OLLAMA_SUPPORTED_MODELS = {
}
class OllamaInferenceAdapter(Inference):
class OllamaInferenceAdapter(Inference, Models):
def __init__(self, url: str) -> None:
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())

View file

@ -6,14 +6,18 @@
import logging
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_models.sku_list import all_registered_models
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@ -30,26 +34,47 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
logger = logging.getLogger(__name__)
class _HfAdapter(Inference):
class _HfAdapter(Inference, ModelsProtocolPrivate):
client: AsyncInferenceClient
max_tokens: int
model_id: str
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
async def register_model(self, model: ModelDef) -> None:
resolved_model = resolve_model(model.identifier)
if resolved_model is None:
raise ValueError(f"Unknown model: {model.identifier}")
raise ValueError("Model registration is not supported for HuggingFace models")
if not resolved_model.huggingface_repo:
raise ValueError(
f"Model {model.identifier} does not have a HuggingFace repo"
async def list_models(self) -> List[ModelDef]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
ModelDef(
identifier=identifier,
llama_model=identifier,
metadata={
"huggingface_repo": repo,
},
)
]
if self.model_id != resolved_model.huggingface_repo:
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
async def get_model(self, identifier: str) -> Optional[ModelDef]:
model = self.huggingface_repo_to_llama_model_id.get(self.model_id)
if model != identifier:
return None
return ModelDef(
identifier=model,
llama_model=model,
metadata={
"huggingface_repo": self.model_id,
},
)
async def shutdown(self) -> None:
pass
@ -145,6 +170,13 @@ class _HfAdapter(Inference):
**options,
)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:

View file

@ -134,3 +134,10 @@ class TogetherInferenceAdapter(
"stream": request.stream,
**get_sampling_options(request),
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -10,6 +10,11 @@ from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
from llama_stack.apis.shields import ShieldDef
@json_schema_type
class Api(Enum):
@ -28,6 +33,30 @@ class Api(Enum):
inspect = "inspect"
class ModelsProtocolPrivate(Protocol):
async def list_models(self) -> List[ModelDef]: ...
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
async def register_model(self, model: ModelDef) -> None: ...
class ShieldsProtocolPrivate(Protocol):
async def list_shields(self) -> List[ShieldDef]: ...
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...
async def register_shield(self, shield: ShieldDef) -> None: ...
class MemoryBanksProtocolPrivate(Protocol):
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api

View file

@ -12,6 +12,7 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
@ -24,7 +25,7 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference):
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
@ -39,14 +40,38 @@ class MetaReferenceInferenceImpl(Inference):
self.generator.start()
async def register_model(self, model: ModelDef) -> None:
if model.identifier != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
raise ValueError("Dynamic model registration is not supported")
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(
identifier=self.model.descriptor(),
llama_model=self.model.descriptor(),
)
]
async def get_model(self, identifier: str) -> Optional[ModelDef]:
if self.model.descriptor() != identifier:
return None
return ModelDef(
identifier=self.model.descriptor(),
llama_model=self.model.descriptor(),
)
async def shutdown(self) -> None:
self.generator.stop()
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
def chat_completion(
self,
model: str,
@ -255,3 +280,10 @@ class MetaReferenceInferenceImpl(Inference):
stop_reason=stop_reason,
)
)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -15,6 +15,8 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
@ -61,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory):
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
@ -83,6 +85,16 @@ class FaissMemoryImpl(Memory):
)
self.cache[memory_bank.identifier] = index
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
banks = await self.list_memory_banks()
for bank in banks:
if bank.identifier == identifier:
return bank
return None
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [i.bank for i in self.cache.values()]
async def insert_documents(
self,
bank_id: str,

View file

@ -28,7 +28,7 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
# --tb=short --disable-warnings
# ```
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
# {"model": Llama_3B},
],
ids=lambda d: d["model"],
)
@ -64,16 +64,11 @@ async def inference_settings(request):
model = request.param["model"]
impls = await resolve_impls_for_test(
Api.inference,
models=[
ModelDef(
identifier=model,
llama_model=model,
)
],
)
return {
"impl": impls[Api.inference],
"models_impl": impls[Api.models],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,
@ -108,6 +103,25 @@ def sample_tool_definition():
)
@pytest.mark.asyncio
async def test_model_list(inference_settings):
params = inference_settings["common_params"]
models_impl = inference_settings["models_impl"]
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
assert all(isinstance(model, ModelDefWithProvider) for model in response)
model_def = None
for model in response:
if model.identifier == params["model"]:
model_def = model
break
assert model_def is not None
assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
@ -30,12 +31,14 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
@pytest_asyncio.fixture(scope="session")
async def memory_impl():
async def memory_settings():
impls = await resolve_impls_for_test(
Api.memory,
memory_banks=[],
)
return impls[Api.memory]
return {
"memory_impl": impls[Api.memory],
"memory_banks_impl": impls[Api.memory_banks],
}
@pytest.fixture
@ -64,23 +67,35 @@ def sample_documents():
]
async def register_memory_bank(memory_impl: Memory):
async def register_memory_bank(banks_impl: MemoryBanks):
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_id=os.environ["PROVIDER_ID"],
)
await memory_impl.register_memory_bank(bank)
await banks_impl.register_memory_bank(bank)
@pytest.mark.asyncio
async def test_query_documents(memory_impl, sample_documents):
async def test_banks_list(memory_settings):
banks_impl = memory_settings["memory_banks_impl"]
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_query_documents(memory_settings, sample_documents):
memory_impl = memory_settings["memory_impl"]
banks_impl = memory_settings["memory_banks_impl"]
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await register_memory_bank(memory_impl)
await register_memory_bank(banks_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
query1 = "programming language"

View file

@ -18,9 +18,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
async def resolve_impls_for_test(
api: Api,
models: List[ModelDef] = None,
memory_banks: List[MemoryBankDef] = None,
shields: List[ShieldDef] = None,
):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
@ -47,45 +44,11 @@ async def resolve_impls_for_test(
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
models = models or []
shields = shields or []
memory_banks = memory_banks or []
models = [
ModelDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in models
]
shields = [
ShieldDef(
**{
**s.dict(),
"provider_id": provider_id,
}
)
for s in shields
]
memory_banks = [
MemoryBankDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in memory_banks
]
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api],
providers={api.value: [Provider(**provider)]},
models=models,
memory_banks=memory_banks,
shields=shields,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Dict, List, Optional
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
class ModelRegistryHelper:
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
@ -33,3 +33,15 @@ class ModelRegistryHelper:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)
async def list_models(self) -> List[ModelDef]:
models = []
for llama_model, provider_model in self.stack_to_provider_models_map.items():
models.append(ModelDef(identifier=llama_model, llama_model=llama_model))
return models
async def get_model(self, identifier: str) -> Optional[ModelDef]:
if identifier not in self.stack_to_provider_models_map:
return None
return ModelDef(identifier=identifier, llama_model=identifier)