mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Another round of simplification and clarity for models/shields/memory_banks stuff
This commit is contained in:
parent
73a0a34e39
commit
b55034c0de
27 changed files with 454 additions and 444 deletions
|
@ -6,7 +6,16 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
runtime_checkable,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
@ -404,6 +413,7 @@ class AgentStepResponse(BaseModel):
|
||||||
step: Step
|
step: Step
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Agents(Protocol):
|
class Agents(Protocol):
|
||||||
@webmethod(route="/agents/create")
|
@webmethod(route="/agents/create")
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
completion_message_batch: List[CompletionMessage]
|
completion_message_batch: List[CompletionMessage]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class BatchInference(Protocol):
|
class BatchInference(Protocol):
|
||||||
@webmethod(route="/batch_inference/completion")
|
@webmethod(route="/batch_inference/completion")
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Protocol, Union
|
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
@ -177,6 +177,7 @@ class ModelStore(Protocol):
|
||||||
def get_model(self, identifier: str) -> ModelDef: ...
|
def get_model(self, identifier: str) -> ModelDef: ...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
model_store: ModelStore
|
model_store: ModelStore
|
||||||
|
|
||||||
|
@ -214,6 +215,3 @@ class Inference(Protocol):
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse: ...
|
) -> EmbeddingsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/register_model")
|
|
||||||
async def register_model(self, model: ModelDef) -> None: ...
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict, List, Protocol
|
from typing import Dict, List, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -29,6 +29,7 @@ class HealthInfo(BaseModel):
|
||||||
# TODO: add a provider level status
|
# TODO: add a provider level status
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Inspect(Protocol):
|
class Inspect(Protocol):
|
||||||
@webmethod(route="/providers/list", method="GET")
|
@webmethod(route="/providers/list", method="GET")
|
||||||
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
|
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ class MemoryBankStore(Protocol):
|
||||||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Memory(Protocol):
|
class Memory(Protocol):
|
||||||
memory_bank_store: MemoryBankStore
|
memory_bank_store: MemoryBankStore
|
||||||
|
|
||||||
|
@ -55,13 +56,6 @@ class Memory(Protocol):
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory/update")
|
|
||||||
async def update_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory/query")
|
@webmethod(route="/memory/query")
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -69,20 +63,3 @@ class Memory(Protocol):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse: ...
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/memory/documents/get", method="GET")
|
|
||||||
async def get_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
document_ids: List[str],
|
|
||||||
) -> List[MemoryBankDocument]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory/documents/delete", method="DELETE")
|
|
||||||
async def delete_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
document_ids: List[str],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory/register_memory_bank")
|
|
||||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, Optional, Protocol, Union
|
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -22,7 +22,8 @@ class MemoryBankType(Enum):
|
||||||
|
|
||||||
class CommonDef(BaseModel):
|
class CommonDef(BaseModel):
|
||||||
identifier: str
|
identifier: str
|
||||||
provider_id: Optional[str] = None
|
# Hack: move this out later
|
||||||
|
provider_id: str = ""
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -58,13 +59,20 @@ MemoryBankDef = Annotated[
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
MemoryBankDefWithProvider = MemoryBankDef
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class MemoryBanks(Protocol):
|
class MemoryBanks(Protocol):
|
||||||
@webmethod(route="/memory_banks/list", method="GET")
|
@webmethod(route="/memory_banks/list", method="GET")
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get", method="GET")
|
@webmethod(route="/memory_banks/get", method="GET")
|
||||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
async def get_memory_bank(
|
||||||
|
self, identifier: str
|
||||||
|
) -> Optional[MemoryBankDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/register", method="POST")
|
@webmethod(route="/memory_banks/register", method="POST")
|
||||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
async def register_memory_bank(
|
||||||
|
self, memory_bank: MemoryBankDefWithProvider
|
||||||
|
) -> None: ...
|
||||||
|
|
|
@ -4,34 +4,39 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModelDef(BaseModel):
|
class ModelDef(BaseModel):
|
||||||
identifier: str = Field(
|
identifier: str = Field(
|
||||||
description="A unique identifier for the model type",
|
description="A unique name for the model type",
|
||||||
)
|
)
|
||||||
llama_model: str = Field(
|
llama_model: str = Field(
|
||||||
description="Pointer to the core Llama family model",
|
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
||||||
)
|
)
|
||||||
provider_id: Optional[str] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
default=None, description="The provider instance which serves this model"
|
default_factory=dict,
|
||||||
|
description="Any additional metadata for this model",
|
||||||
)
|
)
|
||||||
# For now, we are only supporting core llama models but as soon as finetuned
|
|
||||||
# and other custom models (for example various quantizations) are allowed, there
|
|
||||||
# will be more metadata fields here
|
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModelDefWithProvider(ModelDef):
|
||||||
|
provider_id: str = Field(
|
||||||
|
description="The provider ID for this model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models/list", method="GET")
|
@webmethod(route="/models/list", method="GET")
|
||||||
async def list_models(self) -> List[ModelDef]: ...
|
async def list_models(self) -> List[ModelDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/models/get", method="GET")
|
@webmethod(route="/models/get", method="GET")
|
||||||
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/models/register", method="POST")
|
@webmethod(route="/models/register", method="POST")
|
||||||
async def register_model(self, model: ModelDef) -> None: ...
|
async def register_model(self, model: ModelDefWithProvider) -> None: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Protocol
|
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -42,6 +42,7 @@ class ShieldStore(Protocol):
|
||||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Safety(Protocol):
|
class Safety(Protocol):
|
||||||
shield_store: ShieldStore
|
shield_store: ShieldStore
|
||||||
|
|
||||||
|
@ -49,6 +50,3 @@ class Safety(Protocol):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/safety/register_shield")
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -26,21 +26,26 @@ class ShieldDef(BaseModel):
|
||||||
type: str = Field(
|
type: str = Field(
|
||||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||||
)
|
)
|
||||||
provider_id: Optional[str] = Field(
|
|
||||||
default=None, description="The provider instance which serves this shield"
|
|
||||||
)
|
|
||||||
params: Dict[str, Any] = Field(
|
params: Dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional parameters needed for this shield",
|
description="Any additional parameters needed for this shield",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ShieldDefWithProvider(ShieldDef):
|
||||||
|
provider_id: str = Field(
|
||||||
|
description="The provider ID for this shield type",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Shields(Protocol):
|
class Shields(Protocol):
|
||||||
@webmethod(route="/shields/list", method="GET")
|
@webmethod(route="/shields/list", method="GET")
|
||||||
async def list_shields(self) -> List[ShieldDef]: ...
|
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/get", method="GET")
|
@webmethod(route="/shields/get", method="GET")
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
|
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/register", method="POST")
|
@webmethod(route="/shields/register", method="POST")
|
||||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -123,6 +123,7 @@ Event = Annotated[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/log_event")
|
@webmethod(route="/telemetry/log_event")
|
||||||
async def log_event(self, event: Event) -> None: ...
|
async def log_event(self, event: Event) -> None: ...
|
||||||
|
|
|
@ -7,17 +7,7 @@ import textwrap
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_models.sku_list import (
|
|
||||||
llama3_1_family,
|
|
||||||
llama3_2_family,
|
|
||||||
llama3_family,
|
|
||||||
resolve_model,
|
|
||||||
safety_models,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from prompt_toolkit import prompt
|
|
||||||
from prompt_toolkit.validation import Validator
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
|
@ -33,11 +23,6 @@ from llama_stack.apis.shields import * # noqa: F403
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
ALLOWED_MODELS = (
|
|
||||||
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(
|
def configure_single_provider(
|
||||||
registry: Dict[str, ProviderSpec], provider: Provider
|
registry: Dict[str, ProviderSpec], provider: Provider
|
||||||
) -> Provider:
|
) -> Provider:
|
||||||
|
@ -133,137 +118,10 @@ def configure_api_providers(
|
||||||
|
|
||||||
config.providers[api_str] = updated_providers
|
config.providers[api_str] = updated_providers
|
||||||
|
|
||||||
if is_nux:
|
|
||||||
print(
|
|
||||||
textwrap.dedent(
|
|
||||||
"""
|
|
||||||
=========================================================================================
|
|
||||||
Now let's configure the `objects` you will be serving via the stack. These are:
|
|
||||||
|
|
||||||
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
|
|
||||||
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
|
|
||||||
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
|
|
||||||
|
|
||||||
This wizard will guide you through setting up one of each of these objects. You can
|
|
||||||
always add more later by editing the run.yaml file.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
object_types = {
|
|
||||||
"models": (ModelDef, configure_models, "inference"),
|
|
||||||
"shields": (ShieldDef, configure_shields, "safety"),
|
|
||||||
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
|
||||||
}
|
|
||||||
safety_providers = config.providers.get("safety", [])
|
|
||||||
|
|
||||||
for otype, (odef, config_method, api_str) in object_types.items():
|
|
||||||
existing_objects = getattr(config, otype)
|
|
||||||
|
|
||||||
if existing_objects:
|
|
||||||
cprint(
|
|
||||||
f"{len(existing_objects)} {otype} exist. Skipping...",
|
|
||||||
"blue",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
|
||||||
updated_objects = existing_objects
|
|
||||||
else:
|
|
||||||
providers = config.providers.get(api_str, [])
|
|
||||||
if not providers:
|
|
||||||
updated_objects = []
|
|
||||||
else:
|
|
||||||
# we are newly configuring this API
|
|
||||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
|
||||||
updated_objects = config_method(
|
|
||||||
config.providers[api_str], safety_providers
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(config, otype, updated_objects)
|
|
||||||
print("")
|
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
|
def upgrade_from_routing_table(
|
||||||
if not safety_providers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
provider = safety_providers[0]
|
|
||||||
assert provider.provider_type == "meta-reference"
|
|
||||||
|
|
||||||
cfg = provider.config["llama_guard_shield"]
|
|
||||||
if not cfg:
|
|
||||||
return None
|
|
||||||
return cfg["model"]
|
|
||||||
|
|
||||||
|
|
||||||
def configure_models(
|
|
||||||
providers: List[Provider], safety_providers: List[Provider]
|
|
||||||
) -> List[ModelDef]:
|
|
||||||
model = prompt(
|
|
||||||
"> Please enter the model you want to serve: ",
|
|
||||||
default="Llama3.2-1B-Instruct",
|
|
||||||
validator=Validator.from_callable(
|
|
||||||
lambda x: resolve_model(x) is not None,
|
|
||||||
error_message="Model must be: {}".format(
|
|
||||||
[x.descriptor() for x in ALLOWED_MODELS]
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model = ModelDef(
|
|
||||||
identifier=model,
|
|
||||||
llama_model=model,
|
|
||||||
provider_id=providers[0].provider_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = [model]
|
|
||||||
if llama_guard := get_llama_guard_model(safety_providers):
|
|
||||||
ret.append(
|
|
||||||
ModelDef(
|
|
||||||
identifier=llama_guard,
|
|
||||||
llama_model=llama_guard,
|
|
||||||
provider_id=providers[0].provider_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def configure_shields(
|
|
||||||
providers: List[Provider], safety_providers: List[Provider]
|
|
||||||
) -> List[ShieldDef]:
|
|
||||||
if get_llama_guard_model(safety_providers):
|
|
||||||
return [
|
|
||||||
ShieldDef(
|
|
||||||
identifier="llama_guard",
|
|
||||||
type="llama_guard",
|
|
||||||
provider_id=providers[0].provider_id,
|
|
||||||
params={},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def configure_memory_banks(
|
|
||||||
providers: List[Provider], safety_providers: List[Provider]
|
|
||||||
) -> List[MemoryBankDef]:
|
|
||||||
bank_name = prompt(
|
|
||||||
"> Please enter a name for your memory bank: ",
|
|
||||||
default="my-memory-bank",
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
VectorMemoryBankDef(
|
|
||||||
identifier=bank_name,
|
|
||||||
provider_id=providers[0].provider_id,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade_from_routing_table_to_registry(
|
|
||||||
config_dict: Dict[str, Any],
|
config_dict: Dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
def get_providers(entries):
|
def get_providers(entries):
|
||||||
|
@ -281,57 +139,12 @@ def upgrade_from_routing_table_to_registry(
|
||||||
]
|
]
|
||||||
|
|
||||||
providers_by_api = {}
|
providers_by_api = {}
|
||||||
models = []
|
|
||||||
shields = []
|
|
||||||
memory_banks = []
|
|
||||||
|
|
||||||
routing_table = config_dict.get("routing_table", {})
|
routing_table = config_dict.get("routing_table", {})
|
||||||
for api_str, entries in routing_table.items():
|
for api_str, entries in routing_table.items():
|
||||||
providers = get_providers(entries)
|
providers = get_providers(entries)
|
||||||
providers_by_api[api_str] = providers
|
providers_by_api[api_str] = providers
|
||||||
|
|
||||||
if api_str == "inference":
|
|
||||||
for entry, provider in zip(entries, providers):
|
|
||||||
key = entry["routing_key"]
|
|
||||||
keys = key if isinstance(key, list) else [key]
|
|
||||||
for key in keys:
|
|
||||||
models.append(
|
|
||||||
ModelDef(
|
|
||||||
identifier=key,
|
|
||||||
provider_id=provider.provider_id,
|
|
||||||
llama_model=key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif api_str == "safety":
|
|
||||||
for entry, provider in zip(entries, providers):
|
|
||||||
key = entry["routing_key"]
|
|
||||||
keys = key if isinstance(key, list) else [key]
|
|
||||||
for key in keys:
|
|
||||||
shields.append(
|
|
||||||
ShieldDef(
|
|
||||||
identifier=key,
|
|
||||||
type=ShieldType.llama_guard.value,
|
|
||||||
provider_id=provider.provider_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif api_str == "memory":
|
|
||||||
for entry, provider in zip(entries, providers):
|
|
||||||
key = entry["routing_key"]
|
|
||||||
keys = key if isinstance(key, list) else [key]
|
|
||||||
for key in keys:
|
|
||||||
# we currently only support Vector memory banks so this is OK
|
|
||||||
memory_banks.append(
|
|
||||||
VectorMemoryBankDef(
|
|
||||||
identifier=key,
|
|
||||||
provider_id=provider.provider_id,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
config_dict["models"] = models
|
|
||||||
config_dict["shields"] = shields
|
|
||||||
config_dict["memory_banks"] = memory_banks
|
|
||||||
|
|
||||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||||
if provider_map:
|
if provider_map:
|
||||||
for api_str, provider in provider_map.items():
|
for api_str, provider in provider_map.items():
|
||||||
|
@ -361,9 +174,9 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi
|
||||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||||
return StackRunConfig(**config_dict)
|
return StackRunConfig(**config_dict)
|
||||||
|
|
||||||
if "models" not in config_dict:
|
if "routing_table" in config_dict:
|
||||||
print("Upgrading config...")
|
print("Upgrading config...")
|
||||||
config_dict = upgrade_from_routing_table_to_registry(config_dict)
|
config_dict = upgrade_from_routing_table(config_dict)
|
||||||
|
|
||||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
config_dict["built_at"] = datetime.now().isoformat()
|
config_dict["built_at"] = datetime.now().isoformat()
|
||||||
|
|
|
@ -32,6 +32,12 @@ RoutableObject = Union[
|
||||||
MemoryBankDef,
|
MemoryBankDef,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
RoutableObjectWithProvider = Union[
|
||||||
|
ModelDefWithProvider,
|
||||||
|
ShieldDef,
|
||||||
|
MemoryBankDef,
|
||||||
|
]
|
||||||
|
|
||||||
RoutedProtocol = Union[
|
RoutedProtocol = Union[
|
||||||
Inference,
|
Inference,
|
||||||
Safety,
|
Safety,
|
||||||
|
@ -63,7 +69,6 @@ class RoutingTableProviderSpec(ProviderSpec):
|
||||||
docker_image: Optional[str] = None
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
router_api: Api
|
router_api: Api
|
||||||
registry: List[RoutableObject]
|
|
||||||
module: str
|
module: str
|
||||||
pip_packages: List[str] = Field(default_factory=list)
|
pip_packages: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@ -121,25 +126,6 @@ can be instantiated multiple times (with different configs) if necessary.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
models: List[ModelDef] = Field(
|
|
||||||
description="""
|
|
||||||
List of model definitions to serve. This list may get extended by
|
|
||||||
/models/register API calls at runtime.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
shields: List[ShieldDef] = Field(
|
|
||||||
description="""
|
|
||||||
List of shield definitions to serve. This list may get extended by
|
|
||||||
/shields/register API calls at runtime.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
memory_banks: List[MemoryBankDef] = Field(
|
|
||||||
description="""
|
|
||||||
List of memory bank definitions to serve. This list may get extended by
|
|
||||||
/memory_banks/register API calls at runtime.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
|
|
@ -4,10 +4,22 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.inspect import Inspect
|
||||||
|
from llama_stack.apis.memory import Memory
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.shields import Shields
|
||||||
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
|
@ -15,6 +27,28 @@ from llama_stack.distribution.distribution import (
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
|
def api_protocol_map() -> Dict[Api, Any]:
|
||||||
|
return {
|
||||||
|
Api.agents: Agents,
|
||||||
|
Api.inference: Inference,
|
||||||
|
Api.inspect: Inspect,
|
||||||
|
Api.memory: Memory,
|
||||||
|
Api.memory_banks: MemoryBanks,
|
||||||
|
Api.models: Models,
|
||||||
|
Api.safety: Safety,
|
||||||
|
Api.shields: Shields,
|
||||||
|
Api.telemetry: Telemetry,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def additional_protocols_map() -> Dict[Api, Any]:
|
||||||
|
return {
|
||||||
|
Api.inference: ModelsProtocolPrivate,
|
||||||
|
Api.memory: MemoryBanksProtocolPrivate,
|
||||||
|
Api.safety: ShieldsProtocolPrivate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
||||||
class ProviderWithSpec(Provider):
|
class ProviderWithSpec(Provider):
|
||||||
spec: ProviderSpec
|
spec: ProviderSpec
|
||||||
|
@ -73,17 +107,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
|
|
||||||
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
||||||
|
|
||||||
inner_deps = []
|
|
||||||
registry = getattr(run_config, info.routing_table_api.value)
|
|
||||||
for entry in registry:
|
|
||||||
if entry.provider_id not in available_providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
provider = available_providers[entry.provider_id]
|
|
||||||
inner_deps.extend(provider.spec.api_dependencies)
|
|
||||||
|
|
||||||
providers_with_specs[info.routing_table_api.value] = {
|
providers_with_specs[info.routing_table_api.value] = {
|
||||||
"__builtin__": ProviderWithSpec(
|
"__builtin__": ProviderWithSpec(
|
||||||
provider_id="__builtin__",
|
provider_id="__builtin__",
|
||||||
|
@ -92,13 +115,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
spec=RoutingTableProviderSpec(
|
spec=RoutingTableProviderSpec(
|
||||||
api=info.routing_table_api,
|
api=info.routing_table_api,
|
||||||
router_api=info.router_api,
|
router_api=info.router_api,
|
||||||
registry=registry,
|
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
api_dependencies=inner_deps,
|
api_dependencies=[],
|
||||||
deps__=(
|
deps__=([f"inner-{info.router_api.value}"]),
|
||||||
[x.value for x in inner_deps]
|
|
||||||
+ [f"inner-{info.router_api.value}"]
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -212,6 +231,9 @@ async def instantiate_provider(
|
||||||
deps: Dict[str, Any],
|
deps: Dict[str, Any],
|
||||||
inner_impls: Dict[str, Any],
|
inner_impls: Dict[str, Any],
|
||||||
):
|
):
|
||||||
|
protocols = api_protocol_map()
|
||||||
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
|
@ -234,7 +256,7 @@ async def instantiate_provider(
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
|
args = [provider_spec.api, inner_impls, deps]
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
|
@ -247,4 +269,55 @@ async def instantiate_provider(
|
||||||
impl.__provider_id__ = provider.provider_id
|
impl.__provider_id__ = provider.provider_id
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
|
|
||||||
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
|
if (
|
||||||
|
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||||
|
and provider_spec.api in additional_protocols
|
||||||
|
):
|
||||||
|
additional_api = additional_protocols[provider_spec.api]
|
||||||
|
check_protocol_compliance(impl, additional_api)
|
||||||
|
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
|
missing_methods = []
|
||||||
|
|
||||||
|
mro = type(obj).__mro__
|
||||||
|
for name, value in inspect.getmembers(protocol):
|
||||||
|
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||||
|
if not hasattr(obj, name):
|
||||||
|
missing_methods.append((name, "missing"))
|
||||||
|
elif not callable(getattr(obj, name)):
|
||||||
|
missing_methods.append((name, "not_callable"))
|
||||||
|
else:
|
||||||
|
# Check if the method signatures are compatible
|
||||||
|
obj_method = getattr(obj, name)
|
||||||
|
proto_sig = inspect.signature(value)
|
||||||
|
obj_sig = inspect.signature(obj_method)
|
||||||
|
|
||||||
|
proto_params = set(proto_sig.parameters)
|
||||||
|
proto_params.discard("self")
|
||||||
|
obj_params = set(obj_sig.parameters)
|
||||||
|
obj_params.discard("self")
|
||||||
|
if not (proto_params <= obj_params):
|
||||||
|
print(
|
||||||
|
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||||
|
)
|
||||||
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
|
else:
|
||||||
|
# Check if the method is actually implemented in the class
|
||||||
|
method_owner = next(
|
||||||
|
(cls for cls in mro if name in cls.__dict__), None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
method_owner is None
|
||||||
|
or method_owner.__name__ == protocol.__name__
|
||||||
|
):
|
||||||
|
missing_methods.append((name, "not_actually_implemented"))
|
||||||
|
|
||||||
|
if missing_methods:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||||
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from .routing_tables import (
|
from .routing_tables import (
|
||||||
|
@ -16,7 +16,6 @@ from .routing_tables import (
|
||||||
|
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
api: Api,
|
api: Api,
|
||||||
registry: List[RoutableObject],
|
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||||
_deps,
|
_deps,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
@ -28,7 +27,7 @@ async def get_routing_table_impl(
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_tables[api.value](registry, impls_by_provider_id)
|
impl = api_to_tables[api.value](impls_by_provider_id)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
@ -29,115 +29,145 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
await p.register_memory_bank(obj)
|
await p.register_memory_bank(obj)
|
||||||
|
|
||||||
|
|
||||||
|
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
# TODO: this routing table maintains state in memory purely. We need to
|
# TODO: this routing table maintains state in memory purely. We need to
|
||||||
# add persistence to it when we add dynamic registration of objects.
|
# add persistence to it when we add dynamic registration of objects.
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
registry: List[RoutableObject],
|
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||||
) -> None:
|
) -> None:
|
||||||
for obj in registry:
|
|
||||||
if obj.provider_id not in impls_by_provider_id:
|
|
||||||
print(f"{impls_by_provider_id=}")
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
self.registry = registry
|
|
||||||
|
|
||||||
for p in self.impls_by_provider_id.values():
|
async def initialize(self) -> None:
|
||||||
|
self.registry: Registry = {}
|
||||||
|
|
||||||
|
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
|
||||||
|
for obj in objs:
|
||||||
|
if obj.identifier not in self.registry:
|
||||||
|
self.registry[obj.identifier] = []
|
||||||
|
|
||||||
|
self.registry[obj.identifier].append(obj)
|
||||||
|
|
||||||
|
for pid, p in self.impls_by_provider_id.items():
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
p.model_store = self
|
p.model_store = self
|
||||||
|
models = await p.list_models()
|
||||||
|
add_objects(
|
||||||
|
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
|
||||||
|
)
|
||||||
|
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
p.shield_store = self
|
p.shield_store = self
|
||||||
|
shields = await p.list_shields()
|
||||||
|
add_objects(
|
||||||
|
[
|
||||||
|
ShieldDefWithProvider(**s.dict(), provider_id=pid)
|
||||||
|
for s in shields
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
p.memory_bank_store = self
|
p.memory_bank_store = self
|
||||||
|
memory_banks = await p.list_memory_banks()
|
||||||
|
|
||||||
self.routing_key_to_object = {}
|
# do in-memory updates due to pesky Annotated unions
|
||||||
for obj in self.registry:
|
for m in memory_banks:
|
||||||
self.routing_key_to_object[obj.identifier] = obj
|
m.provider_id = pid
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
add_objects(memory_banks)
|
||||||
for obj in self.registry:
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
|
||||||
await register_object_with_provider(obj, p)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Any:
|
def get_provider_impl(
|
||||||
if routing_key not in self.routing_key_to_object:
|
self, routing_key: str, provider_id: Optional[str] = None
|
||||||
|
) -> Any:
|
||||||
|
if routing_key not in self.registry:
|
||||||
raise ValueError(f"`{routing_key}` not registered")
|
raise ValueError(f"`{routing_key}` not registered")
|
||||||
|
|
||||||
obj = self.routing_key_to_object[routing_key]
|
objs = self.registry[routing_key]
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
for obj in objs:
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
if not provider_id or provider_id == obj.provider_id:
|
||||||
|
|
||||||
return self.impls_by_provider_id[obj.provider_id]
|
return self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
for obj in self.registry:
|
|
||||||
if obj.identifier == identifier:
|
def get_object_by_identifier(
|
||||||
return obj
|
self, identifier: str
|
||||||
|
) -> Optional[RoutableObjectWithProvider]:
|
||||||
|
objs = self.registry.get(identifier, [])
|
||||||
|
if not objs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObject):
|
# kind of ill-defined behavior here, but we'll just return the first one
|
||||||
if obj.identifier in self.routing_key_to_object:
|
return objs[0]
|
||||||
print(f"`{obj.identifier}` is already registered")
|
|
||||||
|
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||||
|
entries = self.registry.get(obj.identifier, [])
|
||||||
|
for entry in entries:
|
||||||
|
if entry.provider_id == obj.provider_id:
|
||||||
|
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not obj.provider_id:
|
|
||||||
provider_ids = list(self.impls_by_provider_id.keys())
|
|
||||||
if not provider_ids:
|
|
||||||
raise ValueError("No providers found")
|
|
||||||
|
|
||||||
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
|
|
||||||
obj.provider_id = provider_ids[0]
|
|
||||||
else:
|
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
await register_object_with_provider(obj, p)
|
await register_object_with_provider(obj, p)
|
||||||
|
|
||||||
self.routing_key_to_object[obj.identifier] = obj
|
if obj.identifier not in self.registry:
|
||||||
self.registry.append(obj)
|
self.registry[obj.identifier] = []
|
||||||
|
self.registry[obj.identifier].append(obj)
|
||||||
|
|
||||||
# TODO: persist this to a store
|
# TODO: persist this to a store
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def list_models(self) -> List[ModelDef]:
|
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||||
return self.registry
|
objects = []
|
||||||
|
for objs in self.registry.values():
|
||||||
|
objects.extend(objs)
|
||||||
|
return objects
|
||||||
|
|
||||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
||||||
return self.get_object_by_identifier(identifier)
|
return self.get_object_by_identifier(identifier)
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDefWithProvider) -> None:
|
||||||
await self.register_object(model)
|
await self.register_object(model)
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
return self.registry
|
objects = []
|
||||||
|
for objs in self.registry.values():
|
||||||
|
objects.extend(objs)
|
||||||
|
return objects
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
|
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||||
return self.get_object_by_identifier(shield_type)
|
return self.get_object_by_identifier(shield_type)
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None:
|
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||||
await self.register_object(shield)
|
await self.register_object(shield)
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||||
return self.registry
|
objects = []
|
||||||
|
for objs in self.registry.values():
|
||||||
|
objects.extend(objs)
|
||||||
|
return objects
|
||||||
|
|
||||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
async def get_memory_bank(
|
||||||
|
self, identifier: str
|
||||||
|
) -> Optional[MemoryBankDefWithProvider]:
|
||||||
return self.get_object_by_identifier(identifier)
|
return self.get_object_by_identifier(identifier)
|
||||||
|
|
||||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
async def register_memory_bank(
|
||||||
await self.register_object(bank)
|
self, memory_bank: MemoryBankDefWithProvider
|
||||||
|
) -> None:
|
||||||
|
await self.register_object(memory_bank)
|
||||||
|
|
|
@ -9,15 +9,7 @@ from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
from llama_stack.apis.inference import Inference
|
|
||||||
from llama_stack.apis.inspect import Inspect
|
|
||||||
from llama_stack.apis.memory import Memory
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.safety import Safety
|
|
||||||
from llama_stack.apis.shields import Shields
|
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
@ -31,18 +23,7 @@ class ApiEndpoint(BaseModel):
|
||||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = {
|
protocols = api_protocol_map()
|
||||||
Api.inference: Inference,
|
|
||||||
Api.safety: Safety,
|
|
||||||
Api.agents: Agents,
|
|
||||||
Api.memory: Memory,
|
|
||||||
Api.telemetry: Telemetry,
|
|
||||||
Api.models: Models,
|
|
||||||
Api.shields: Shields,
|
|
||||||
Api.memory_banks: MemoryBanks,
|
|
||||||
Api.inspect: Inspect,
|
|
||||||
}
|
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
endpoints = []
|
endpoints = []
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
|
@ -121,3 +121,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**options,
|
**options,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -35,7 +36,7 @@ OLLAMA_SUPPORTED_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(Inference):
|
class OllamaInferenceAdapter(Inference, Models):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
self.url = url
|
self.url = url
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
|
@ -6,14 +6,18 @@
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -30,26 +34,47 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _HfAdapter(Inference):
|
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
client: AsyncInferenceClient
|
client: AsyncInferenceClient
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
model_id: str
|
model_id: str
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
self.huggingface_repo_to_llama_model_id = {
|
||||||
|
model.huggingface_repo: model.descriptor()
|
||||||
|
for model in all_registered_models()
|
||||||
|
if model.huggingface_repo
|
||||||
|
}
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
resolved_model = resolve_model(model.identifier)
|
raise ValueError("Model registration is not supported for HuggingFace models")
|
||||||
if resolved_model is None:
|
|
||||||
raise ValueError(f"Unknown model: {model.identifier}")
|
|
||||||
|
|
||||||
if not resolved_model.huggingface_repo:
|
async def list_models(self) -> List[ModelDef]:
|
||||||
raise ValueError(
|
repo = self.model_id
|
||||||
f"Model {model.identifier} does not have a HuggingFace repo"
|
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||||
|
return [
|
||||||
|
ModelDef(
|
||||||
|
identifier=identifier,
|
||||||
|
llama_model=identifier,
|
||||||
|
metadata={
|
||||||
|
"huggingface_repo": repo,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
if self.model_id != resolved_model.huggingface_repo:
|
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||||
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
|
model = self.huggingface_repo_to_llama_model_id.get(self.model_id)
|
||||||
|
if model != identifier:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ModelDef(
|
||||||
|
identifier=model,
|
||||||
|
llama_model=model,
|
||||||
|
metadata={
|
||||||
|
"huggingface_repo": self.model_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -145,6 +170,13 @@ class _HfAdapter(Inference):
|
||||||
**options,
|
**options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
|
|
|
@ -134,3 +134,10 @@ class TogetherInferenceAdapter(
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**get_sampling_options(request),
|
**get_sampling_options(request),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -10,6 +10,11 @@ from typing import Any, List, Optional, Protocol
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBankDef
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelDef
|
||||||
|
from llama_stack.apis.shields import ShieldDef
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Api(Enum):
|
class Api(Enum):
|
||||||
|
@ -28,6 +33,30 @@ class Api(Enum):
|
||||||
inspect = "inspect"
|
inspect = "inspect"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsProtocolPrivate(Protocol):
|
||||||
|
async def list_models(self) -> List[ModelDef]: ...
|
||||||
|
|
||||||
|
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
||||||
|
|
||||||
|
async def register_model(self, model: ModelDef) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldsProtocolPrivate(Protocol):
|
||||||
|
async def list_shields(self) -> List[ShieldDef]: ...
|
||||||
|
|
||||||
|
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...
|
||||||
|
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBanksProtocolPrivate(Protocol):
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||||
|
|
||||||
|
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
||||||
|
|
||||||
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProviderSpec(BaseModel):
|
class ProviderSpec(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
|
|
|
@ -12,6 +12,7 @@ from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
)
|
)
|
||||||
|
@ -24,7 +25,7 @@ from .model_parallel import LlamaModelParallelGenerator
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference):
|
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
|
@ -39,14 +40,38 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
if model.identifier != self.model.descriptor():
|
raise ValueError("Dynamic model registration is not supported")
|
||||||
raise RuntimeError(
|
|
||||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
async def list_models(self) -> List[ModelDef]:
|
||||||
|
return [
|
||||||
|
ModelDef(
|
||||||
|
identifier=self.model.descriptor(),
|
||||||
|
llama_model=self.model.descriptor(),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||||
|
if self.model.descriptor() != identifier:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ModelDef(
|
||||||
|
identifier=self.model.descriptor(),
|
||||||
|
llama_model=self.model.descriptor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
content: InterleavedTextMedia,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -255,3 +280,10 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -15,6 +15,8 @@ from numpy.typing import NDArray
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
|
@ -61,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory):
|
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -83,6 +85,16 @@ class FaissMemoryImpl(Memory):
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
|
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||||
|
banks = await self.list_memory_banks()
|
||||||
|
for bank in banks:
|
||||||
|
if bank.identifier == identifier:
|
||||||
|
return bank
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
# ```bash
|
# ```bash
|
||||||
# PROVIDER_ID=<your_provider> \
|
# PROVIDER_ID=<your_provider> \
|
||||||
# PROVIDER_CONFIG=provider_config.yaml \
|
# PROVIDER_CONFIG=provider_config.yaml \
|
||||||
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
|
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
|
||||||
# --tb=short --disable-warnings
|
# --tb=short --disable-warnings
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
|
||||||
scope="session",
|
scope="session",
|
||||||
params=[
|
params=[
|
||||||
{"model": Llama_8B},
|
{"model": Llama_8B},
|
||||||
{"model": Llama_3B},
|
# {"model": Llama_3B},
|
||||||
],
|
],
|
||||||
ids=lambda d: d["model"],
|
ids=lambda d: d["model"],
|
||||||
)
|
)
|
||||||
|
@ -64,16 +64,11 @@ async def inference_settings(request):
|
||||||
model = request.param["model"]
|
model = request.param["model"]
|
||||||
impls = await resolve_impls_for_test(
|
impls = await resolve_impls_for_test(
|
||||||
Api.inference,
|
Api.inference,
|
||||||
models=[
|
|
||||||
ModelDef(
|
|
||||||
identifier=model,
|
|
||||||
llama_model=model,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"impl": impls[Api.inference],
|
"impl": impls[Api.inference],
|
||||||
|
"models_impl": impls[Api.models],
|
||||||
"common_params": {
|
"common_params": {
|
||||||
"model": model,
|
"model": model,
|
||||||
"tool_choice": ToolChoice.auto,
|
"tool_choice": ToolChoice.auto,
|
||||||
|
@ -108,6 +103,25 @@ def sample_tool_definition():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_list(inference_settings):
|
||||||
|
params = inference_settings["common_params"]
|
||||||
|
models_impl = inference_settings["models_impl"]
|
||||||
|
response = await models_impl.list_models()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) >= 1
|
||||||
|
assert all(isinstance(model, ModelDefWithProvider) for model in response)
|
||||||
|
|
||||||
|
model_def = None
|
||||||
|
for model in response:
|
||||||
|
if model.identifier == params["model"]:
|
||||||
|
model_def = model
|
||||||
|
break
|
||||||
|
|
||||||
|
assert model_def is not None
|
||||||
|
assert model_def.identifier == params["model"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||||
inference_impl = inference_settings["impl"]
|
inference_impl = inference_settings["impl"]
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
@ -30,12 +31,14 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def memory_impl():
|
async def memory_settings():
|
||||||
impls = await resolve_impls_for_test(
|
impls = await resolve_impls_for_test(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
memory_banks=[],
|
|
||||||
)
|
)
|
||||||
return impls[Api.memory]
|
return {
|
||||||
|
"memory_impl": impls[Api.memory],
|
||||||
|
"memory_banks_impl": impls[Api.memory_banks],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -64,23 +67,35 @@ def sample_documents():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(memory_impl: Memory):
|
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
bank = VectorMemoryBankDef(
|
bank = VectorMemoryBankDef(
|
||||||
identifier="test_bank",
|
identifier="test_bank",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
|
provider_id=os.environ["PROVIDER_ID"],
|
||||||
)
|
)
|
||||||
|
|
||||||
await memory_impl.register_memory_bank(bank)
|
await banks_impl.register_memory_bank(bank)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(memory_impl, sample_documents):
|
async def test_banks_list(memory_settings):
|
||||||
|
banks_impl = memory_settings["memory_banks_impl"]
|
||||||
|
response = await banks_impl.list_memory_banks()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_documents(memory_settings, sample_documents):
|
||||||
|
memory_impl = memory_settings["memory_impl"]
|
||||||
|
banks_impl = memory_settings["memory_banks_impl"]
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
await register_memory_bank(memory_impl)
|
await register_memory_bank(banks_impl)
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
query1 = "programming language"
|
query1 = "programming language"
|
||||||
|
|
|
@ -18,9 +18,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
|
|
||||||
async def resolve_impls_for_test(
|
async def resolve_impls_for_test(
|
||||||
api: Api,
|
api: Api,
|
||||||
models: List[ModelDef] = None,
|
|
||||||
memory_banks: List[MemoryBankDef] = None,
|
|
||||||
shields: List[ShieldDef] = None,
|
|
||||||
):
|
):
|
||||||
if "PROVIDER_CONFIG" not in os.environ:
|
if "PROVIDER_CONFIG" not in os.environ:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -47,45 +44,11 @@ async def resolve_impls_for_test(
|
||||||
provider_id = provider["provider_id"]
|
provider_id = provider["provider_id"]
|
||||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||||
|
|
||||||
models = models or []
|
|
||||||
shields = shields or []
|
|
||||||
memory_banks = memory_banks or []
|
|
||||||
|
|
||||||
models = [
|
|
||||||
ModelDef(
|
|
||||||
**{
|
|
||||||
**m.dict(),
|
|
||||||
"provider_id": provider_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for m in models
|
|
||||||
]
|
|
||||||
shields = [
|
|
||||||
ShieldDef(
|
|
||||||
**{
|
|
||||||
**s.dict(),
|
|
||||||
"provider_id": provider_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for s in shields
|
|
||||||
]
|
|
||||||
memory_banks = [
|
|
||||||
MemoryBankDef(
|
|
||||||
**{
|
|
||||||
**m.dict(),
|
|
||||||
"provider_id": provider_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for m in memory_banks
|
|
||||||
]
|
|
||||||
run_config = dict(
|
run_config = dict(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
image_name="test-fixture",
|
image_name="test-fixture",
|
||||||
apis=[api],
|
apis=[api],
|
||||||
providers={api.value: [Provider(**provider)]},
|
providers={api.value: [Provider(**provider)]},
|
||||||
models=models,
|
|
||||||
memory_banks=memory_banks,
|
|
||||||
shields=shields,
|
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
impls = await resolve_impls_with_routing(run_config)
|
impls = await resolve_impls_with_routing(run_config)
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistryHelper:
|
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
|
|
||||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||||
|
@ -33,3 +33,15 @@ class ModelRegistryHelper:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def list_models(self) -> List[ModelDef]:
|
||||||
|
models = []
|
||||||
|
for llama_model, provider_model in self.stack_to_provider_models_map.items():
|
||||||
|
models.append(ModelDef(identifier=llama_model, llama_model=llama_model))
|
||||||
|
return models
|
||||||
|
|
||||||
|
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||||
|
if identifier not in self.stack_to_provider_models_map:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ModelDef(identifier=identifier, llama_model=identifier)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue