mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-09 03:19:20 +00:00
rebase on top of registry
This commit is contained in:
commit
6abef716dd
107 changed files with 4813 additions and 3587 deletions
|
|
@ -8,15 +8,16 @@ from enum import Enum
|
|||
from typing import List, Optional
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
from pydantic import BaseModel
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
|
||||
|
||||
|
|
@ -95,6 +96,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
build_config.name,
|
||||
package_deps.docker_image,
|
||||
str(build_file_path),
|
||||
str(BUILDS_BASE_DIR / ImageType.docker.value),
|
||||
" ".join(deps),
|
||||
]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ if [ "$#" -lt 3 ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$3"
|
||||
special_pip_deps="$4"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ if [ "$#" -lt 4 ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$5"
|
||||
special_pip_deps="$6"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
|
@ -18,7 +18,8 @@ build_name="$1"
|
|||
image_name="llamastack-$build_name"
|
||||
docker_base=$2
|
||||
build_file_path=$3
|
||||
pip_dependencies=$4
|
||||
host_build_dir=$4
|
||||
pip_dependencies=$5
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
|
|
@ -33,7 +34,8 @@ REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
|
|||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
llama stack configure $build_file_path --output-dir $REPO_CONFIGS_DIR
|
||||
llama stack configure $build_file_path
|
||||
cp $host_build_dir/$build_name-run.yaml $REPO_CONFIGS_DIR
|
||||
|
||||
add_to_docker() {
|
||||
local input
|
||||
|
|
@ -132,6 +134,9 @@ fi
|
|||
|
||||
set -x
|
||||
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -rf $REPO_CONFIGS_DIR
|
||||
set +x
|
||||
|
||||
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
|
||||
|
|
|
|||
|
|
@ -3,171 +3,369 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import textwrap
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_models.sku_list import (
|
||||
llama3_1_family,
|
||||
llama3_2_family,
|
||||
llama3_family,
|
||||
resolve_model,
|
||||
safety_models,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.validation import Validator
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.memory.memory import MemoryBankType
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
stack_apis,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.impls.meta_reference.safety.config import (
|
||||
MetaReferenceShieldType,
|
||||
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
ALLOWED_MODELS = (
|
||||
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
||||
)
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
routing_key: str
|
||||
config: config_class
|
||||
def configure_single_provider(
|
||||
registry: Dict[str, ProviderSpec], provider: Provider
|
||||
) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
if provider.config:
|
||||
existing = config_type(**provider.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
|
||||
return BaseModelWithConfig
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
|
||||
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
|
||||
"""Get corresponding builtin APIs given provider backed APIs"""
|
||||
res = []
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in provider_backed_apis:
|
||||
res.append(inf.routing_table_api.value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
config: StackRunConfig, build_spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||
# append the bulitin routing APIs
|
||||
apis += get_builtin_apis(apis)
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
router_api2builtin_api = {
|
||||
inf.router_api.value: inf.routing_table_api.value
|
||||
for inf in builtin_automatically_routed_apis()
|
||||
}
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||
we need to configure the providers (implementations) you want to use for these APIs.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = get_provider_registry()
|
||||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
|
||||
# configure simple case for with non-routing providers to api_providers
|
||||
for api_str in spec.providers.keys():
|
||||
if api_str not in apis:
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
if api not in provider_registry:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
|
||||
p = spec.providers[api_str]
|
||||
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
|
||||
|
||||
if isinstance(p, list):
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
|
||||
"yellow",
|
||||
f"Re-configuring existing providers for API `{api_str}`...",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p = p[0]
|
||||
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
provider_config = config.api_providers.get(api_str)
|
||||
if provider_config:
|
||||
existing = config_type(**provider_config.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
|
||||
if api_str in router_api2builtin_api:
|
||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
||||
routing_entries = []
|
||||
if api_str == "inference":
|
||||
if hasattr(cfg, "model"):
|
||||
routing_key = cfg.model
|
||||
else:
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported model your provider has for inference: ",
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
print(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(provider_registry[api], p)
|
||||
)
|
||||
|
||||
if api_str == "safety":
|
||||
# TODO: add support for other safety providers, and simplify safety provider config
|
||||
if p == "meta-reference":
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=[s.value for s in MetaReferenceShieldType],
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
if api_str == "memory":
|
||||
bank_types = list([x.value for x in MemoryBankType])
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported memory bank type your provider has for memory: ",
|
||||
default="vector",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in bank_types,
|
||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
||||
bank_types
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
config.routing_table[api_str] = routing_entries
|
||||
config.api_providers[api_str] = PlaceholderProviderConfig(
|
||||
providers=p if isinstance(p, list) else [p]
|
||||
)
|
||||
print("")
|
||||
else:
|
||||
config.api_providers[api_str] = GenericProviderConfig(
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
# we are newly configuring this API
|
||||
plist = build_spec.providers.get(api_str, [])
|
||||
plist = plist if isinstance(plist, list) else [plist]
|
||||
|
||||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
print(f"> Configuring provider `({provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{provider_type}-{i:02d}"
|
||||
if len(plist) > 1
|
||||
else provider_type
|
||||
),
|
||||
provider_type=provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
)
|
||||
print("")
|
||||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
=========================================================================================
|
||||
Now let's configure the `objects` you will be serving via the stack. These are:
|
||||
|
||||
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
|
||||
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
|
||||
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
|
||||
|
||||
This wizard will guide you through setting up one of each of these objects. You can
|
||||
always add more later by editing the run.yaml file.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
object_types = {
|
||||
"models": (ModelDef, configure_models, "inference"),
|
||||
"shields": (ShieldDef, configure_shields, "safety"),
|
||||
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||
}
|
||||
safety_providers = config.providers.get("safety", [])
|
||||
|
||||
for otype, (odef, config_method, api_str) in object_types.items():
|
||||
existing_objects = getattr(config, otype)
|
||||
|
||||
if existing_objects:
|
||||
cprint(
|
||||
f"{len(existing_objects)} {otype} exist. Skipping...",
|
||||
"blue",
|
||||
attrs=["bold"],
|
||||
)
|
||||
updated_objects = existing_objects
|
||||
else:
|
||||
providers = config.providers.get(api_str, [])
|
||||
if not providers:
|
||||
updated_objects = []
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||
updated_objects = config_method(
|
||||
config.providers[api_str], safety_providers
|
||||
)
|
||||
|
||||
setattr(config, otype, updated_objects)
|
||||
print("")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
|
||||
if not safety_providers:
|
||||
return None
|
||||
|
||||
provider = safety_providers[0]
|
||||
assert provider.provider_type == "meta-reference"
|
||||
|
||||
cfg = provider.config["llama_guard_shield"]
|
||||
if not cfg:
|
||||
return None
|
||||
return cfg["model"]
|
||||
|
||||
|
||||
def configure_models(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[ModelDef]:
|
||||
model = prompt(
|
||||
"> Please enter the model you want to serve: ",
|
||||
default="Llama3.2-1B-Instruct",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: resolve_model(x) is not None,
|
||||
error_message="Model must be: {}".format(
|
||||
[x.descriptor() for x in ALLOWED_MODELS]
|
||||
),
|
||||
),
|
||||
)
|
||||
model = ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
provider_id=providers[0].provider_id,
|
||||
)
|
||||
|
||||
ret = [model]
|
||||
if llama_guard := get_llama_guard_model(safety_providers):
|
||||
ret.append(
|
||||
ModelDef(
|
||||
identifier=llama_guard,
|
||||
llama_model=llama_guard,
|
||||
provider_id=providers[0].provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def configure_shields(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[ShieldDef]:
|
||||
if get_llama_guard_model(safety_providers):
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier="llama_guard",
|
||||
type="llama_guard",
|
||||
provider_id=providers[0].provider_id,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def configure_memory_banks(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[MemoryBankDef]:
|
||||
bank_name = prompt(
|
||||
"> Please enter a name for your memory bank: ",
|
||||
default="my-memory-bank",
|
||||
)
|
||||
|
||||
return [
|
||||
VectorMemoryBankDef(
|
||||
identifier=bank_name,
|
||||
provider_id=providers[0].provider_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def upgrade_from_routing_table_to_registry(
|
||||
config_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{entry['provider_type']}-{i:02d}"
|
||||
if len(entries) > 1
|
||||
else entry["provider_type"]
|
||||
),
|
||||
provider_type=entry["provider_type"],
|
||||
config=entry["config"],
|
||||
)
|
||||
for i, entry in enumerate(entries)
|
||||
]
|
||||
|
||||
providers_by_api = {}
|
||||
models = []
|
||||
shields = []
|
||||
memory_banks = []
|
||||
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
||||
if api_str == "inference":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
models.append(
|
||||
ModelDef(
|
||||
identifier=key,
|
||||
provider_id=provider.provider_id,
|
||||
llama_model=key,
|
||||
)
|
||||
)
|
||||
elif api_str == "safety":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
shields.append(
|
||||
ShieldDef(
|
||||
identifier=key,
|
||||
type=ShieldType.llama_guard.value,
|
||||
provider_id=provider.provider_id,
|
||||
)
|
||||
)
|
||||
elif api_str == "memory":
|
||||
for entry, provider in zip(entries, providers):
|
||||
key = entry["routing_key"]
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
for key in keys:
|
||||
# we currently only support Vector memory banks so this is OK
|
||||
memory_banks.append(
|
||||
VectorMemoryBankDef(
|
||||
identifier=key,
|
||||
provider_id=provider.provider_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
)
|
||||
config_dict["models"] = models
|
||||
config_dict["shields"] = shields
|
||||
config_dict["memory_banks"] = memory_banks
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
if isinstance(provider, dict) and "provider_type" in provider:
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=f"{provider['provider_type']}",
|
||||
provider_type=provider["provider_type"],
|
||||
config=provider["config"],
|
||||
)
|
||||
]
|
||||
|
||||
config_dict["providers"] = providers_by_api
|
||||
|
||||
config_dict.pop("routing_table", None)
|
||||
config_dict.pop("api_providers", None)
|
||||
config_dict.pop("provider_map", None)
|
||||
|
||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
return StackRunConfig(**config_dict)
|
||||
|
||||
if "models" not in config_dict:
|
||||
print("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table_to_registry(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
config_dict["built_at"] = datetime.now().isoformat()
|
||||
|
||||
return StackRunConfig(**config_dict)
|
||||
|
|
|
|||
|
|
@ -11,28 +11,32 @@ from typing import Dict, List, Optional, Union
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
||||
|
||||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
RoutableObject = Union[
|
||||
ModelDef,
|
||||
ShieldDef,
|
||||
MemoryBankDef,
|
||||
]
|
||||
|
||||
|
||||
class RoutableProviderConfig(GenericProviderConfig):
|
||||
routing_key: RoutingKey
|
||||
|
||||
|
||||
class PlaceholderProviderConfig(BaseModel):
|
||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
||||
|
||||
providers: List[str]
|
||||
RoutedProtocol = Union[
|
||||
Inference,
|
||||
Safety,
|
||||
Memory,
|
||||
]
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
|
|
@ -53,18 +57,17 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
|
||||
|
||||
# Example: /models, /shields
|
||||
@json_schema_type
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
router_api: Api
|
||||
registry: List[RoutableObject]
|
||||
module: str
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
description: Optional[str] = Field(
|
||||
default="",
|
||||
|
|
@ -80,7 +83,12 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Provider(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
built_at: datetime
|
||||
|
|
@ -100,36 +108,39 @@ this could be just a hash
|
|||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
apis_to_serve: List[str] = Field(
|
||||
apis: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
|
||||
api_providers: Dict[
|
||||
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
|
||||
] = Field(
|
||||
providers: Dict[str, List[Provider]] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package.
|
||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
|
||||
E.g. The following is a ProviderRoutingEntry for models:
|
||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Meta-Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
""",
|
||||
models: List[ModelDef] = Field(
|
||||
description="""
|
||||
List of model definitions to serve. This list may get extended by
|
||||
/models/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
shields: List[ShieldDef] = Field(
|
||||
description="""
|
||||
List of shield definitions to serve. This list may get extended by
|
||||
/shields/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
memory_banks: List[MemoryBankDef] = Field(
|
||||
description="""
|
||||
List of memory bank definitions to serve. This list may get extended by
|
||||
/memory_banks/register API calls at runtime.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -6,45 +6,58 @@
|
|||
|
||||
from typing import Dict, List
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_providers = get_provider_registry()
|
||||
for api, providers in all_providers.items():
|
||||
ret[api.value] = [
|
||||
for api, providers in run_config.providers.items():
|
||||
ret[api] = [
|
||||
ProviderInfo(
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
description="Passthrough" if is_passthrough(p) else "",
|
||||
)
|
||||
for p in providers.values()
|
||||
for p in providers
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
for api, endpoints in all_endpoints.items():
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
ret[api.value] = [
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
providers=[],
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
|
|
|
|||
|
|
@ -13,138 +13,207 @@ from llama_stack.distribution.distribution import (
|
|||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.inspect import DistributionInspectImpl
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
||||
class ProviderWithSpec(Provider):
|
||||
spec: ProviderSpec
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = get_provider_registry()
|
||||
specs = {}
|
||||
configs = {}
|
||||
all_api_providers = get_provider_registry()
|
||||
|
||||
for api_str, config in run_config.api_providers.items():
|
||||
api = Api(api_str)
|
||||
|
||||
# TODO: check that these APIs are not in the routing table part of the config
|
||||
providers = all_providers[api]
|
||||
|
||||
# skip checks for API whose provider config is specified in routing_table
|
||||
if isinstance(config, PlaceholderProviderConfig):
|
||||
continue
|
||||
|
||||
if config.provider_type not in providers:
|
||||
raise ValueError(
|
||||
f"Provider `{config.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[config.provider_type]
|
||||
configs[api] = config
|
||||
|
||||
apis_to_serve = run_config.apis_to_serve or set(
|
||||
list(specs.keys()) + list(run_config.routing_table.keys())
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
|
||||
providers_with_specs = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(
|
||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||
)
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in all_api_providers[api]:
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = all_api_providers[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.dict()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
source_api = info.routing_table_api
|
||||
|
||||
assert (
|
||||
source_api not in specs
|
||||
), f"Routing table API {source_api} specified in wrong place?"
|
||||
assert (
|
||||
info.router_api not in specs
|
||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
if info.router_api.value not in run_config.routing_table:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
||||
|
||||
routing_table = run_config.routing_table[info.router_api.value]
|
||||
|
||||
providers = all_providers[info.router_api]
|
||||
|
||||
inner_specs = []
|
||||
inner_deps = []
|
||||
for rt_entry in routing_table:
|
||||
if rt_entry.provider_type not in providers:
|
||||
registry = getattr(run_config, info.routing_table_api.value)
|
||||
for entry in registry:
|
||||
if entry.provider_id not in available_providers:
|
||||
raise ValueError(
|
||||
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
|
||||
f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_type])
|
||||
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
|
||||
|
||||
specs[source_api] = RoutingTableProviderSpec(
|
||||
api=source_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
inner_specs=inner_specs,
|
||||
provider = available_providers[entry.provider_id]
|
||||
inner_deps.extend(provider.spec.api_dependencies)
|
||||
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__routing_table__",
|
||||
config={},
|
||||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
registry=registry,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
deps__=(
|
||||
[x.value for x in inner_deps]
|
||||
+ [f"inner-{info.router_api.value}"]
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
providers_with_specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__autorouted__",
|
||||
config={},
|
||||
spec=AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
sorted_providers = topological_sort(
|
||||
{k: v.values() for k, v in providers_with_specs.items()}
|
||||
)
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
),
|
||||
),
|
||||
)
|
||||
configs[source_api] = routing_table
|
||||
|
||||
specs[info.router_api] = AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=source_api,
|
||||
api_dependencies=[source_api],
|
||||
)
|
||||
configs[info.router_api] = {}
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||
for spec in sorted_specs:
|
||||
print(f" {spec.api}: {spec.provider_type}")
|
||||
print("")
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, configs[api])
|
||||
|
||||
impls[api] = impl
|
||||
|
||||
impls[Api.inspect] = DistributionInspectImpl()
|
||||
specs[Api.inspect] = InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__distribution_builtin__",
|
||||
config_class="",
|
||||
module="",
|
||||
)
|
||||
|
||||
return impls, specs
|
||||
print(f"Resolved {len(sorted_providers)} providers in topological order")
|
||||
for api_str, provider in sorted_providers:
|
||||
print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}")
|
||||
print("")
|
||||
|
||||
impls = {}
|
||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
|
||||
inner_impls = {}
|
||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[
|
||||
f"inner-{provider.spec.router_api.value}"
|
||||
]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||
by_id = {x.api: x for x in providers}
|
||||
def topological_sort(
|
||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||
) -> List[ProviderWithSpec]:
|
||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
deps = []
|
||||
for provider in providers:
|
||||
for dep in provider.spec.deps__:
|
||||
deps.append(dep)
|
||||
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
for dep in deps:
|
||||
if dep not in visited:
|
||||
dfs((dep, providers_with_specs[dep]), visited, stack)
|
||||
|
||||
stack.append(a.api)
|
||||
stack.append(api_str)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in providers:
|
||||
if a.api not in visited:
|
||||
dfs(a, visited, stack)
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
dfs((api_str, providers), visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
flattened = []
|
||||
for api_str in stack:
|
||||
for provider in providers_with_specs[api_str]:
|
||||
flattened.append((api_str, provider))
|
||||
return flattened
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||
inner_impls: Dict[str, Any],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
args = []
|
||||
|
|
@ -154,9 +223,8 @@ async def instantiate_provider(
|
|||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
|
@ -166,31 +234,18 @@ async def instantiate_provider(
|
|||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
assert isinstance(provider_config, List)
|
||||
routing_table = provider_config
|
||||
|
||||
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in routing_table:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_type],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||
args = [provider_spec.api, provider_spec.registry, inner_impls, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,23 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
registry: List[RoutableObject],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
) -> Any:
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
|
|
@ -29,7 +28,7 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
||||
impl = api_to_tables[api.value](registry, impls_by_provider_id)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
|
|||
|
|
@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
"""Routes to an provider based on the memory bank type"""
|
||||
"""Routes to an provider based on the memory bank identifier"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
|
@ -29,32 +28,8 @@ class MemoryRouter(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.routing_table.get_provider_impl(bank_type)
|
||||
if not provider:
|
||||
raise ValueError(f"Could not find provider for {bank_type}")
|
||||
return provider
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_type = config.type
|
||||
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
|
||||
name, config, url
|
||||
)
|
||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
provider = self.get_provider_from_bank_id(bank_id)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
await self.routing_table.register_memory_bank(memory_bank)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
|
@ -62,7 +37,7 @@ class MemoryRouter(Memory):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
return await self.get_provider_from_bank_id(bank_id).insert_documents(
|
||||
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
|
||||
bank_id, documents, ttl_seconds
|
||||
)
|
||||
|
||||
|
|
@ -72,7 +47,7 @@ class MemoryRouter(Memory):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
return await self.get_provider_from_bank_id(bank_id).query_documents(
|
||||
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
||||
bank_id, query, params
|
||||
)
|
||||
|
||||
|
|
@ -92,7 +67,10 @@ class InferenceRouter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -113,27 +91,32 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
**params
|
||||
):
|
||||
yield chunk
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.chat_completion(**params))
|
||||
else:
|
||||
return provider.chat_completion(**params)
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
) -> AsyncGenerator:
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
params = dict(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.completion(**params))
|
||||
else:
|
||||
return provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -159,6 +142,9 @@ class SafetyRouter(Safety):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.routing_table.register_shield(shield)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
|
|
|
|||
|
|
@ -4,9 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
|
@ -16,129 +15,129 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
return p.__provider_spec__.api
|
||||
|
||||
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
|
||||
|
||||
# TODO: this routing table maintains state in memory purely. We need to
|
||||
# add persistence to it when we add dynamic registration of objects.
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[RoutingKey, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
registry: List[RoutableObject],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
) -> None:
|
||||
self.unique_providers = []
|
||||
self.providers = {}
|
||||
self.routing_keys = []
|
||||
for obj in registry:
|
||||
if obj.provider_id not in impls_by_provider_id:
|
||||
print(f"{impls_by_provider_id=}")
|
||||
raise ValueError(
|
||||
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
||||
)
|
||||
|
||||
for key, impl in inner_impls:
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
self.unique_providers.append((keys, impl))
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.registry = registry
|
||||
|
||||
for k in keys:
|
||||
if k in self.providers:
|
||||
raise ValueError(f"Duplicate routing key {k}")
|
||||
self.providers[k] = impl
|
||||
self.routing_keys.append(k)
|
||||
for p in self.impls_by_provider_id.values():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
|
||||
self.routing_table_config = routing_table_config
|
||||
self.routing_key_to_object = {}
|
||||
for obj in self.registry:
|
||||
self.routing_key_to_object[obj.identifier] = obj
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for keys, p in self.unique_providers:
|
||||
spec = p.__provider_spec__
|
||||
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
||||
continue
|
||||
|
||||
await p.validate_routing_keys(keys)
|
||||
for obj in self.registry:
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for _, p in self.unique_providers:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.providers:
|
||||
raise ValueError(f"Could not find provider for {routing_key}")
|
||||
return self.providers[routing_key]
|
||||
if routing_key not in self.routing_key_to_object:
|
||||
raise ValueError(f"`{routing_key}` not registered")
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
obj = self.routing_key_to_object[routing_key]
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == routing_key:
|
||||
return entry
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
|
||||
for obj in self.registry:
|
||||
if obj.identifier == identifier:
|
||||
return obj
|
||||
return None
|
||||
|
||||
async def register_object(self, obj: RoutableObject):
|
||||
if obj.identifier in self.routing_key_to_object:
|
||||
print(f"`{obj.identifier}` is already registered")
|
||||
return
|
||||
|
||||
if not obj.provider_id:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if not provider_ids:
|
||||
raise ValueError("No providers found")
|
||||
|
||||
print(f"Picking provider `{provider_ids[0]}` for {obj.identifier}")
|
||||
obj.provider_id = provider_ids[0]
|
||||
else:
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
self.routing_key_to_object[obj.identifier] = obj
|
||||
self.registry.append(obj)
|
||||
|
||||
# TODO: persist this to a store
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return self.registry
|
||||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
model_id = entry.routing_key
|
||||
specs.append(
|
||||
ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == core_model_id:
|
||||
return ModelServingSpec(
|
||||
llama_model=resolve_model(core_model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.register_object(model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return self.registry
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
if isinstance(entry.routing_key, list):
|
||||
for k in entry.routing_key:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=k,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
else:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
|
||||
return self.get_object_by_identifier(shield_type)
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == shield_type:
|
||||
return ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.register_object(shield)
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
return self.registry
|
||||
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
specs.append(
|
||||
MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == bank_type:
|
||||
return MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||
await self.register_object(bank)
|
||||
|
|
|
|||
|
|
@ -5,18 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
|
@ -43,20 +40,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
|
|
@ -169,11 +152,20 @@ async def passthrough(
|
|||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
def handle_sigint(app, *args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
|
||||
async def run_shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
print(f"Shutting down {impl}")
|
||||
await impl.shutdown()
|
||||
|
||||
asyncio.run(run_shutdown())
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
|
||||
loop.stop()
|
||||
|
||||
|
||||
|
|
@ -181,7 +173,10 @@ def handle_sigint(*args, **kwargs):
|
|||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
yield
|
||||
|
||||
print("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
|
|
@ -193,65 +188,59 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
if inspect.iscoroutine(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
if is_streaming:
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(**kwargs)
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
|
|
@ -285,29 +274,25 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
impls = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
if config.apis_to_serve:
|
||||
apis_to_serve = set(config.apis_to_serve)
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
apis_to_serve.add(Api.inspect)
|
||||
apis_to_serve.add("inspect")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
provider_spec = specs[api]
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
):
|
||||
if is_passthrough(impl.__provider_spec__):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
|
|
@ -337,7 +322,9 @@ def main(
|
|||
print("")
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
built_at: '2024-09-30T09:04:30.533391'
|
||||
version: '2'
|
||||
built_at: '2024-10-08T17:42:07.505267'
|
||||
image_name: local-cpu
|
||||
docker_image: local-cpu
|
||||
conda_env: null
|
||||
apis_to_serve:
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- models
|
||||
|
|
@ -10,40 +11,48 @@ apis_to_serve:
|
|||
- safety
|
||||
- shields
|
||||
- memory_banks
|
||||
api_providers:
|
||||
providers:
|
||||
inference:
|
||||
providers:
|
||||
- remote::ollama
|
||||
- provider_id: remote::ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 6000
|
||||
safety:
|
||||
providers:
|
||||
- meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
memory:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: /home/xiyan/.llama/runtime/kvstore.db
|
||||
memory:
|
||||
providers:
|
||||
- meta-reference
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 6000
|
||||
routing_key: Meta-Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||
memory:
|
||||
- provider_type: meta-reference
|
||||
config: {}
|
||||
routing_key: vector
|
||||
models:
|
||||
- identifier: Llama3.1-8B-Instruct
|
||||
llama_model: Llama3.1-8B-Instruct
|
||||
provider_id: remote::ollama
|
||||
shields:
|
||||
- identifier: llama_guard
|
||||
type: llama_guard
|
||||
provider_id: meta-reference
|
||||
params: {}
|
||||
memory_banks:
|
||||
- identifier: vector
|
||||
provider_id: meta-reference
|
||||
type: vector
|
||||
embedding_model: all-MiniLM-L6-v2
|
||||
chunk_size_in_tokens: 512
|
||||
overlap_size_in_tokens: null
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
built_at: '2024-09-30T09:00:56.693751'
|
||||
version: '2'
|
||||
built_at: '2024-10-08T17:42:33.690666'
|
||||
image_name: local-gpu
|
||||
docker_image: local-gpu
|
||||
conda_env: null
|
||||
apis_to_serve:
|
||||
apis:
|
||||
- memory
|
||||
- inference
|
||||
- agents
|
||||
|
|
@ -10,43 +11,51 @@ apis_to_serve:
|
|||
- safety
|
||||
- models
|
||||
- memory_banks
|
||||
api_providers:
|
||||
providers:
|
||||
inference:
|
||||
providers:
|
||||
- meta-reference
|
||||
safety:
|
||||
providers:
|
||||
- meta-reference
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: /home/xiyan/.llama/runtime/kvstore.db
|
||||
memory:
|
||||
providers:
|
||||
- meta-reference
|
||||
telemetry:
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||
memory:
|
||||
- provider_type: meta-reference
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_key: vector
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
models:
|
||||
- identifier: Llama3.1-8B-Instruct
|
||||
llama_model: Llama3.1-8B-Instruct
|
||||
provider_id: meta-reference
|
||||
shields:
|
||||
- identifier: llama_guard
|
||||
type: llama_guard
|
||||
provider_id: meta-reference
|
||||
params: {}
|
||||
memory_banks:
|
||||
- identifier: vector
|
||||
provider_id: meta-reference
|
||||
type: vector
|
||||
embedding_model: all-MiniLM-L6-v2
|
||||
chunk_size_in_tokens: 512
|
||||
overlap_size_in_tokens: null
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
name: local-databricks
|
||||
distribution_spec:
|
||||
description: Use Databricks for running LLM inference
|
||||
providers:
|
||||
inference: remote::databricks
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
10
llama_stack/distribution/templates/local-vllm-build.yaml
Normal file
10
llama_stack/distribution/templates/local-vllm-build.yaml
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
name: local-vllm
|
||||
distribution_spec:
|
||||
description: Like local, but use vLLM for running LLM inference
|
||||
providers:
|
||||
inference: vllm
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
Loading…
Add table
Add a link
Reference in a new issue