Remove "routing_table" and "routing_key" concepts for the user (#201)

This PR makes several core changes to the developer experience surrounding Llama Stack.

Background: PR #92 introduced the notion of "routing" to the Llama Stack. It introduces three object types: (1) models, (2) shields and (3) memory banks. Each of these objects can be associated with a distinct provider. So you can get model A to be inferenced locally while model B, C can be inference remotely (e.g.)

However, this had a few drawbacks:

you could not address the provider instances -- i.e., if you configured "meta-reference" with a given model, you could not assign an identifier to this instance which you could re-use later.
the above meant that you could not register a "routing_key" (e.g. model) dynamically and say "please use this existing provider I have already configured" for a new model.
the terms "routing_table" and "routing_key" were exposed directly to the user. in my view, this is way too much overhead for a new user (which almost everyone is.) people come to the stack wanting to do ML and encounter a completely unexpected term.
What this PR does: This PR structures the run config with only a single prominent key:

- providers
Providers are instances of configured provider types. Here's an example which shows two instances of the remote::tgi provider which are serving two different models.

providers:
  inference:
  - provider_id: foo
    provider_type: remote::tgi
    config: { ... }
  - provider_id: bar
    provider_type: remote::tgi
    config: { ... }
Secondly, the PR adds dynamic registration of { models | shields | memory_banks } to the API surface. The distribution still acts like a "routing table" (as previously) except that it asks the backing providers for a listing of these objects. For example it asks a TGI or Ollama inference adapter what models it is serving. Only the models that are being actually served can be requested by the user for inference. Otherwise, the Stack server will throw an error.

When dynamically registering these objects, you can use the provider IDs shown above. Info about providers can be obtained using the Api.inspect set of endpoints (/providers, /routes, etc.)

The above examples shows the correspondence between inference providers and models registry items. Things work similarly for the safety <=> shields and memory <=> memory_banks pairs.

Registry: This PR also makes it so that Providers need to implement additional methods for registering and listing objects. For example, each Inference provider is now expected to implement the ModelsProtocolPrivate protocol (naming is not great!) which consists of two methods

register_model
list_models
The goal is to inform the provider that a certain model needs to be supported so the provider can make any relevant backend changes if needed (or throw an error if the model cannot be supported.)

There are many other cleanups included some of which are detailed in a follow-up comment.
This commit is contained in:
Ashwin Bharambe 2024-10-10 10:24:13 -07:00 committed by GitHub
parent 8c3010553f
commit 6bb57e72a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
93 changed files with 4697 additions and 4457 deletions

View file

@ -3,189 +3,182 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import textwrap
from typing import Any
from llama_models.sku_list import (
llama3_1_family,
llama3_2_family,
llama3_family,
resolve_model,
safety_models,
)
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
stack_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.impls.meta_reference.safety.config import (
MetaReferenceShieldType,
)
ALLOWED_MODELS = (
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
)
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
config: config_class
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
if provider.config:
existing = config_type(**provider.config)
else:
existing = None
except Exception:
existing = None
return BaseModelWithConfig
cfg = prompt_for_config(config_type, existing)
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.dict(),
)
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
"""Get corresponding builtin APIs given provider backed APIs"""
res = []
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in provider_backed_apis:
res.append(inf.routing_table_api.value)
return res
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
config: StackRunConfig, build_spec: DistributionSpec
) -> StackRunConfig:
apis = config.apis_to_serve or list(spec.providers.keys())
# append the bulitin routing APIs
apis += get_builtin_apis(apis)
is_nux = len(config.providers) == 0
router_api2builtin_api = {
inf.router_api.value: inf.routing_table_api.value
for inf in builtin_automatically_routed_apis()
}
if is_nux:
print(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
"""
)
)
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
apis = [v.value for v in stack_apis()]
all_providers = get_provider_registry()
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
# configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys():
if api_str not in apis:
for api_str in apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue
if api not in provider_registry:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
api = Api(api_str)
p = spec.providers[api_str]
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
if isinstance(p, list):
existing_providers = config.providers.get(api_str, [])
if existing_providers:
cprint(
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
"yellow",
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
p = p[0]
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
try:
provider_config = config.api_providers.get(api_str)
if provider_config:
existing = config_type(**provider_config.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
if api_str in router_api2builtin_api:
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
routing_entries = []
if api_str == "inference":
if hasattr(cfg, "model"):
routing_key = cfg.model
else:
routing_key = prompt(
"> Please enter the supported model your provider has for inference: ",
default="Llama3.1-8B-Instruct",
validator=Validator.from_callable(
lambda x: resolve_model(x) is not None,
error_message="Model must be: {}".format(
[x.descriptor() for x in ALLOWED_MODELS]
),
),
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
updated_providers = []
for p in existing_providers:
print(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(
configure_single_provider(provider_registry[api], p)
)
if api_str == "safety":
# TODO: add support for other safety providers, and simplify safety provider config
if p == "meta-reference":
routing_entries.append(
RoutableProviderConfig(
routing_key=[s.value for s in MetaReferenceShieldType],
provider_type=p,
config=cfg.dict(),
)
)
else:
cprint(
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
"yellow",
attrs=["bold"],
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
)
if api_str == "memory":
bank_types = list([x.value for x in MemoryBankType])
routing_key = prompt(
"> Please enter the supported memory bank type your provider has for memory: ",
default="vector",
validator=Validator.from_callable(
lambda x: x in bank_types,
error_message="Invalid provider, please enter one of the following: {}".format(
bank_types
),
),
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
)
config.routing_table[api_str] = routing_entries
config.api_providers[api_str] = PlaceholderProviderConfig(
providers=p if isinstance(p, list) else [p]
)
print("")
else:
config.api_providers[api_str] = GenericProviderConfig(
provider_type=p,
config=cfg.dict(),
)
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
plist = plist if isinstance(plist, list) else [plist]
print("")
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
updated_providers = []
for i, provider_type in enumerate(plist):
print(f"> Configuring provider `({provider_type})`")
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(
f"{provider_type}-{i:02d}"
if len(plist) > 1
else provider_type
),
provider_type=provider_type,
config={},
),
)
)
print("")
config.providers[api_str] = updated_providers
return config
def upgrade_from_routing_table(
config_dict: Dict[str, Any],
) -> Dict[str, Any]:
def get_providers(entries):
return [
Provider(
provider_id=(
f"{entry['provider_type']}-{i:02d}"
if len(entries) > 1
else entry["provider_type"]
),
provider_type=entry["provider_type"],
config=entry["config"],
)
for i, entry in enumerate(entries)
]
providers_by_api = {}
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
if isinstance(provider, dict) and "provider_type" in provider:
providers_by_api[api_str] = [
Provider(
provider_id=f"{provider['provider_type']}",
provider_type=provider["provider_type"],
config=provider["config"],
)
]
config_dict["providers"] = providers_by_api
config_dict.pop("routing_table", None)
config_dict.pop("api_providers", None)
config_dict.pop("provider_map", None)
config_dict["apis"] = config_dict["apis_to_serve"]
config_dict.pop("apis_to_serve", None)
return config_dict
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict)
if "routing_table" in config_dict:
print("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
config_dict["built_at"] = datetime.now().isoformat()
return StackRunConfig(**config_dict)

View file

@ -11,28 +11,38 @@ from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]]
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
RoutableObject = Union[
ModelDef,
ShieldDef,
MemoryBankDef,
]
RoutableObjectWithProvider = Union[
ModelDefWithProvider,
ShieldDefWithProvider,
MemoryBankDefWithProvider,
]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
RoutedProtocol = Union[
Inference,
Safety,
Memory,
]
# Example: /inference, /safety
@ -53,18 +63,16 @@ class AutoRoutedProviderSpec(ProviderSpec):
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
router_api: Api
module: str
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class DistributionSpec(BaseModel):
description: Optional[str] = Field(
default="",
@ -80,7 +88,12 @@ in the runtime configuration to help route to the correct provider.""",
)
@json_schema_type
class Provider(BaseModel):
provider_id: str
provider_type: str
config: Dict[str, Any]
class StackRunConfig(BaseModel):
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
built_at: datetime
@ -100,36 +113,20 @@ this could be just a hash
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
apis_to_serve: List[str] = Field(
apis: List[str] = Field(
default_factory=list,
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
api_providers: Dict[
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
] = Field(
providers: Dict[str, List[Provider]] = Field(
description="""
Provider configurations for each of the APIs provided by this package.
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
default_factory=dict,
description="""
E.g. The following is a ProviderRoutingEntry for models:
- routing_key: Llama3.1-8B-Instruct
provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
""",
)
@json_schema_type
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
name: str

View file

@ -6,45 +6,58 @@
from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
class DistributionInspectConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = DistributionInspectImpl(config, deps)
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect):
def __init__(self):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
run_config = self.config.run_config
ret = {}
all_providers = get_provider_registry()
for api, providers in all_providers.items():
ret[api.value] = [
for api, providers in run_config.providers.items():
ret[api] = [
ProviderInfo(
provider_id=p.provider_id,
provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
)
for p in providers.values()
for p in providers
]
return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
run_config = self.config.run_config
ret = {}
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret[api.value] = [
RouteInfo(
route=e.route,
method=e.method,
providers=[],
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
]

View file

@ -4,146 +4,237 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import inspect
from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.inspect import DistributionInspectImpl
from llama_stack.distribution.utils.dynamic import instantiate_class_type
def api_protocol_map() -> Dict[Api, Any]:
return {
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.memory: Memory,
Api.memory_banks: MemoryBanks,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
}
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: ModelsProtocolPrivate,
Api.memory: MemoryBanksProtocolPrivate,
Api.safety: ShieldsProtocolPrivate,
}
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = get_provider_registry()
specs = {}
configs = {}
all_api_providers = get_provider_registry()
for api_str, config in run_config.api_providers.items():
api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# skip checks for API whose provider config is specified in routing_table
if isinstance(config, PlaceholderProviderConfig):
continue
if config.provider_type not in providers:
raise ValueError(
f"Provider `{config.provider_type}` is not available for API `{api}`"
)
specs[api] = providers[config.provider_type]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_table.keys())
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
providers_with_specs = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
specs = {}
for provider in providers:
if provider.provider_type not in all_api_providers[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = all_api_providers[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.dict()),
)
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
)
for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
routing_table = run_config.routing_table[info.router_api.value]
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
deps__=([f"inner-{info.router_api.value}"]),
),
)
}
providers = all_providers[info.router_api]
providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
)
}
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_type not in providers:
raise ValueError(
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_type])
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
inner_specs=inner_specs,
sorted_providers = topological_sort(
{k: v.values() for k, v in providers_with_specs.items()}
)
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
"inspect",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={
"run_config": run_config.dict(),
},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=([x.value for x in apis]),
),
),
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_type}")
print("")
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, configs[api])
impls[api] = impl
impls[Api.inspect] = DistributionInspectImpl()
specs[Api.inspect] = InlineProviderSpec(
api=Api.inspect,
provider_type="__distribution_builtin__",
config_class="",
module="",
)
return impls, specs
print(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
print(f" {api_str} => {provider.provider_id}")
print("")
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider(
provider,
deps,
inner_impls,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
return impls
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[ProviderWithSpec]:
def dfs(kv, visited: Set[str], stack: List[str]):
api_str, providers = kv
visited.add(api_str)
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
deps = []
for provider in providers:
for dep in provider.spec.deps__:
deps.append(dep)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
for dep in deps:
if dep not in visited:
dfs((dep, providers_with_specs[dep]), visited, stack)
stack.append(a.api)
stack.append(api_str)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
dfs((api_str, providers), visited, stack)
return [by_id[x] for x in stack]
flattened = []
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
provider: ProviderWithSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
inner_impls: Dict[str, Any],
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
module = importlib.import_module(provider_spec.module)
args = []
@ -153,9 +244,8 @@ async def instantiate_provider(
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
config = config_type(**provider.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
@ -165,31 +255,69 @@ async def instantiate_provider(
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
args = [provider_spec.api, inner_impls, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
config = config_type(**provider.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
missing_methods = []
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):
missing_methods.append((name, "not_callable"))
else:
# Check if the method signatures are compatible
obj_method = getattr(obj, name)
proto_sig = inspect.signature(value)
obj_sig = inspect.signature(obj_method)
proto_params = set(proto_sig.parameters)
proto_params.discard("self")
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
print(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)

View file

@ -4,23 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
) -> Any:
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
@ -29,7 +27,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](inner_impls, routing_table_config)
impl = api_to_tables[api.value](impls_by_provider_id)
await impl.initialize()
return impl

View file

@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type"""
"""Routes to an provider based on the memory bank identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
@ -29,32 +28,8 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None:
pass
def get_provider_from_bank_id(self, bank_id: str) -> Any:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
name, config, url
)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(memory_bank)
async def insert_documents(
self,
@ -62,7 +37,7 @@ class MemoryRouter(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.get_provider_from_bank_id(bank_id).insert_documents(
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds
)
@ -72,7 +47,7 @@ class MemoryRouter(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.get_provider_from_bank_id(bank_id).query_documents(
return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params
)
@ -92,7 +67,10 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None:
pass
async def chat_completion(
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
def chat_completion(
self,
model: str,
messages: List[Message],
@ -113,27 +91,32 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
**params
):
yield chunk
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in provider.chat_completion(**params))
else:
return provider.chat_completion(**params)
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(model).completion(
) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model)
params = dict(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in provider.completion(**params))
else:
return provider.completion(**params)
async def embeddings(
self,
@ -159,6 +142,9 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield)
async def run_shield(
self,
shield_type: str,

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -16,129 +15,159 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.inference:
await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
Registry = Dict[str, List[RoutableObjectWithProvider]]
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[RoutingKey, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
impls_by_provider_id: Dict[str, RoutedProtocol],
) -> None:
self.unique_providers = []
self.providers = {}
self.routing_keys = []
for key, impl in inner_impls:
keys = key if isinstance(key, list) else [key]
self.unique_providers.append((keys, impl))
for k in keys:
if k in self.providers:
raise ValueError(f"Duplicate routing key {k}")
self.providers[k] = impl
self.routing_keys.append(k)
self.routing_table_config = routing_table_config
self.impls_by_provider_id = impls_by_provider_id
async def initialize(self) -> None:
for keys, p in self.unique_providers:
spec = p.__provider_spec__
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
continue
self.registry: Registry = {}
await p.validate_routing_keys(keys)
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
for obj in objs:
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
models = await p.list_models()
add_objects(
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
)
elif api == Api.safety:
p.shield_store = self
shields = await p.list_shields()
add_objects(
[
ShieldDefWithProvider(**s.dict(), provider_id=pid)
for s in shields
]
)
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
# do in-memory updates due to pesky Annotated unions
for m in memory_banks:
m.provider_id = pid
add_objects(memory_banks)
async def shutdown(self) -> None:
for _, p in self.unique_providers:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.providers:
raise ValueError(f"Could not find provider for {routing_key}")
return self.providers[routing_key]
def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
if routing_key not in self.registry:
raise ValueError(f"`{routing_key}` not registered")
def get_routing_keys(self) -> List[str]:
return self.routing_keys
objs = self.registry[routing_key]
for obj in objs:
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config:
if entry.routing_key == routing_key:
return entry
return None
raise ValueError(f"Provider not found for `{routing_key}`")
def get_object_by_identifier(
self, identifier: str
) -> Optional[RoutableObjectWithProvider]:
objs = self.registry.get(identifier, [])
if not objs:
return None
# kind of ill-defined behavior here, but we'll just return the first one
return objs[0]
async def register_object(self, obj: RoutableObjectWithProvider):
entries = self.registry.get(obj.identifier, [])
for entry in entries:
if entry.provider_id == obj.provider_id:
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
return
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
# TODO: persist this to a store
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=entry,
)
)
return specs
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
return self.get_object_by_identifier(identifier)
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
provider_config=entry,
)
return None
async def register_model(self, model: ModelDefWithProvider) -> None:
await self.register_object(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config:
if isinstance(entry.routing_key, list):
for k in entry.routing_key:
specs.append(
ShieldSpec(
shield_type=k,
provider_config=entry,
)
)
else:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
return self.get_object_by_identifier(shield_type)
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
return None
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
await self.register_object(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]:
return self.get_object_by_identifier(identifier)
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
return None
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None:
await self.register_object(memory_bank)

View file

@ -9,15 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
@ -31,18 +23,7 @@ class ApiEndpoint(BaseModel):
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
Api.inspect: Inspect,
}
protocols = api_protocol_map()
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)

View file

@ -5,18 +5,15 @@
# the root directory of this source tree.
import asyncio
import functools
import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
from typing import Any, Dict, Optional
import fire
import httpx
@ -29,6 +26,8 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
@ -43,20 +42,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
@ -169,11 +154,20 @@ async def passthrough(
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(*args, **kwargs):
def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
async def run_shutdown():
for impl in app.__llama_stack_impls__.values():
print(f"Shutting down {impl}")
await impl.shutdown()
asyncio.run(run_shutdown())
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@ -181,7 +175,10 @@ def handle_sigint(*args, **kwargs):
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
def create_dynamic_passthrough(
@ -193,65 +190,59 @@ def create_dynamic_passthrough(
return endpoint
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
finally:
await end_trace()
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints.get("return")
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
if is_streaming:
set_request_provider_data(request.headers)
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
finally:
await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
try:
return (
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func)
new_params = [
@ -285,29 +276,28 @@ def main(
app = FastAPI()
impls, specs = asyncio.run(resolve_impls_with_routing(config))
impls = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
all_endpoints = get_all_api_endpoints()
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
apis_to_serve.add(Api.inspect)
for inf in builtin_automatically_routed_apis():
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
provider_spec = specs[api]
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
if is_passthrough(impl.__provider_spec__):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
@ -337,7 +327,9 @@ def main(
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
app.__llama_stack_impls__ = impls
import uvicorn

View file

@ -1,8 +1,9 @@
built_at: '2024-09-30T09:04:30.533391'
version: '2'
built_at: '2024-10-08T17:42:07.505267'
image_name: local-cpu
docker_image: local-cpu
conda_env: null
apis_to_serve:
apis:
- agents
- inference
- models
@ -10,40 +11,32 @@ apis_to_serve:
- safety
- shields
- memory_banks
api_providers:
providers:
inference:
providers:
- remote::ollama
- provider_id: remote::ollama
provider_type: remote::ollama
config:
host: localhost
port: 6000
safety:
providers:
- meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
memory:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_type: remote::ollama
config:
host: localhost
port: 6000
routing_key: Llama3.1-8B-Instruct
safety:
- provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_type: meta-reference
config: {}
routing_key: vector

View file

@ -1,8 +1,9 @@
built_at: '2024-09-30T09:00:56.693751'
version: '2'
built_at: '2024-10-08T17:42:33.690666'
image_name: local-gpu
docker_image: local-gpu
conda_env: null
apis_to_serve:
apis:
- memory
- inference
- agents
@ -10,43 +11,35 @@ apis_to_serve:
- safety
- models
- memory_banks
api_providers:
providers:
inference:
providers:
- meta-reference
safety:
providers:
- meta-reference
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
telemetry:
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
routing_key: Llama3.1-8B-Instruct
safety:
- provider_type: meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_type: meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config: {}
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_key: vector