mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
Another round of simplification and clarity for models/shields/memory_banks stuff
This commit is contained in:
parent
73a0a34e39
commit
b55034c0de
27 changed files with 454 additions and 444 deletions
|
@ -6,7 +6,16 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -404,6 +413,7 @@ class AgentStepResponse(BaseModel):
|
|||
step: Step
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
@webmethod(route="/agents/create")
|
||||
async def create_agent(
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -47,6 +47,7 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
completion_message_batch: List[CompletionMessage]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BatchInference(Protocol):
|
||||
@webmethod(route="/batch_inference/completion")
|
||||
async def batch_completion(
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, Union
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -177,6 +177,7 @@ class ModelStore(Protocol):
|
|||
def get_model(self, identifier: str) -> ModelDef: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
|
@ -214,6 +215,3 @@ class Inference(Protocol):
|
|||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse: ...
|
||||
|
||||
@webmethod(route="/inference/register_model")
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
@ -29,6 +29,7 @@ class HealthInfo(BaseModel):
|
|||
# TODO: add a provider level status
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/providers/list", method="GET")
|
||||
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -42,6 +42,7 @@ class MemoryBankStore(Protocol):
|
|||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Memory(Protocol):
|
||||
memory_bank_store: MemoryBankStore
|
||||
|
||||
|
@ -55,13 +56,6 @@ class Memory(Protocol):
|
|||
ttl_seconds: Optional[int] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/update")
|
||||
async def update_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/query")
|
||||
async def query_documents(
|
||||
self,
|
||||
|
@ -69,20 +63,3 @@ class Memory(Protocol):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
|
||||
@webmethod(route="/memory/documents/get", method="GET")
|
||||
async def get_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> List[MemoryBankDocument]: ...
|
||||
|
||||
@webmethod(route="/memory/documents/delete", method="DELETE")
|
||||
async def delete_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/register_memory_bank")
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Protocol, Union
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -22,7 +22,8 @@ class MemoryBankType(Enum):
|
|||
|
||||
class CommonDef(BaseModel):
|
||||
identifier: str
|
||||
provider_id: Optional[str] = None
|
||||
# Hack: move this out later
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -58,13 +59,20 @@ MemoryBankDef = Annotated[
|
|||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
MemoryBankDefWithProvider = MemoryBankDef
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory_banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/get", method="GET")
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
) -> Optional[MemoryBankDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/register", method="POST")
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
) -> None: ...
|
||||
|
|
|
@ -4,34 +4,39 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelDef(BaseModel):
|
||||
identifier: str = Field(
|
||||
description="A unique identifier for the model type",
|
||||
description="A unique name for the model type",
|
||||
)
|
||||
llama_model: str = Field(
|
||||
description="Pointer to the core Llama family model",
|
||||
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
||||
)
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this model"
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this model",
|
||||
)
|
||||
# For now, we are only supporting core llama models but as soon as finetuned
|
||||
# and other custom models (for example various quantizations) are allowed, there
|
||||
# will be more metadata fields here
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelDefWithProvider(ModelDef):
|
||||
provider_id: str = Field(
|
||||
description="The provider ID for this model",
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models/list", method="GET")
|
||||
async def list_models(self) -> List[ModelDef]: ...
|
||||
async def list_models(self) -> List[ModelDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/models/get", method="GET")
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/models/register", method="POST")
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
async def register_model(self, model: ModelDefWithProvider) -> None: ...
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Protocol
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
@ -42,6 +42,7 @@ class ShieldStore(Protocol):
|
|||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Safety(Protocol):
|
||||
shield_store: ShieldStore
|
||||
|
||||
|
@ -49,6 +50,3 @@ class Safety(Protocol):
|
|||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse: ...
|
||||
|
||||
@webmethod(route="/safety/register_shield")
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -26,21 +26,26 @@ class ShieldDef(BaseModel):
|
|||
type: str = Field(
|
||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||
)
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this shield"
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional parameters needed for this shield",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldDefWithProvider(ShieldDef):
|
||||
provider_id: str = Field(
|
||||
description="The provider ID for this shield type",
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields/list", method="GET")
|
||||
async def list_shields(self) -> List[ShieldDef]: ...
|
||||
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/register", method="POST")
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -123,6 +123,7 @@ Event = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/log_event")
|
||||
async def log_event(self, event: Event) -> None: ...
|
||||
|
|
|
@ -7,17 +7,7 @@ import textwrap
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_models.sku_list import (
|
||||
llama3_1_family,
|
||||
llama3_2_family,
|
||||
llama3_family,
|
||||
resolve_model,
|
||||
safety_models,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.validation import Validator
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.distribution import (
|
||||
|
@ -33,11 +23,6 @@ from llama_stack.apis.shields import * # noqa: F403
|
|||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
ALLOWED_MODELS = (
|
||||
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
||||
)
|
||||
|
||||
|
||||
def configure_single_provider(
|
||||
registry: Dict[str, ProviderSpec], provider: Provider
|
||||
) -> Provider:
|
||||
|
@ -133,137 +118,10 @@ def configure_api_providers(
|
|||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
=========================================================================================
|
||||
Now let's configure the `objects` you will be serving via the stack. These are:
|
||||
|
||||
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
|
||||
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
|
||||
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
|
||||
|
||||
This wizard will guide you through setting up one of each of these objects. You can
|
||||
always add more later by editing the run.yaml file.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
object_types = {
|
||||
"models": (ModelDef, configure_models, "inference"),
|
||||
"shields": (ShieldDef, configure_shields, "safety"),
|
||||
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||
}
|
||||
safety_providers = config.providers.get("safety", [])
|
||||
|
||||
for otype, (odef, config_method, api_str) in object_types.items():
|
||||
existing_objects = getattr(config, otype)
|
||||
|
||||
if existing_objects:
|
||||
cprint(
|
||||
f"{len(existing_objects)} {otype} exist. Skipping...",
|
||||
"blue",
|
||||
attrs=["bold"],
|
||||
)
|
||||
updated_objects = existing_objects
|
||||
else:
|
||||
providers = config.providers.get(api_str, [])
|
||||
if not providers:
|
||||
updated_objects = []
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||
updated_objects = config_method(
|
||||
config.providers[api_str], safety_providers
|
||||
)
|
||||
|
||||
setattr(config, otype, updated_objects)
|
||||
print("")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
|
||||
if not safety_providers:
|
||||
return None
|
||||
|
||||
provider = safety_providers[0]
|
||||
assert provider.provider_type == "meta-reference"
|
||||
|
||||
cfg = provider.config["llama_guard_shield"]
|
||||
if not cfg:
|
||||
return None
|
||||
return cfg["model"]
|
||||
|
||||
|
||||
def configure_models(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[ModelDef]:
|
||||
model = prompt(
|
||||
"> Please enter the model you want to serve: ",
|
||||
default="Llama3.2-1B-Instruct",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: resolve_model(x) is not None,
|
||||
error_message="Model must be: {}".format(
|
||||
[x.descriptor() for x in ALLOWED_MODELS]
|
||||
),
|
||||
),
|
||||
)
|
||||
model = ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
provider_id=providers[0].provider_id,
|
||||
)
|
||||
|
||||
ret = [model]
|
||||
if llama_guard := get_llama_guard_model(safety_providers):
|
||||
ret.append(
|
||||
ModelDef(
|
||||
identifier=llama_guard,
|
||||
llama_model=llama_guard,
|
||||
provider_id=providers[0].provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def configure_shields(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[ShieldDef]:
|
||||
if get_llama_guard_model(safety_providers):
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier="llama_guard",
|
||||
type="llama_guard",
|
||||
provider_id=providers[0].provider_id,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def configure_memory_banks(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[MemoryBankDef]:
|
||||
bank_name = prompt(
|
||||
"> Please enter a name for your memory bank: ",
|
||||
default="my-memory-bank",
|
||||
)
|
||||
|
||||
return [
|
||||
VectorMemoryBankDef(
|
||||
identifier=bank_name,
|
||||
provider_id=providers[0].provider_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def upgrade_from_routing_table_to_registry(
|
||||
def upgrade_from_routing_table(
|
||||
config_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
|
@ -281,57 +139,12 @@ def upgrade_from_routing_table_to_registry(
|
|||
]
|
||||
|
||||
providers_by_api = {}
|
||||
models = []
|
||||
shields = []
|
||||
memory_banks = []
|
||||
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
||||
if api_str == "inference":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
models.append(
|
||||
ModelDef(
|
||||
identifier=key,
|
||||
provider_id=provider.provider_id,
|
||||
llama_model=key,
|
||||
)
|
||||
)
|
||||
elif api_str == "safety":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
shields.append(
|
||||
ShieldDef(
|
||||
identifier=key,
|
||||
type=ShieldType.llama_guard.value,
|
||||
provider_id=provider.provider_id,
|
||||
)
|
||||
)
|
||||
elif api_str == "memory":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
# we currently only support Vector memory banks so this is OK
|
||||
memory_banks.append(
|
||||
VectorMemoryBankDef(
|
||||
identifier=key,
|
||||
provider_id=provider.provider_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
)
|
||||
config_dict["models"] = models
|
||||
config_dict["shields"] = shields
|
||||
config_dict["memory_banks"] = memory_banks
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
|
@ -361,9 +174,9 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi
|
|||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
return StackRunConfig(**config_dict)
|
||||
|
||||
if "models" not in config_dict:
|
||||
if "routing_table" in config_dict:
|
||||
print("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table_to_registry(config_dict)
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
config_dict["built_at"] = datetime.now().isoformat()
|
||||
|
|
|
@ -32,6 +32,12 @@ RoutableObject = Union[
|
|||
MemoryBankDef,
|
||||
]
|
||||
|
||||
RoutableObjectWithProvider = Union[
|
||||
ModelDefWithProvider,
|
||||
ShieldDef,
|
||||
MemoryBankDef,
|
||||
]
|
||||
|
||||
RoutedProtocol = Union[
|
||||
Inference,
|
||||
Safety,
|
||||
|
@ -63,7 +69,6 @@ class RoutingTableProviderSpec(ProviderSpec):
|
|||
docker_image: Optional[str] = None
|
||||
|
||||
router_api: Api
|
||||
registry: List[RoutableObject]
|
||||
module: str
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
@ -121,25 +126,6 @@ can be instantiated multiple times (with different configs) if necessary.
|
|||
""",
|
||||
)
|
||||
|
||||
models: List[ModelDef] = Field(
|
||||
description="""
|
||||
List of model definitions to serve. This list may get extended by
|
||||
/models/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
shields: List[ShieldDef] = Field(
|
||||
description="""
|
||||
List of shield definitions to serve. This list may get extended by
|
||||
/shields/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
memory_banks: List[MemoryBankDef] = Field(
|
||||
description="""
|
||||
List of memory bank definitions to serve. This list may get extended by
|
||||
/memory_banks/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
@ -4,10 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
|
@ -15,6 +27,28 @@ from llama_stack.distribution.distribution import (
|
|||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.memory: Memory,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
Api.telemetry: Telemetry,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: ModelsProtocolPrivate,
|
||||
Api.memory: MemoryBanksProtocolPrivate,
|
||||
Api.safety: ShieldsProtocolPrivate,
|
||||
}
|
||||
|
||||
|
||||
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
||||
class ProviderWithSpec(Provider):
|
||||
spec: ProviderSpec
|
||||
|
@ -73,17 +107,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
|
||||
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
||||
|
||||
inner_deps = []
|
||||
registry = getattr(run_config, info.routing_table_api.value)
|
||||
for entry in registry:
|
||||
if entry.provider_id not in available_providers:
|
||||
raise ValueError(
|
||||
f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
|
||||
)
|
||||
|
||||
provider = available_providers[entry.provider_id]
|
||||
inner_deps.extend(provider.spec.api_dependencies)
|
||||
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
|
@ -92,13 +115,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
registry=registry,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
deps__=(
|
||||
[x.value for x in inner_deps]
|
||||
+ [f"inner-{info.router_api.value}"]
|
||||
),
|
||||
api_dependencies=[],
|
||||
deps__=([f"inner-{info.router_api.value}"]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
@ -212,6 +231,9 @@ async def instantiate_provider(
|
|||
deps: Dict[str, Any],
|
||||
inner_impls: Dict[str, Any],
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
|
@ -234,7 +256,7 @@ async def instantiate_provider(
|
|||
method = "get_routing_table_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
|
||||
args = [provider_spec.api, inner_impls, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
|
@ -247,4 +269,55 @@ async def instantiate_provider(
|
|||
impl.__provider_id__ = provider.provider_id
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if (
|
||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
additional_api = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||
missing_methods = []
|
||||
|
||||
mro = type(obj).__mro__
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||
if not hasattr(obj, name):
|
||||
missing_methods.append((name, "missing"))
|
||||
elif not callable(getattr(obj, name)):
|
||||
missing_methods.append((name, "not_callable"))
|
||||
else:
|
||||
# Check if the method signatures are compatible
|
||||
obj_method = getattr(obj, name)
|
||||
proto_sig = inspect.signature(value)
|
||||
obj_sig = inspect.signature(obj_method)
|
||||
|
||||
proto_params = set(proto_sig.parameters)
|
||||
proto_params.discard("self")
|
||||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
print(
|
||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||
)
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next(
|
||||
(cls for cls in mro if name in cls.__dict__), None
|
||||
)
|
||||
if (
|
||||
method_owner is None
|
||||
or method_owner.__name__ == protocol.__name__
|
||||
):
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
raise ValueError(
|
||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .routing_tables import (
|
||||
|
@ -16,7 +16,6 @@ from .routing_tables import (
|
|||
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
registry: List[RoutableObject],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
) -> Any:
|
||||
|
@ -28,7 +27,7 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](registry, impls_by_provider_id)
|
||||
impl = api_to_tables[api.value](impls_by_provider_id)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
|
@ -29,115 +29,145 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
|||
await p.register_memory_bank(obj)
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
# TODO: this routing table maintains state in memory purely. We need to
|
||||
# add persistence to it when we add dynamic registration of objects.
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
registry: List[RoutableObject],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
) -> None:
|
||||
for obj in registry:
|
||||
if obj.provider_id not in impls_by_provider_id:
|
||||
print(f"{impls_by_provider_id=}")
|
||||
raise ValueError(
|
||||
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
||||
)
|
||||
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.registry = registry
|
||||
|
||||
for p in self.impls_by_provider_id.values():
|
||||
async def initialize(self) -> None:
|
||||
self.registry: Registry = {}
|
||||
|
||||
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
|
||||
for obj in objs:
|
||||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
|
||||
self.registry[obj.identifier].append(obj)
|
||||
|
||||
for pid, p in self.impls_by_provider_id.items():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
models = await p.list_models()
|
||||
add_objects(
|
||||
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
|
||||
)
|
||||
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
shields = await p.list_shields()
|
||||
add_objects(
|
||||
[
|
||||
ShieldDefWithProvider(**s.dict(), provider_id=pid)
|
||||
for s in shields
|
||||
]
|
||||
)
|
||||
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
memory_banks = await p.list_memory_banks()
|
||||
|
||||
self.routing_key_to_object = {}
|
||||
for obj in self.registry:
|
||||
self.routing_key_to_object[obj.identifier] = obj
|
||||
# do in-memory updates due to pesky Annotated unions
|
||||
for m in memory_banks:
|
||||
m.provider_id = pid
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for obj in self.registry:
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await register_object_with_provider(obj, p)
|
||||
add_objects(memory_banks)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.routing_key_to_object:
|
||||
def get_provider_impl(
|
||||
self, routing_key: str, provider_id: Optional[str] = None
|
||||
) -> Any:
|
||||
if routing_key not in self.registry:
|
||||
raise ValueError(f"`{routing_key}` not registered")
|
||||
|
||||
obj = self.routing_key_to_object[routing_key]
|
||||
objs = self.registry[routing_key]
|
||||
for obj in objs:
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
def get_object_by_identifier(
|
||||
self, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
objs = self.registry.get(identifier, [])
|
||||
if not objs:
|
||||
return None
|
||||
|
||||
# kind of ill-defined behavior here, but we'll just return the first one
|
||||
return objs[0]
|
||||
|
||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||
entries = self.registry.get(obj.identifier, [])
|
||||
for entry in entries:
|
||||
if entry.provider_id == obj.provider_id:
|
||||
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
|
||||
return
|
||||
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
|
||||
for obj in self.registry:
|
||||
if obj.identifier == identifier:
|
||||
return obj
|
||||
return None
|
||||
|
||||
async def register_object(self, obj: RoutableObject):
|
||||
if obj.identifier in self.routing_key_to_object:
|
||||
print(f"`{obj.identifier}` is already registered")
|
||||
return
|
||||
|
||||
if not obj.provider_id:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if not provider_ids:
|
||||
raise ValueError("No providers found")
|
||||
|
||||
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
|
||||
obj.provider_id = provider_ids[0]
|
||||
else:
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
self.routing_key_to_object[obj.identifier] = obj
|
||||
self.registry.append(obj)
|
||||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
self.registry[obj.identifier].append(obj)
|
||||
|
||||
# TODO: persist this to a store
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return self.registry
|
||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
async def register_model(self, model: ModelDefWithProvider) -> None:
|
||||
await self.register_object(model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return self.registry
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
return self.get_object_by_identifier(shield_type)
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
await self.register_object(shield)
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
return self.registry
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
) -> Optional[MemoryBankDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||
await self.register_object(bank)
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
) -> None:
|
||||
await self.register_object(memory_bank)
|
||||
|
|
|
@ -9,15 +9,7 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
@ -31,18 +23,7 @@ class ApiEndpoint(BaseModel):
|
|||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = {
|
||||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agents: Agents,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
Api.models: Models,
|
||||
Api.shields: Shields,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.inspect: Inspect,
|
||||
}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
|
|
@ -121,3 +121,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -15,6 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -35,7 +36,7 @@ OLLAMA_SUPPORTED_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference):
|
||||
class OllamaInferenceAdapter(Inference, Models):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
|
|
@ -6,14 +6,18 @@
|
|||
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -30,26 +34,47 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _HfAdapter(Inference):
|
||||
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||
client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor()
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
}
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
resolved_model = resolve_model(model.identifier)
|
||||
if resolved_model is None:
|
||||
raise ValueError(f"Unknown model: {model.identifier}")
|
||||
raise ValueError("Model registration is not supported for HuggingFace models")
|
||||
|
||||
if not resolved_model.huggingface_repo:
|
||||
raise ValueError(
|
||||
f"Model {model.identifier} does not have a HuggingFace repo"
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
repo = self.model_id
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
return [
|
||||
ModelDef(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
metadata={
|
||||
"huggingface_repo": repo,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
if self.model_id != resolved_model.huggingface_repo:
|
||||
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
model = self.huggingface_repo_to_llama_model_id.get(self.model_id)
|
||||
if model != identifier:
|
||||
return None
|
||||
|
||||
return ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
metadata={
|
||||
"huggingface_repo": self.model_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
@ -145,6 +170,13 @@ class _HfAdapter(Inference):
|
|||
**options,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
|
|
|
@ -134,3 +134,10 @@ class TogetherInferenceAdapter(
|
|||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -10,6 +10,11 @@ from typing import Any, List, Optional, Protocol
|
|||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.memory_banks import MemoryBankDef
|
||||
|
||||
from llama_stack.apis.models import ModelDef
|
||||
from llama_stack.apis.shields import ShieldDef
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
|
@ -28,6 +33,30 @@ class Api(Enum):
|
|||
inspect = "inspect"
|
||||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
async def list_models(self) -> List[ModelDef]: ...
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
|
||||
|
||||
class ShieldsProtocolPrivate(Protocol):
|
||||
async def list_shields(self) -> List[ShieldDef]: ...
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
|
||||
|
||||
class MemoryBanksProtocolPrivate(Protocol):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
||||
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderSpec(BaseModel):
|
||||
api: Api
|
||||
|
|
|
@ -12,6 +12,7 @@ from llama_models.sku_list import resolve_model
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
@ -24,7 +25,7 @@ from .model_parallel import LlamaModelParallelGenerator
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
|
@ -39,14 +40,38 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
self.generator.start()
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return [
|
||||
ModelDef(
|
||||
identifier=self.model.descriptor(),
|
||||
llama_model=self.model.descriptor(),
|
||||
)
|
||||
]
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
if self.model.descriptor() != identifier:
|
||||
return None
|
||||
|
||||
return ModelDef(
|
||||
identifier=self.model.descriptor(),
|
||||
llama_model=self.model.descriptor(),
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -255,3 +280,10 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -15,6 +15,8 @@ from numpy.typing import NDArray
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
|
@ -61,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class FaissMemoryImpl(Memory):
|
||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
|
@ -83,6 +85,16 @@ class FaissMemoryImpl(Memory):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||
banks = await self.list_memory_banks()
|
||||
for bank in banks:
|
||||
if bank.identifier == identifier:
|
||||
return bank
|
||||
return None
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
|
||||
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
|
|||
scope="session",
|
||||
params=[
|
||||
{"model": Llama_8B},
|
||||
{"model": Llama_3B},
|
||||
# {"model": Llama_3B},
|
||||
],
|
||||
ids=lambda d: d["model"],
|
||||
)
|
||||
|
@ -64,16 +64,11 @@ async def inference_settings(request):
|
|||
model = request.param["model"]
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.inference,
|
||||
models=[
|
||||
ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return {
|
||||
"impl": impls[Api.inference],
|
||||
"models_impl": impls[Api.models],
|
||||
"common_params": {
|
||||
"model": model,
|
||||
"tool_choice": ToolChoice.auto,
|
||||
|
@ -108,6 +103,25 @@ def sample_tool_definition():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list(inference_settings):
|
||||
params = inference_settings["common_params"]
|
||||
models_impl = inference_settings["models_impl"]
|
||||
response = await models_impl.list_models()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) >= 1
|
||||
assert all(isinstance(model, ModelDefWithProvider) for model in response)
|
||||
|
||||
model_def = None
|
||||
for model in response:
|
||||
if model.identifier == params["model"]:
|
||||
model_def = model
|
||||
break
|
||||
|
||||
assert model_def is not None
|
||||
assert model_def.identifier == params["model"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
@ -30,12 +31,14 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_impl():
|
||||
async def memory_settings():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.memory,
|
||||
memory_banks=[],
|
||||
)
|
||||
return impls[Api.memory]
|
||||
return {
|
||||
"memory_impl": impls[Api.memory],
|
||||
"memory_banks_impl": impls[Api.memory_banks],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -64,23 +67,35 @@ def sample_documents():
|
|||
]
|
||||
|
||||
|
||||
async def register_memory_bank(memory_impl: Memory):
|
||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
provider_id=os.environ["PROVIDER_ID"],
|
||||
)
|
||||
|
||||
await memory_impl.register_memory_bank(bank)
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(memory_impl, sample_documents):
|
||||
async def test_banks_list(memory_settings):
|
||||
banks_impl = memory_settings["memory_banks_impl"]
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(memory_settings, sample_documents):
|
||||
memory_impl = memory_settings["memory_impl"]
|
||||
banks_impl = memory_settings["memory_banks_impl"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
await register_memory_bank(memory_impl)
|
||||
await register_memory_bank(banks_impl)
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
query1 = "programming language"
|
||||
|
|
|
@ -18,9 +18,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|||
|
||||
async def resolve_impls_for_test(
|
||||
api: Api,
|
||||
models: List[ModelDef] = None,
|
||||
memory_banks: List[MemoryBankDef] = None,
|
||||
shields: List[ShieldDef] = None,
|
||||
):
|
||||
if "PROVIDER_CONFIG" not in os.environ:
|
||||
raise ValueError(
|
||||
|
@ -47,45 +44,11 @@ async def resolve_impls_for_test(
|
|||
provider_id = provider["provider_id"]
|
||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||
|
||||
models = models or []
|
||||
shields = shields or []
|
||||
memory_banks = memory_banks or []
|
||||
|
||||
models = [
|
||||
ModelDef(
|
||||
**{
|
||||
**m.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for m in models
|
||||
]
|
||||
shields = [
|
||||
ShieldDef(
|
||||
**{
|
||||
**s.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for s in shields
|
||||
]
|
||||
memory_banks = [
|
||||
MemoryBankDef(
|
||||
**{
|
||||
**m.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for m in memory_banks
|
||||
]
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=[api],
|
||||
providers={api.value: [Provider(**provider)]},
|
||||
models=models,
|
||||
memory_banks=memory_banks,
|
||||
shields=shields,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls_with_routing(run_config)
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
|
||||
class ModelRegistryHelper:
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
|
||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
|
@ -33,3 +33,15 @@ class ModelRegistryHelper:
|
|||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
models = []
|
||||
for llama_model, provider_model in self.stack_to_provider_models_map.items():
|
||||
models.append(ModelDef(identifier=llama_model, llama_model=llama_model))
|
||||
return models
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
if identifier not in self.stack_to_provider_models_map:
|
||||
return None
|
||||
|
||||
return ModelDef(identifier=identifier, llama_model=identifier)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue