forked from phoenix-oss/llama-stack-mirror
Remove "routing_table" and "routing_key" concepts for the user (#201)
This PR makes several core changes to the developer experience surrounding Llama Stack. Background: PR #92 introduced the notion of "routing" to the Llama Stack. It introduces three object types: (1) models, (2) shields and (3) memory banks. Each of these objects can be associated with a distinct provider. So you can get model A to be inferenced locally while model B, C can be inference remotely (e.g.) However, this had a few drawbacks: you could not address the provider instances -- i.e., if you configured "meta-reference" with a given model, you could not assign an identifier to this instance which you could re-use later. the above meant that you could not register a "routing_key" (e.g. model) dynamically and say "please use this existing provider I have already configured" for a new model. the terms "routing_table" and "routing_key" were exposed directly to the user. in my view, this is way too much overhead for a new user (which almost everyone is.) people come to the stack wanting to do ML and encounter a completely unexpected term. What this PR does: This PR structures the run config with only a single prominent key: - providers Providers are instances of configured provider types. Here's an example which shows two instances of the remote::tgi provider which are serving two different models. providers: inference: - provider_id: foo provider_type: remote::tgi config: { ... } - provider_id: bar provider_type: remote::tgi config: { ... } Secondly, the PR adds dynamic registration of { models | shields | memory_banks } to the API surface. The distribution still acts like a "routing table" (as previously) except that it asks the backing providers for a listing of these objects. For example it asks a TGI or Ollama inference adapter what models it is serving. Only the models that are being actually served can be requested by the user for inference. Otherwise, the Stack server will throw an error. When dynamically registering these objects, you can use the provider IDs shown above. Info about providers can be obtained using the Api.inspect set of endpoints (/providers, /routes, etc.) The above examples shows the correspondence between inference providers and models registry items. Things work similarly for the safety <=> shields and memory <=> memory_banks pairs. Registry: This PR also makes it so that Providers need to implement additional methods for registering and listing objects. For example, each Inference provider is now expected to implement the ModelsProtocolPrivate protocol (naming is not great!) which consists of two methods register_model list_models The goal is to inform the provider that a certain model needs to be supported so the provider can make any relevant backend changes if needed (or throw an error if the model cannot be supported.) There are many other cleanups included some of which are detailed in a follow-up comment.
This commit is contained in:
parent
8c3010553f
commit
6bb57e72a7
93 changed files with 4697 additions and 4457 deletions
|
@ -3,189 +3,182 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import textwrap
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_models.sku_list import (
|
||||
llama3_1_family,
|
||||
llama3_2_family,
|
||||
llama3_family,
|
||||
resolve_model,
|
||||
safety_models,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.validation import Validator
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.memory.memory import MemoryBankType
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
stack_apis,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.impls.meta_reference.safety.config import (
|
||||
MetaReferenceShieldType,
|
||||
)
|
||||
|
||||
|
||||
ALLOWED_MODELS = (
|
||||
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
||||
)
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
routing_key: str
|
||||
config: config_class
|
||||
def configure_single_provider(
|
||||
registry: Dict[str, ProviderSpec], provider: Provider
|
||||
) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
if provider.config:
|
||||
existing = config_type(**provider.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
|
||||
return BaseModelWithConfig
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
|
||||
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
|
||||
"""Get corresponding builtin APIs given provider backed APIs"""
|
||||
res = []
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in provider_backed_apis:
|
||||
res.append(inf.routing_table_api.value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
config: StackRunConfig, build_spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||
# append the bulitin routing APIs
|
||||
apis += get_builtin_apis(apis)
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
router_api2builtin_api = {
|
||||
inf.router_api.value: inf.routing_table_api.value
|
||||
for inf in builtin_automatically_routed_apis()
|
||||
}
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||
we need to configure the providers (implementations) you want to use for these APIs.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = get_provider_registry()
|
||||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
|
||||
# configure simple case for with non-routing providers to api_providers
|
||||
for api_str in spec.providers.keys():
|
||||
if api_str not in apis:
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
if api not in provider_registry:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
|
||||
p = spec.providers[api_str]
|
||||
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
|
||||
|
||||
if isinstance(p, list):
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
|
||||
"yellow",
|
||||
f"Re-configuring existing providers for API `{api_str}`...",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p = p[0]
|
||||
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
provider_config = config.api_providers.get(api_str)
|
||||
if provider_config:
|
||||
existing = config_type(**provider_config.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
|
||||
if api_str in router_api2builtin_api:
|
||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
||||
routing_entries = []
|
||||
if api_str == "inference":
|
||||
if hasattr(cfg, "model"):
|
||||
routing_key = cfg.model
|
||||
else:
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported model your provider has for inference: ",
|
||||
default="Llama3.1-8B-Instruct",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: resolve_model(x) is not None,
|
||||
error_message="Model must be: {}".format(
|
||||
[x.descriptor() for x in ALLOWED_MODELS]
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
print(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(provider_registry[api], p)
|
||||
)
|
||||
|
||||
if api_str == "safety":
|
||||
# TODO: add support for other safety providers, and simplify safety provider config
|
||||
if p == "meta-reference":
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=[s.value for s in MetaReferenceShieldType],
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
if api_str == "memory":
|
||||
bank_types = list([x.value for x in MemoryBankType])
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported memory bank type your provider has for memory: ",
|
||||
default="vector",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in bank_types,
|
||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
||||
bank_types
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
config.routing_table[api_str] = routing_entries
|
||||
config.api_providers[api_str] = PlaceholderProviderConfig(
|
||||
providers=p if isinstance(p, list) else [p]
|
||||
)
|
||||
print("")
|
||||
else:
|
||||
config.api_providers[api_str] = GenericProviderConfig(
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
# we are newly configuring this API
|
||||
plist = build_spec.providers.get(api_str, [])
|
||||
plist = plist if isinstance(plist, list) else [plist]
|
||||
|
||||
print("")
|
||||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
print(f"> Configuring provider `({provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{provider_type}-{i:02d}"
|
||||
if len(plist) > 1
|
||||
else provider_type
|
||||
),
|
||||
provider_type=provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
)
|
||||
print("")
|
||||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def upgrade_from_routing_table(
|
||||
config_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{entry['provider_type']}-{i:02d}"
|
||||
if len(entries) > 1
|
||||
else entry["provider_type"]
|
||||
),
|
||||
provider_type=entry["provider_type"],
|
||||
config=entry["config"],
|
||||
)
|
||||
for i, entry in enumerate(entries)
|
||||
]
|
||||
|
||||
providers_by_api = {}
|
||||
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
if isinstance(provider, dict) and "provider_type" in provider:
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=f"{provider['provider_type']}",
|
||||
provider_type=provider["provider_type"],
|
||||
config=provider["config"],
|
||||
)
|
||||
]
|
||||
|
||||
config_dict["providers"] = providers_by_api
|
||||
|
||||
config_dict.pop("routing_table", None)
|
||||
config_dict.pop("api_providers", None)
|
||||
config_dict.pop("provider_map", None)
|
||||
|
||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
return StackRunConfig(**config_dict)
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
print("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
config_dict["built_at"] = datetime.now().isoformat()
|
||||
|
||||
return StackRunConfig(**config_dict)
|
||||
|
|
|
@ -11,28 +11,38 @@ from typing import Dict, List, Optional, Union
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
||||
|
||||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
RoutableObject = Union[
|
||||
ModelDef,
|
||||
ShieldDef,
|
||||
MemoryBankDef,
|
||||
]
|
||||
|
||||
RoutableObjectWithProvider = Union[
|
||||
ModelDefWithProvider,
|
||||
ShieldDefWithProvider,
|
||||
MemoryBankDefWithProvider,
|
||||
]
|
||||
|
||||
class RoutableProviderConfig(GenericProviderConfig):
|
||||
routing_key: RoutingKey
|
||||
|
||||
|
||||
class PlaceholderProviderConfig(BaseModel):
|
||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
||||
|
||||
providers: List[str]
|
||||
RoutedProtocol = Union[
|
||||
Inference,
|
||||
Safety,
|
||||
Memory,
|
||||
]
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
|
@ -53,18 +63,16 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
|
||||
|
||||
# Example: /models, /shields
|
||||
@json_schema_type
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
router_api: Api
|
||||
module: str
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
description: Optional[str] = Field(
|
||||
default="",
|
||||
|
@ -80,7 +88,12 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Provider(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
built_at: datetime
|
||||
|
@ -100,36 +113,20 @@ this could be just a hash
|
|||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
apis_to_serve: List[str] = Field(
|
||||
apis: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
|
||||
api_providers: Dict[
|
||||
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
|
||||
] = Field(
|
||||
providers: Dict[str, List[Provider]] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package.
|
||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
|
||||
E.g. The following is a ProviderRoutingEntry for models:
|
||||
- routing_key: Llama3.1-8B-Instruct
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
name: str
|
||||
|
|
|
@ -6,45 +6,58 @@
|
|||
|
||||
from typing import Dict, List
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_providers = get_provider_registry()
|
||||
for api, providers in all_providers.items():
|
||||
ret[api.value] = [
|
||||
for api, providers in run_config.providers.items():
|
||||
ret[api] = [
|
||||
ProviderInfo(
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
description="Passthrough" if is_passthrough(p) else "",
|
||||
)
|
||||
for p in providers.values()
|
||||
for p in providers
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
for api, endpoints in all_endpoints.items():
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
ret[api.value] = [
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
providers=[],
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
|
|
|
@ -4,146 +4,237 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.inspect import DistributionInspectImpl
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.memory: Memory,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
Api.telemetry: Telemetry,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: ModelsProtocolPrivate,
|
||||
Api.memory: MemoryBanksProtocolPrivate,
|
||||
Api.safety: ShieldsProtocolPrivate,
|
||||
}
|
||||
|
||||
|
||||
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
||||
class ProviderWithSpec(Provider):
|
||||
spec: ProviderSpec
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = get_provider_registry()
|
||||
specs = {}
|
||||
configs = {}
|
||||
all_api_providers = get_provider_registry()
|
||||
|
||||
for api_str, config in run_config.api_providers.items():
|
||||
api = Api(api_str)
|
||||
|
||||
# TODO: check that these APIs are not in the routing table part of the config
|
||||
providers = all_providers[api]
|
||||
|
||||
# skip checks for API whose provider config is specified in routing_table
|
||||
if isinstance(config, PlaceholderProviderConfig):
|
||||
continue
|
||||
|
||||
if config.provider_type not in providers:
|
||||
raise ValueError(
|
||||
f"Provider `{config.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[config.provider_type]
|
||||
configs[api] = config
|
||||
|
||||
apis_to_serve = run_config.apis_to_serve or set(
|
||||
list(specs.keys()) + list(run_config.routing_table.keys())
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
|
||||
providers_with_specs = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(
|
||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||
)
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in all_api_providers[api]:
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = all_api_providers[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.dict()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
source_api = info.routing_table_api
|
||||
|
||||
assert (
|
||||
source_api not in specs
|
||||
), f"Routing table API {source_api} specified in wrong place?"
|
||||
assert (
|
||||
info.router_api not in specs
|
||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
if info.router_api.value not in run_config.routing_table:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
||||
|
||||
routing_table = run_config.routing_table[info.router_api.value]
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__routing_table__",
|
||||
provider_type="__routing_table__",
|
||||
config={},
|
||||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
deps__=([f"inner-{info.router_api.value}"]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
providers = all_providers[info.router_api]
|
||||
providers_with_specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__autorouted__",
|
||||
provider_type="__autorouted__",
|
||||
config={},
|
||||
spec=AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
inner_specs = []
|
||||
inner_deps = []
|
||||
for rt_entry in routing_table:
|
||||
if rt_entry.provider_type not in providers:
|
||||
raise ValueError(
|
||||
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_type])
|
||||
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
|
||||
|
||||
specs[source_api] = RoutingTableProviderSpec(
|
||||
api=source_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
inner_specs=inner_specs,
|
||||
sorted_providers = topological_sort(
|
||||
{k: v.values() for k, v in providers_with_specs.items()}
|
||||
)
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
),
|
||||
),
|
||||
)
|
||||
configs[source_api] = routing_table
|
||||
|
||||
specs[info.router_api] = AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=source_api,
|
||||
api_dependencies=[source_api],
|
||||
)
|
||||
configs[info.router_api] = {}
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||
for spec in sorted_specs:
|
||||
print(f" {spec.api}: {spec.provider_type}")
|
||||
print("")
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, configs[api])
|
||||
|
||||
impls[api] = impl
|
||||
|
||||
impls[Api.inspect] = DistributionInspectImpl()
|
||||
specs[Api.inspect] = InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__distribution_builtin__",
|
||||
config_class="",
|
||||
module="",
|
||||
)
|
||||
|
||||
return impls, specs
|
||||
print(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
print(f" {api_str} => {provider.provider_id}")
|
||||
print("")
|
||||
|
||||
impls = {}
|
||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
|
||||
inner_impls = {}
|
||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[
|
||||
f"inner-{provider.spec.router_api.value}"
|
||||
]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||
by_id = {x.api: x for x in providers}
|
||||
def topological_sort(
|
||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||
) -> List[ProviderWithSpec]:
|
||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
deps = []
|
||||
for provider in providers:
|
||||
for dep in provider.spec.deps__:
|
||||
deps.append(dep)
|
||||
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
for dep in deps:
|
||||
if dep not in visited:
|
||||
dfs((dep, providers_with_specs[dep]), visited, stack)
|
||||
|
||||
stack.append(a.api)
|
||||
stack.append(api_str)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in providers:
|
||||
if a.api not in visited:
|
||||
dfs(a, visited, stack)
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
dfs((api_str, providers), visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
flattened = []
|
||||
for api_str in stack:
|
||||
for provider in providers_with_specs[api_str]:
|
||||
flattened.append((api_str, provider))
|
||||
return flattened
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||
inner_impls: Dict[str, Any],
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
args = []
|
||||
|
@ -153,9 +244,8 @@ async def instantiate_provider(
|
|||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
@ -165,31 +255,69 @@ async def instantiate_provider(
|
|||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
assert isinstance(provider_config, List)
|
||||
routing_table = provider_config
|
||||
|
||||
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in routing_table:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_type],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||
args = [provider_spec.api, inner_impls, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if (
|
||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
additional_api = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||
missing_methods = []
|
||||
|
||||
mro = type(obj).__mro__
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||
if not hasattr(obj, name):
|
||||
missing_methods.append((name, "missing"))
|
||||
elif not callable(getattr(obj, name)):
|
||||
missing_methods.append((name, "not_callable"))
|
||||
else:
|
||||
# Check if the method signatures are compatible
|
||||
obj_method = getattr(obj, name)
|
||||
proto_sig = inspect.signature(value)
|
||||
obj_sig = inspect.signature(obj_method)
|
||||
|
||||
proto_params = set(proto_sig.parameters)
|
||||
proto_params.discard("self")
|
||||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
print(
|
||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||
)
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next(
|
||||
(cls for cls in mro if name in cls.__dict__), None
|
||||
)
|
||||
if (
|
||||
method_owner is None
|
||||
or method_owner.__name__ == protocol.__name__
|
||||
):
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
raise ValueError(
|
||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||
)
|
||||
|
|
|
@ -4,23 +4,21 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
) -> Any:
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
|
@ -29,7 +27,7 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
||||
impl = api_to_tables[api.value](impls_by_provider_id)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
|
@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
"""Routes to an provider based on the memory bank type"""
|
||||
"""Routes to an provider based on the memory bank identifier"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -29,32 +28,8 @@ class MemoryRouter(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.routing_table.get_provider_impl(bank_type)
|
||||
if not provider:
|
||||
raise ValueError(f"Could not find provider for {bank_type}")
|
||||
return provider
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_type = config.type
|
||||
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
|
||||
name, config, url
|
||||
)
|
||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
provider = self.get_provider_from_bank_id(bank_id)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
await self.routing_table.register_memory_bank(memory_bank)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -62,7 +37,7 @@ class MemoryRouter(Memory):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
return await self.get_provider_from_bank_id(bank_id).insert_documents(
|
||||
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
|
||||
bank_id, documents, ttl_seconds
|
||||
)
|
||||
|
||||
|
@ -72,7 +47,7 @@ class MemoryRouter(Memory):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
return await self.get_provider_from_bank_id(bank_id).query_documents(
|
||||
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
||||
bank_id, query, params
|
||||
)
|
||||
|
||||
|
@ -92,7 +67,10 @@ class InferenceRouter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -113,27 +91,32 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
**params
|
||||
):
|
||||
yield chunk
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.chat_completion(**params))
|
||||
else:
|
||||
return provider.chat_completion(**params)
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
) -> AsyncGenerator:
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
params = dict(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.completion(**params))
|
||||
else:
|
||||
return provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -159,6 +142,9 @@ class SafetyRouter(Safety):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.routing_table.register_shield(shield)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
|
|
|
@ -4,9 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
@ -16,129 +15,159 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
return p.__provider_spec__.api
|
||||
|
||||
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
# TODO: this routing table maintains state in memory purely. We need to
|
||||
# add persistence to it when we add dynamic registration of objects.
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[RoutingKey, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
) -> None:
|
||||
self.unique_providers = []
|
||||
self.providers = {}
|
||||
self.routing_keys = []
|
||||
|
||||
for key, impl in inner_impls:
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
self.unique_providers.append((keys, impl))
|
||||
|
||||
for k in keys:
|
||||
if k in self.providers:
|
||||
raise ValueError(f"Duplicate routing key {k}")
|
||||
self.providers[k] = impl
|
||||
self.routing_keys.append(k)
|
||||
|
||||
self.routing_table_config = routing_table_config
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for keys, p in self.unique_providers:
|
||||
spec = p.__provider_spec__
|
||||
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
||||
continue
|
||||
self.registry: Registry = {}
|
||||
|
||||
await p.validate_routing_keys(keys)
|
||||
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
|
||||
for obj in objs:
|
||||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
|
||||
self.registry[obj.identifier].append(obj)
|
||||
|
||||
for pid, p in self.impls_by_provider_id.items():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
models = await p.list_models()
|
||||
add_objects(
|
||||
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
|
||||
)
|
||||
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
shields = await p.list_shields()
|
||||
add_objects(
|
||||
[
|
||||
ShieldDefWithProvider(**s.dict(), provider_id=pid)
|
||||
for s in shields
|
||||
]
|
||||
)
|
||||
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
memory_banks = await p.list_memory_banks()
|
||||
|
||||
# do in-memory updates due to pesky Annotated unions
|
||||
for m in memory_banks:
|
||||
m.provider_id = pid
|
||||
|
||||
add_objects(memory_banks)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for _, p in self.unique_providers:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.providers:
|
||||
raise ValueError(f"Could not find provider for {routing_key}")
|
||||
return self.providers[routing_key]
|
||||
def get_provider_impl(
|
||||
self, routing_key: str, provider_id: Optional[str] = None
|
||||
) -> Any:
|
||||
if routing_key not in self.registry:
|
||||
raise ValueError(f"`{routing_key}` not registered")
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
objs = self.registry[routing_key]
|
||||
for obj in objs:
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == routing_key:
|
||||
return entry
|
||||
return None
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
def get_object_by_identifier(
|
||||
self, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
objs = self.registry.get(identifier, [])
|
||||
if not objs:
|
||||
return None
|
||||
|
||||
# kind of ill-defined behavior here, but we'll just return the first one
|
||||
return objs[0]
|
||||
|
||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||
entries = self.registry.get(obj.identifier, [])
|
||||
for entry in entries:
|
||||
if entry.provider_id == obj.provider_id:
|
||||
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
|
||||
return
|
||||
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
self.registry[obj.identifier].append(obj)
|
||||
|
||||
# TODO: persist this to a store
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
model_id = entry.routing_key
|
||||
specs.append(
|
||||
ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == core_model_id:
|
||||
return ModelServingSpec(
|
||||
llama_model=resolve_model(core_model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_model(self, model: ModelDefWithProvider) -> None:
|
||||
await self.register_object(model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
if isinstance(entry.routing_key, list):
|
||||
for k in entry.routing_key:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=k,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
else:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
return self.get_object_by_identifier(shield_type)
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == shield_type:
|
||||
return ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
await self.register_object(shield)
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
specs.append(
|
||||
MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
) -> Optional[MemoryBankDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == bank_type:
|
||||
return MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
) -> None:
|
||||
await self.register_object(memory_bank)
|
||||
|
|
|
@ -9,15 +9,7 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
@ -31,18 +23,7 @@ class ApiEndpoint(BaseModel):
|
|||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = {
|
||||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agents: Agents,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
Api.models: Models,
|
||||
Api.shields: Shields,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.inspect: Inspect,
|
||||
}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
|
|
@ -5,18 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -29,6 +26,8 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
|
@ -43,20 +42,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
|
@ -169,11 +154,20 @@ async def passthrough(
|
|||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
def handle_sigint(app, *args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
|
||||
async def run_shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
print(f"Shutting down {impl}")
|
||||
await impl.shutdown()
|
||||
|
||||
asyncio.run(run_shutdown())
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
|
||||
loop.stop()
|
||||
|
||||
|
||||
|
@ -181,7 +175,10 @@ def handle_sigint(*args, **kwargs):
|
|||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
yield
|
||||
|
||||
print("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
|
@ -193,65 +190,59 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
if inspect.iscoroutine(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
if is_streaming:
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(**kwargs)
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
|
@ -285,29 +276,28 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
impls = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
if config.apis_to_serve:
|
||||
apis_to_serve = set(config.apis_to_serve)
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
apis_to_serve.add(Api.inspect)
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
provider_spec = specs[api]
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
):
|
||||
if is_passthrough(impl.__provider_spec__):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
|
@ -337,7 +327,9 @@ def main(
|
|||
print("")
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
built_at: '2024-09-30T09:04:30.533391'
|
||||
version: '2'
|
||||
built_at: '2024-10-08T17:42:07.505267'
|
||||
image_name: local-cpu
|
||||
docker_image: local-cpu
|
||||
conda_env: null
|
||||
apis_to_serve:
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- models
|
||||
|
@ -10,40 +11,32 @@ apis_to_serve:
|
|||
- safety
|
||||
- shields
|
||||
- memory_banks
|
||||
api_providers:
|
||||
providers:
|
||||
inference:
|
||||
providers:
|
||||
- remote::ollama
|
||||
- provider_id: remote::ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 6000
|
||||
safety:
|
||||
providers:
|
||||
- meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
memory:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
memory:
|
||||
providers:
|
||||
- meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 6000
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||
memory:
|
||||
- provider_type: meta-reference
|
||||
config: {}
|
||||
routing_key: vector
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
built_at: '2024-09-30T09:00:56.693751'
|
||||
version: '2'
|
||||
built_at: '2024-10-08T17:42:33.690666'
|
||||
image_name: local-gpu
|
||||
docker_image: local-gpu
|
||||
conda_env: null
|
||||
apis_to_serve:
|
||||
apis:
|
||||
- memory
|
||||
- inference
|
||||
- agents
|
||||
|
@ -10,43 +11,35 @@ apis_to_serve:
|
|||
- safety
|
||||
- models
|
||||
- memory_banks
|
||||
api_providers:
|
||||
providers:
|
||||
inference:
|
||||
providers:
|
||||
- meta-reference
|
||||
safety:
|
||||
providers:
|
||||
- meta-reference
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
memory:
|
||||
providers:
|
||||
- meta-reference
|
||||
telemetry:
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||
memory:
|
||||
- provider_type: meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_key: vector
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue