mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore(package): migrate to src/ layout (#3920)
Migrates package structure to src/ layout following Python packaging best practices. All code moved from `llama_stack/` to `src/llama_stack/`. Public API unchanged - imports remain `import llama_stack.*`. Updated build configs, pre-commit hooks, scripts, and GitHub workflows accordingly. All hooks pass, package builds cleanly. **Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
parent
98a5047f9d
commit
471b1b248b
791 changed files with 2983 additions and 456 deletions
5
src/llama_stack/core/__init__.py
Normal file
5
src/llama_stack/core/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
5
src/llama_stack/core/access_control/__init__.py
Normal file
5
src/llama_stack/core/access_control/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
131
src/llama_stack/core/access_control/access_control.py
Normal file
131
src/llama_stack/core/access_control/access_control.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
|
||||
from .conditions import (
|
||||
Condition,
|
||||
ProtectedResource,
|
||||
parse_conditions,
|
||||
)
|
||||
from .datatypes import (
|
||||
AccessRule,
|
||||
Action,
|
||||
Scope,
|
||||
)
|
||||
|
||||
|
||||
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
|
||||
if resource_scope == actual_resource:
|
||||
return True
|
||||
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
|
||||
|
||||
|
||||
def matches_scope(
|
||||
scope: Scope,
|
||||
action: Action,
|
||||
resource: str,
|
||||
user: str | None,
|
||||
) -> bool:
|
||||
if scope.resource and not matches_resource(scope.resource, resource):
|
||||
return False
|
||||
if scope.principal and scope.principal != user:
|
||||
return False
|
||||
return action in scope.actions
|
||||
|
||||
|
||||
def as_list(obj: Any) -> list[Any]:
|
||||
if isinstance(obj, list):
|
||||
return obj
|
||||
return [obj]
|
||||
|
||||
|
||||
def matches_conditions(
|
||||
conditions: list[Condition],
|
||||
resource: ProtectedResource,
|
||||
user: User,
|
||||
) -> bool:
|
||||
for condition in conditions:
|
||||
# must match all conditions
|
||||
if not condition.matches(resource, user):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def default_policy() -> list[AccessRule]:
|
||||
# for backwards compatibility, if no rules are provided, assume
|
||||
# full access subject to previous attribute matching rules
|
||||
return [
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def is_action_allowed(
|
||||
policy: list[AccessRule],
|
||||
action: Action,
|
||||
resource: ProtectedResource,
|
||||
user: User | None,
|
||||
) -> bool:
|
||||
# If user is not set, assume authentication is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not len(policy):
|
||||
policy = default_policy()
|
||||
|
||||
qualified_resource_id = f"{resource.type}::{resource.identifier}"
|
||||
for rule in policy:
|
||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||
return False
|
||||
elif rule.unless:
|
||||
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||
return True
|
||||
elif rule.unless:
|
||||
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
# assume access is denied unless we find a rule that permits access
|
||||
return False
|
||||
|
||||
|
||||
class AccessDeniedError(RuntimeError):
|
||||
def __init__(self, action: str | None = None, resource: ProtectedResource | None = None, user: User | None = None):
|
||||
self.action = action
|
||||
self.resource = resource
|
||||
self.user = user
|
||||
|
||||
message = _build_access_denied_message(action, resource, user)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _build_access_denied_message(action: str | None, resource: ProtectedResource | None, user: User | None) -> str:
|
||||
"""Build detailed error message for access denied scenarios."""
|
||||
if action and resource and user:
|
||||
resource_info = f"{resource.type}::{resource.identifier}"
|
||||
user_info = f"'{user.principal}'"
|
||||
if user.attributes:
|
||||
attrs = ", ".join([f"{k}={v}" for k, v in user.attributes.items()])
|
||||
user_info += f" (attributes: {attrs})"
|
||||
|
||||
message = f"User {user_info} cannot perform action '{action}' on resource '{resource_info}'"
|
||||
else:
|
||||
message = "Insufficient permissions"
|
||||
|
||||
return message
|
||||
129
src/llama_stack/core/access_control/conditions.py
Normal file
129
src/llama_stack/core/access_control/conditions.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class User(Protocol):
|
||||
principal: str
|
||||
attributes: dict[str, list[str]] | None
|
||||
|
||||
|
||||
class ProtectedResource(Protocol):
|
||||
type: str
|
||||
identifier: str
|
||||
owner: User
|
||||
|
||||
|
||||
class Condition(Protocol):
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
|
||||
|
||||
|
||||
class UserInOwnersList:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
|
||||
if (
|
||||
hasattr(resource, "owner")
|
||||
and resource.owner
|
||||
and resource.owner.attributes
|
||||
and self.name in resource.owner.attributes
|
||||
):
|
||||
return resource.owner.attributes[self.name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
required = self.owners_values(resource)
|
||||
if not required:
|
||||
return True
|
||||
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
|
||||
return False
|
||||
user_values = user.attributes[self.name]
|
||||
for value in required:
|
||||
if value in user_values:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"user in owners {self.name}"
|
||||
|
||||
|
||||
class UserNotInOwnersList(UserInOwnersList):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not super().matches(resource, user)
|
||||
|
||||
def __repr__(self):
|
||||
return f"user not in owners {self.name}"
|
||||
|
||||
|
||||
class UserWithValueInList:
|
||||
def __init__(self, name: str, value: str):
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
if user.attributes and self.name in user.attributes:
|
||||
return self.value in user.attributes[self.name]
|
||||
print(f"User does not have {self.value} in {self.name}")
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"user with {self.value} in {self.name}"
|
||||
|
||||
|
||||
class UserWithValueNotInList(UserWithValueInList):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__(name, value)
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not super().matches(resource, user)
|
||||
|
||||
def __repr__(self):
|
||||
return f"user with {self.value} not in {self.name}"
|
||||
|
||||
|
||||
class UserIsOwner:
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return resource.owner.principal == user.principal if resource.owner else False
|
||||
|
||||
def __repr__(self):
|
||||
return "user is owner"
|
||||
|
||||
|
||||
class UserIsNotOwner:
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not resource.owner or resource.owner.principal != user.principal
|
||||
|
||||
def __repr__(self):
|
||||
return "user is not owner"
|
||||
|
||||
|
||||
def parse_condition(condition: str) -> Condition:
|
||||
words = condition.split()
|
||||
match words:
|
||||
case ["user", "is", "owner"]:
|
||||
return UserIsOwner()
|
||||
case ["user", "is", "not", "owner"]:
|
||||
return UserIsNotOwner()
|
||||
case ["user", "with", value, "in", name]:
|
||||
return UserWithValueInList(name, value)
|
||||
case ["user", "with", value, "not", "in", name]:
|
||||
return UserWithValueNotInList(name, value)
|
||||
case ["user", "in", "owners", name]:
|
||||
return UserInOwnersList(name)
|
||||
case ["user", "not", "in", "owners", name]:
|
||||
return UserNotInOwnersList(name)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def parse_conditions(conditions: list[str]) -> list[Condition]:
|
||||
return [parse_condition(c) for c in conditions]
|
||||
107
src/llama_stack/core/access_control/datatypes.py
Normal file
107
src/llama_stack/core/access_control/datatypes.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from .conditions import parse_conditions
|
||||
|
||||
|
||||
class Action(StrEnum):
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class Scope(BaseModel):
|
||||
principal: str | None = None
|
||||
actions: Action | list[Action]
|
||||
resource: str | None = None
|
||||
|
||||
|
||||
def _mutually_exclusive(obj, a: str, b: str):
|
||||
if getattr(obj, a) and getattr(obj, b):
|
||||
raise ValueError(f"{a} and {b} are mutually exclusive")
|
||||
|
||||
|
||||
def _require_one_of(obj, a: str, b: str):
|
||||
if not getattr(obj, a) and not getattr(obj, b):
|
||||
raise ValueError(f"on of {a} or {b} is required")
|
||||
|
||||
|
||||
class AccessRule(BaseModel):
|
||||
"""Access rule based loosely on cedar policy language
|
||||
|
||||
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||
principal or a resource that must match for the rule to take effect. The resource
|
||||
to match should be specified in the form of a type qualified identifier, e.g.
|
||||
model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
|
||||
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||
requests.
|
||||
|
||||
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
||||
constraints as to where the rule applies. The constraints supported at present are:
|
||||
|
||||
- 'user with <attr-value> in <attr-name>'
|
||||
- 'user with <attr-value> not in <attr-name>'
|
||||
- 'user is owner'
|
||||
- 'user is not owner'
|
||||
- 'user in owners <attr-name>'
|
||||
- 'user not in owners <attr-name>'
|
||||
|
||||
Rules are tested in order to find a match. If a match is found, the request is
|
||||
permitted or forbidden depending on the type of rule. If no match is found, the
|
||||
request is denied. If no rules are specified, a rule that allows any action as
|
||||
long as the resource attributes match the user attributes is added
|
||||
(i.e. the previous behaviour is the default).
|
||||
|
||||
Some examples in yaml:
|
||||
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [create, read, delete]
|
||||
resource: model::*
|
||||
description: user-1 has full access to all models
|
||||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: user in owner teams
|
||||
description: any user has read access to any resource created by a member of their team
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
resource: vector_store::*
|
||||
unless: user with admin in roles
|
||||
description: only user with admin role can use vector_store resources
|
||||
|
||||
"""
|
||||
|
||||
permit: Scope | None = None
|
||||
forbid: Scope | None = None
|
||||
when: str | list[str] | None = None
|
||||
unless: str | list[str] | None = None
|
||||
description: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_rule_format(self) -> Self:
|
||||
_require_one_of(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "when", "unless")
|
||||
if isinstance(self.when, list):
|
||||
parse_conditions(self.when)
|
||||
elif self.when:
|
||||
parse_conditions([self.when])
|
||||
if isinstance(self.unless, list):
|
||||
parse_conditions(self.unless)
|
||||
elif self.unless:
|
||||
parse_conditions([self.unless])
|
||||
return self
|
||||
164
src/llama_stack/core/build.py
Normal file
164
src/llama_stack/core/build.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.datatypes import BuildConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
"aiosqlite",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"uvicorn",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
]
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
api: Api
|
||||
provider: str
|
||||
|
||||
|
||||
def get_provider_dependencies(
|
||||
config: BuildConfig | DistributionTemplate,
|
||||
) -> tuple[list[str], list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
if isinstance(config, DistributionTemplate):
|
||||
config = config.build_config()
|
||||
|
||||
providers = config.distribution_spec.providers
|
||||
additional_pip_packages = config.additional_pip_packages
|
||||
|
||||
deps = []
|
||||
external_provider_deps = []
|
||||
registry = get_provider_registry(config)
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
providers_for_api = registry[Api(api_str)]
|
||||
|
||||
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||
|
||||
for provider in providers:
|
||||
# Providers from BuildConfig and RunConfig are subtly different - not great
|
||||
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||
|
||||
if provider_type not in providers_for_api:
|
||||
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
|
||||
|
||||
provider_spec = providers_for_api[provider_type]
|
||||
if hasattr(provider_spec, "is_external") and provider_spec.is_external:
|
||||
# this ensures we install the top level module for our external providers
|
||||
if provider_spec.module:
|
||||
if isinstance(provider_spec.module, str):
|
||||
external_provider_deps.append(provider_spec.module)
|
||||
else:
|
||||
external_provider_deps.extend(provider_spec.module)
|
||||
if hasattr(provider_spec, "pip_packages"):
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
|
||||
raise ValueError("A stack's dependencies cannot have a container image")
|
||||
|
||||
normal_deps = []
|
||||
special_deps = []
|
||||
for package in deps:
|
||||
if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
|
||||
special_deps.append(package)
|
||||
else:
|
||||
normal_deps.append(package)
|
||||
|
||||
normal_deps.extend(additional_pip_packages or [])
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(config: BuildConfig):
|
||||
normal_deps, special_deps, _ = get_provider_dependencies(config)
|
||||
|
||||
cprint(
|
||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
for special_dep in special_deps:
|
||||
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
||||
print()
|
||||
|
||||
|
||||
def build_image(
|
||||
build_config: BuildConfig,
|
||||
image_name: str,
|
||||
distro_or_config: str,
|
||||
run_config: str | None = None,
|
||||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
|
||||
|
||||
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
if build_config.external_apis_dir:
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
"--distro-or-config",
|
||||
distro_or_config,
|
||||
"--image-name",
|
||||
image_name,
|
||||
"--container-base",
|
||||
container_base,
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.extend(["--run-config", run_config])
|
||||
else:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# Always pass both arguments, even if empty, to maintain consistent positional arguments
|
||||
if special_deps:
|
||||
args.extend(["--optional-deps", "#".join(special_deps)])
|
||||
if external_provider_deps:
|
||||
args.extend(
|
||||
["--external-provider-deps", "#".join(external_provider_deps)]
|
||||
) # the script will install external provider module, get its deps, and install those too.
|
||||
|
||||
return_code = run_command(args)
|
||||
|
||||
if return_code != 0:
|
||||
log.error(
|
||||
f"Failed to build target {image_name} with return code {return_code}",
|
||||
)
|
||||
|
||||
return return_code
|
||||
205
src/llama_stack/core/client.py
Normal file
205
src/llama_stack/core/client.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
||||
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||
client_class = create_api_client_class(protocol)
|
||||
impl = client_class(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol) -> type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: str):
|
||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.routes = {}
|
||||
|
||||
# Store routes for this protocol
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
|
||||
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
|
||||
|
||||
# TODO: make this more precise, same thing needs to happen in server.py
|
||||
is_streaming = kwargs.get("stream", False)
|
||||
if is_streaming:
|
||||
return self._call_streaming(method_name, *args, **kwargs)
|
||||
else:
|
||||
return await self._call_non_streaming(method_name, *args, **kwargs)
|
||||
|
||||
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
_, sig = self.routes[method_name]
|
||||
|
||||
if sig.return_annotation is None:
|
||||
return_type = None
|
||||
else:
|
||||
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
||||
assert return_type, f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
response = await client.request(**params)
|
||||
response.raise_for_status()
|
||||
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
|
||||
return parse_obj_as(return_type, j)
|
||||
|
||||
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
return_type = extract_async_iterator_type(sig.return_annotation)
|
||||
assert return_type, f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
async with client.stream(**params) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
data = json.loads(data)
|
||||
if "error" in data:
|
||||
cprint(data, color="red", file=sys.stderr)
|
||||
continue
|
||||
|
||||
yield parse_obj_as(return_type, data)
|
||||
except Exception as e:
|
||||
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
|
||||
cprint(data, color="red", file=sys.stderr)
|
||||
|
||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
parameters = list(sig.parameters.values())[1:] # skip `self`
|
||||
for i, param in enumerate(parameters):
|
||||
if i >= len(args):
|
||||
break
|
||||
kwargs[param.name] = args[i]
|
||||
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
|
||||
if not webmethods:
|
||||
raise RuntimeError(f"Method {method} has no webmethod decorators")
|
||||
|
||||
# Choose the preferred webmethod (non-deprecated if available)
|
||||
preferred_webmethod = None
|
||||
for wm in webmethods:
|
||||
if not getattr(wm, "deprecated", False):
|
||||
preferred_webmethod = wm
|
||||
break
|
||||
|
||||
# If no non-deprecated found, use the first one
|
||||
if preferred_webmethod is None:
|
||||
preferred_webmethod = webmethods[0]
|
||||
|
||||
url = f"{self.base_url}/{preferred_webmethod.level}/{preferred_webmethod.route.lstrip('/')}"
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, list):
|
||||
return [convert(v) for v in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
elif isinstance(value, Enum):
|
||||
return value.value
|
||||
else:
|
||||
return value
|
||||
|
||||
params = {}
|
||||
data = {}
|
||||
if webmethod.method == "GET":
|
||||
params.update(kwargs)
|
||||
else:
|
||||
data.update(convert(kwargs))
|
||||
|
||||
ret = dict(
|
||||
method=webmethod.method or "POST",
|
||||
url=url,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if params:
|
||||
ret["params"] = params
|
||||
if data:
|
||||
ret["json"] = data
|
||||
|
||||
return ret
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
|
||||
# Name the class after the protocol
|
||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||
_CLIENT_CLASSES[protocol] = APIClient
|
||||
return APIClient
|
||||
|
||||
|
||||
# not quite general these methods are
|
||||
def extract_non_async_iterator_type(type_hint):
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if not issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||
return arg
|
||||
return type_hint
|
||||
|
||||
|
||||
def extract_async_iterator_type(type_hint):
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||
inner_args = get_args(arg)
|
||||
return inner_args[0]
|
||||
return None
|
||||
37
src/llama_stack/core/common.sh
Executable file
37
src/llama_stack/core/common.sh
Executable file
|
|
@ -0,0 +1,37 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
cleanup() {
|
||||
# For venv environments, no special cleanup is needed
|
||||
# This function exists to avoid "function not found" errors
|
||||
local env_name="$1"
|
||||
echo "Cleanup called for environment: $env_name"
|
||||
}
|
||||
|
||||
handle_int() {
|
||||
if [ -n "$ENVNAME" ]; then
|
||||
cleanup "$ENVNAME"
|
||||
fi
|
||||
exit 1
|
||||
}
|
||||
|
||||
handle_exit() {
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "\033[1;31mABORTING.\033[0m"
|
||||
if [ -n "$ENVNAME" ]; then
|
||||
cleanup "$ENVNAME"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
|
||||
# check if a command is present
|
||||
is_command_available() {
|
||||
command -v "$1" &>/dev/null
|
||||
}
|
||||
212
src/llama_stack/core/configure.py
Normal file
212
src/llama_stack/core/configure.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
if provider.config:
|
||||
existing = config_type(**provider.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
if is_nux:
|
||||
logger.info(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||
we need to configure the providers (implementations) you want to use for these APIs.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.inspect, Api.providers)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
if api not in provider_registry:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(configure_single_provider(provider_registry[api], p))
|
||||
logger.info("")
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
plist = build_spec.providers.get(api_str, [])
|
||||
plist = plist if isinstance(plist, list) else [plist]
|
||||
|
||||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
logger.info(f"Configuring API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for i, provider in enumerate(plist):
|
||||
if i >= 1:
|
||||
others = ", ".join(p.provider_type for p in plist[i:])
|
||||
logger.info(
|
||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
|
||||
provider_type=provider.provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info("")
|
||||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def upgrade_from_routing_table(
|
||||
config_dict: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
|
||||
provider_type=entry["provider_type"],
|
||||
config=entry["config"],
|
||||
)
|
||||
for i, entry in enumerate(entries)
|
||||
]
|
||||
|
||||
providers_by_api = {}
|
||||
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
if isinstance(provider, dict) and "provider_type" in provider:
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=f"{provider['provider_type']}",
|
||||
provider_type=provider["provider_type"],
|
||||
config=provider["config"],
|
||||
)
|
||||
]
|
||||
|
||||
config_dict["providers"] = providers_by_api
|
||||
|
||||
config_dict.pop("routing_table", None)
|
||||
config_dict.pop("api_providers", None)
|
||||
config_dict.pop("provider_map", None)
|
||||
|
||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
# Add default storage config if not present
|
||||
if "storage" not in config_dict:
|
||||
config_dict["storage"] = {
|
||||
"backends": {
|
||||
"kv_default": {
|
||||
"type": "kv_sqlite",
|
||||
"db_path": "~/.llama/kvstore.db",
|
||||
},
|
||||
"sql_default": {
|
||||
"type": "sql_sqlite",
|
||||
"db_path": "~/.llama/sql_store.db",
|
||||
},
|
||||
},
|
||||
"stores": {
|
||||
"metadata": {
|
||||
"namespace": "registry",
|
||||
"backend": "kv_default",
|
||||
},
|
||||
"inference": {
|
||||
"table_name": "inference_store",
|
||||
"backend": "sql_default",
|
||||
"max_write_queue_size": 10000,
|
||||
"num_writers": 4,
|
||||
},
|
||||
"conversations": {
|
||||
"table_name": "openai_conversations",
|
||||
"backend": "sql_default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
5
src/llama_stack/core/conversations/__init__.py
Normal file
5
src/llama_stack/core/conversations/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
314
src/llama_stack/core/conversations/conversations.py
Normal file
314
src/llama_stack/core/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemInclude,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param run_config: Stack run configuration for resolving persistence
|
||||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the conversation service implementation."""
|
||||
impl = ConversationServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ConversationServiceImpl(Conversations):
|
||||
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||
|
||||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.policy = config.policy
|
||||
|
||||
# Use conversations store reference from run config
|
||||
conversations_ref = config.run_config.storage.stores.conversations
|
||||
if not conversations_ref:
|
||||
raise ValueError("storage.stores.conversations must be configured in run config")
|
||||
|
||||
base_sql_store = sqlstore_impl(conversations_ref)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"items": ColumnType.JSON,
|
||||
"metadata": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"conversation_items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"conversation_id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"item_data": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation."""
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
conversation_id = f"conv_{random_bytes.hex()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
record_data = {
|
||||
"id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"items": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="openai_conversations",
|
||||
data=record_data,
|
||||
)
|
||||
|
||||
if items:
|
||||
item_records = []
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
item_records.append(item_record)
|
||||
|
||||
await self.sql_store.insert(table="conversation_items", data=item_records)
|
||||
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
created_at=created_at,
|
||||
metadata=metadata,
|
||||
object="conversation",
|
||||
)
|
||||
|
||||
logger.debug(f"Created conversation {conversation_id}")
|
||||
return conversation
|
||||
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID."""
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
return Conversation(
|
||||
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||
)
|
||||
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID"""
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID."""
|
||||
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
logger.debug(f"Deleted conversation {conversation_id}")
|
||||
return ConversationDeletedResource(id=conversation_id)
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||
"""Validate conversation ID format."""
|
||||
if not conversation_id.startswith("conv_"):
|
||||
raise ValueError(
|
||||
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
)
|
||||
|
||||
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
|
||||
"""Get existing item ID or generate one if missing."""
|
||||
if item.id is None:
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
if item.type == "message":
|
||||
item_id = f"msg_{random_bytes.hex()}"
|
||||
else:
|
||||
item_id = f"item_{random_bytes.hex()}"
|
||||
item_dict["id"] = item_id
|
||||
return item_id
|
||||
return item.id
|
||||
|
||||
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Validate conversation ID and return the conversation if it exists."""
|
||||
self._validate_conversation_id(conversation_id)
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create (add) items to a conversation."""
|
||||
await self._get_validated_conversation(conversation_id)
|
||||
|
||||
created_items = []
|
||||
base_time = int(time.time())
|
||||
|
||||
for i, item in enumerate(items):
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
# make each timestamp unique to maintain order
|
||||
created_at = base_time + i
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
logger.debug(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||
|
||||
# Convert created items (dicts) to proper ConversationItem types
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=created_items[0]["id"] if created_items else None,
|
||||
last_id=created_items[-1]["id"] if created_items else None,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
# Get item from conversation_items table
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
return adapter.validate_python(record["item_data"])
|
||||
|
||||
async def list_items(
|
||||
self,
|
||||
conversation_id: str,
|
||||
after: str | None = None,
|
||||
include: list[ConversationItemInclude] | None = None,
|
||||
limit: int | None = None,
|
||||
order: Literal["asc", "desc"] | None = None,
|
||||
) -> ConversationItemList:
|
||||
"""List items in the conversation."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
|
||||
# check if conversation exists
|
||||
await self.get_conversation(conversation_id)
|
||||
|
||||
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
|
||||
records = result.data
|
||||
|
||||
if order is not None and order == "asc":
|
||||
records.sort(key=lambda x: x["created_at"])
|
||||
else:
|
||||
records.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
actual_limit = limit or 20
|
||||
|
||||
records = records[:actual_limit]
|
||||
items = [record["item_data"] for record in records]
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
|
||||
|
||||
first_id = response_items[0].id if response_items else None
|
||||
last_id = response_items[-1].id if response_items else None
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
_ = await self._get_validated_conversation(conversation_id)
|
||||
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
await self.sql_store.delete(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
logger.debug(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||
return ConversationItemDeletedResource(id=item_id)
|
||||
629
src/llama_stack/core/datatypes.py
Normal file
629
src/llama_stack/core/datatypes.py
Normal file
|
|
@ -0,0 +1,629 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
KVStoreReference,
|
||||
StorageBackendType,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.log import LoggingConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
||||
|
||||
|
||||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class RegistryEntrySource(StrEnum):
|
||||
via_register_api = "via_register_api"
|
||||
listed_from_provider = "listed_from_provider"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
|
||||
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
|
||||
super().__init__(principal=principal, attributes=attributes)
|
||||
|
||||
|
||||
class ResourceWithOwner(Resource):
|
||||
"""Extension of Resource that adds an optional owner, i.e. the user that created the
|
||||
resource. This can be used to constrain access to the resource."""
|
||||
|
||||
owner: User | None = None
|
||||
source: RegistryEntrySource = RegistryEntrySource.via_register_api
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithOwner(Model, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithOwner(Shield, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class VectorStoreWithOwner(VectorStore, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithOwner(Dataset, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorStore | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithOwner
|
||||
| ShieldWithOwner
|
||||
| VectorStoreWithOwner
|
||||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
| ToolGroupWithOwner,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
class AutoRoutedProviderSpec(ProviderSpec):
|
||||
provider_type: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
container_image: str | None = None
|
||||
routing_table_api: Api
|
||||
module: str
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
container_image: str | None = None
|
||||
|
||||
router_api: Api
|
||||
module: str
|
||||
pip_packages: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
# provider_id of None means that the provider is not enabled - this happens
|
||||
# when the provider is enabled via a conditional environment variable
|
||||
provider_id: str | None
|
||||
provider_type: str
|
||||
config: dict[str, Any] = {}
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class BuildProvider(BaseModel):
|
||||
provider_type: str
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class DistributionSpec(BaseModel):
|
||||
description: str | None = Field(
|
||||
default="",
|
||||
description="Description of the distribution",
|
||||
)
|
||||
container_image: str | None = None
|
||||
providers: dict[str, list[BuildProvider]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for telemetry.
|
||||
|
||||
Llama Stack uses OpenTelemetry for telemetry. Please refer to https://opentelemetry.io/docs/languages/sdk-configuration/
|
||||
for env variables to configure the OpenTelemetry SDK.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
OTEL_SERVICE_NAME=llama-stack OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 uv run llama stack run starter
|
||||
```
|
||||
"""
|
||||
|
||||
enabled: bool = Field(default=False, description="enable or disable telemetry")
|
||||
|
||||
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
class OAuth2IntrospectionConfig(BaseModel):
|
||||
url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
|
||||
|
||||
class AuthProviderType(StrEnum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
GITHUB_TOKEN = "github_token"
|
||||
CUSTOM = "custom"
|
||||
KUBERNETES = "kubernetes"
|
||||
|
||||
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
"""Configuration for OAuth2 token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
|
||||
audience: str = Field(default="llama-stack")
|
||||
verify_tls: bool = Field(default=True)
|
||||
tls_cafile: Path | None = Field(default=None)
|
||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
|
||||
introspection: OAuth2IntrospectionConfig | None = Field(
|
||||
default=None, description="OAuth2 introspection configuration"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode(self) -> Self:
|
||||
if not self.jwks and not self.introspection:
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
if self.jwks and self.introspection:
|
||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||
return self
|
||||
|
||||
|
||||
class CustomAuthConfig(BaseModel):
|
||||
"""Configuration for custom authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
|
||||
endpoint: str = Field(
|
||||
...,
|
||||
description="Custom authentication endpoint URL",
|
||||
)
|
||||
|
||||
|
||||
class GitHubTokenAuthConfig(BaseModel):
|
||||
"""Configuration for GitHub token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
|
||||
github_api_base_url: str = Field(
|
||||
default="https://api.github.com",
|
||||
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
|
||||
)
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"login": "roles",
|
||||
"organizations": "teams",
|
||||
},
|
||||
description="Mapping from GitHub user fields to access attributes",
|
||||
)
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
"""Configuration for Kubernetes authentication provider."""
|
||||
|
||||
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
|
||||
api_server_url: str = Field(
|
||||
default="https://kubernetes.default.svc",
|
||||
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
|
||||
)
|
||||
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
|
||||
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"username": "roles",
|
||||
"groups": "roles",
|
||||
},
|
||||
description="Mapping of Kubernetes user claims to access attributes",
|
||||
)
|
||||
|
||||
@field_validator("api_server_url")
|
||||
@classmethod
|
||||
def validate_api_server_url(cls, v):
|
||||
parsed = urlparse(v)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
|
||||
if parsed.scheme not in ["http", "https"]:
|
||||
raise ValueError(f"api_server_url scheme must be http or https: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("claims_mapping")
|
||||
@classmethod
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
|
||||
AuthProviderConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
"""Top-level authentication configuration."""
|
||||
|
||||
provider_config: AuthProviderConfig = Field(
|
||||
...,
|
||||
description="Authentication provider configuration",
|
||||
)
|
||||
access_policy: list[AccessRule] = Field(
|
||||
default=[],
|
||||
description="Rules for determining access to resources",
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QualifiedModel(BaseModel):
|
||||
"""A qualified model identifier, consisting of a provider ID and a model ID."""
|
||||
|
||||
provider_id: str
|
||||
model_id: str
|
||||
|
||||
|
||||
class VectorStoresConfig(BaseModel):
|
||||
"""Configuration for vector stores in the stack."""
|
||||
|
||||
default_provider_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
|
||||
)
|
||||
default_embedding_model: QualifiedModel | None = Field(
|
||||
default=None,
|
||||
description="Default embedding model configuration for vector stores.",
|
||||
)
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
"""Configuration for default moderations model."""
|
||||
|
||||
default_shield_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of the shield to use for when `model` is not specified in the `moderations` API request.",
|
||||
)
|
||||
|
||||
|
||||
class QuotaPeriod(StrEnum):
|
||||
DAY = "day"
|
||||
|
||||
|
||||
class QuotaConfig(BaseModel):
|
||||
kvstore: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
|
||||
authenticated_max_requests: int = Field(
|
||||
default=1000, description="Max requests for authenticated clients per period"
|
||||
)
|
||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||
|
||||
|
||||
class CORSConfig(BaseModel):
|
||||
allow_origins: list[str] = Field(default_factory=list)
|
||||
allow_origin_regex: str | None = Field(default=None)
|
||||
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||
allow_headers: list[str] = Field(default_factory=list)
|
||||
allow_credentials: bool = Field(default=False)
|
||||
expose_headers: list[str] = Field(default_factory=list)
|
||||
max_age: int = Field(default=600, ge=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_credentials_config(self) -> Self:
|
||||
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||
raise ValueError("Cannot use wildcard origins with credentials enabled")
|
||||
return self
|
||||
|
||||
|
||||
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
|
||||
if cors_config is False or cors_config is None:
|
||||
return None
|
||||
|
||||
if cors_config is True:
|
||||
# dev mode: allow localhost on any port
|
||||
return CORSConfig(
|
||||
allow_origins=[],
|
||||
allow_origin_regex=r"https?://localhost:\d+",
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||
)
|
||||
|
||||
if isinstance(cors_config, CORSConfig):
|
||||
return cors_config
|
||||
|
||||
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||
|
||||
|
||||
class RegisteredResources(BaseModel):
|
||||
"""Registry of resources available in the distribution."""
|
||||
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
vector_stores: list[VectorStoreInput] = Field(default_factory=list)
|
||||
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
description="Port to listen on",
|
||||
ge=1024,
|
||||
le=65535,
|
||||
)
|
||||
tls_certfile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS certificate file for HTTPS",
|
||||
)
|
||||
tls_keyfile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS key file for HTTPS",
|
||||
)
|
||||
tls_cafile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
|
||||
)
|
||||
auth: AuthenticationConfig | None = Field(
|
||||
default=None,
|
||||
description="Authentication configuration for the server",
|
||||
)
|
||||
host: str | None = Field(
|
||||
default=None,
|
||||
description="The host the server should listen on",
|
||||
)
|
||||
quota: QuotaConfig | None = Field(
|
||||
default=None,
|
||||
description="Per client quota request configuration",
|
||||
)
|
||||
cors: bool | CORSConfig | None = Field(
|
||||
default=None,
|
||||
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||
"- true: Enable localhost CORS for development\n"
|
||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
image_name: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
|
||||
this could be just a hash
|
||||
""",
|
||||
)
|
||||
container_image: str | None = Field(
|
||||
default=None,
|
||||
description="Reference to the container image if this package refers to a container",
|
||||
)
|
||||
apis: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
|
||||
providers: dict[str, list[Provider]] = Field(
|
||||
description="""
|
||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
storage: StorageConfig = Field(
|
||||
description="Catalog of named storage backends and references available to the stack",
|
||||
)
|
||||
|
||||
registered_resources: RegisteredResources = Field(
|
||||
default_factory=RegisteredResources,
|
||||
description="Registry of resources available in the distribution",
|
||||
)
|
||||
|
||||
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig, description="Configuration for telemetry")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
)
|
||||
|
||||
external_providers_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
vector_stores: VectorStoresConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for vector stores, including default embedding model",
|
||||
)
|
||||
|
||||
safety: SafetyConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for default moderations model",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
return Path(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_server_stores(self) -> "StackRunConfig":
|
||||
backend_map = self.storage.backends
|
||||
stores = self.storage.stores
|
||||
kv_backends = {
|
||||
name
|
||||
for name, cfg in backend_map.items()
|
||||
if cfg.type
|
||||
in {
|
||||
StorageBackendType.KV_REDIS,
|
||||
StorageBackendType.KV_SQLITE,
|
||||
StorageBackendType.KV_POSTGRES,
|
||||
StorageBackendType.KV_MONGODB,
|
||||
}
|
||||
}
|
||||
sql_backends = {
|
||||
name
|
||||
for name, cfg in backend_map.items()
|
||||
if cfg.type in {StorageBackendType.SQL_SQLITE, StorageBackendType.SQL_POSTGRES}
|
||||
}
|
||||
|
||||
def _ensure_backend(reference, expected_set, store_name: str) -> None:
|
||||
if reference is None:
|
||||
return
|
||||
backend_name = reference.backend
|
||||
if backend_name not in backend_map:
|
||||
raise ValueError(
|
||||
f"{store_name} references unknown backend '{backend_name}'. "
|
||||
f"Available backends: {sorted(backend_map)}"
|
||||
)
|
||||
if backend_name not in expected_set:
|
||||
raise ValueError(
|
||||
f"{store_name} references backend '{backend_name}' of type "
|
||||
f"'{backend_map[backend_name].type.value}', but a backend of type "
|
||||
f"{'kv_*' if expected_set is kv_backends else 'sql_*'} is required."
|
||||
)
|
||||
|
||||
_ensure_backend(stores.metadata, kv_backends, "storage.stores.metadata")
|
||||
_ensure_backend(stores.inference, sql_backends, "storage.stores.inference")
|
||||
_ensure_backend(stores.conversations, sql_backends, "storage.stores.conversations")
|
||||
_ensure_backend(stores.responses, sql_backends, "storage.stores.responses")
|
||||
_ensure_backend(stores.prompts, kv_backends, "storage.stores.prompts")
|
||||
return self
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
||||
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
|
||||
image_type: str = Field(
|
||||
default="venv",
|
||||
description="Type of package to build (container | venv)",
|
||||
)
|
||||
image_name: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the distribution to build",
|
||||
)
|
||||
external_providers_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||
"pip_packages MUST contain the provider package name.",
|
||||
)
|
||||
additional_pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
||||
)
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
return Path(v)
|
||||
return v
|
||||
276
src/llama_stack/core/distribution.py
Normal file
276
src/llama_stack/core/distribution.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
return list(Api)
|
||||
|
||||
|
||||
class AutoRoutedApiInfo(BaseModel):
|
||||
routing_table_api: Api
|
||||
router_api: Api
|
||||
|
||||
|
||||
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||
return [
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.models,
|
||||
router_api=Api.inference,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.shields,
|
||||
router_api=Api.safety,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.datasets,
|
||||
router_api=Api.datasetio,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.scoring_functions,
|
||||
router_api=Api.scoring,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.benchmarks,
|
||||
router_api=Api.eval,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.tool_groups,
|
||||
router_api=Api.tool_runtime,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.vector_stores,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def providable_apis() -> list[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api not in INTERNAL_APIS]
|
||||
|
||||
|
||||
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files or from their provided modules.
|
||||
External providers are loaded from a directory structure like:
|
||||
|
||||
providers.d/
|
||||
remote/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
inline/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
|
||||
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
|
||||
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
||||
There is special handling for all of the potential cases this method can be called from.
|
||||
|
||||
Args:
|
||||
config: Optional object containing the external providers directory path
|
||||
building: Optional bool delineating whether or not this is being called from a build process
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the external providers directory doesn't exist
|
||||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
logger.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
# Refresh providable APIs with external APIs if any
|
||||
external_apis = load_external_apis(config)
|
||||
for api, api_spec in external_apis.items():
|
||||
name = api_spec.name.lower()
|
||||
logger.info(f"Importing external API {name} module {api_spec.module}")
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except (ImportError, AttributeError) as e:
|
||||
# Populate the registry with an empty dict to avoid breaking the provider registry
|
||||
# This assume that the in-tree provider(s) are not available for this API which means
|
||||
# that users will need to use external providers for this API.
|
||||
registry[api] = {}
|
||||
logger.error(
|
||||
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
||||
"Install the API package to load any in-tree providers for this API."
|
||||
)
|
||||
|
||||
# Check if config has external providers
|
||||
if config:
|
||||
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
registry = get_external_providers_from_dir(registry, config)
|
||||
# else lets check for modules in each provider
|
||||
registry = get_external_providers_from_module(
|
||||
registry=registry,
|
||||
config=config,
|
||||
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
|
||||
)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_external_providers_from_dir(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
logger.warning(
|
||||
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
||||
)
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in registry[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
registry[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_external_providers_from_module(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
provider_list = None
|
||||
if isinstance(config, BuildConfig):
|
||||
provider_list = config.distribution_spec.providers.items()
|
||||
else:
|
||||
provider_list = config.providers.items()
|
||||
if provider_list is None:
|
||||
logger.warning("Could not get list of providers from config")
|
||||
return registry
|
||||
for provider_api, providers in provider_list:
|
||||
for provider in providers:
|
||||
if not hasattr(provider, "module") or provider.module is None:
|
||||
continue
|
||||
# get provider using module
|
||||
try:
|
||||
if not building:
|
||||
package_name = provider.module.split("==")[0]
|
||||
module = importlib.import_module(f"{package_name}.provider")
|
||||
# if config class is wrong you will get an error saying module could not be imported
|
||||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
is_external=True,
|
||||
module=provider.module,
|
||||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
if isinstance(spec, list):
|
||||
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||
# with the old method, users could pass in directories of specs using overlapping code
|
||||
# we want to ensure we preserve that flexibility in this method.
|
||||
logger.info(
|
||||
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||
)
|
||||
for provider_spec in spec:
|
||||
if provider_spec.provider_type != provider.provider_type:
|
||||
continue
|
||||
logger.info(f"Adding {provider.provider_type} to registry")
|
||||
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||
else:
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
) from exc
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
||||
raise e
|
||||
return registry
|
||||
54
src/llama_stack/core/external.py
Normal file
54
src/llama_stack/core/external.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||
"""Load external API specifications from the configured directory.
|
||||
|
||||
Args:
|
||||
config: StackRunConfig or BuildConfig containing the external APIs directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping API names to their specifications
|
||||
"""
|
||||
if not config or not config.external_apis_dir:
|
||||
return {}
|
||||
|
||||
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
||||
if not external_apis_dir.is_dir():
|
||||
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||
return {}
|
||||
|
||||
logger.info(f"Loading external APIs from {external_apis_dir}")
|
||||
external_apis: dict[Api, ExternalApiSpec] = {}
|
||||
|
||||
# Look for YAML files in the external APIs directory
|
||||
for yaml_path in external_apis_dir.glob("*.yaml"):
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
spec = ExternalApiSpec(**spec_data)
|
||||
api = Api.add(spec.name)
|
||||
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||
external_apis[api] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to load external API spec from {yaml_path}")
|
||||
raise
|
||||
|
||||
return external_apis
|
||||
42
src/llama_stack/core/id_generation.py
Normal file
42
src/llama_stack/core/id_generation.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
IdFactory = Callable[[], str]
|
||||
IdOverride = Callable[[str, IdFactory], str]
|
||||
|
||||
_id_override: IdOverride | None = None
|
||||
|
||||
|
||||
def generate_object_id(kind: str, factory: IdFactory) -> str:
|
||||
"""Generate an identifier for the given kind using the provided factory.
|
||||
|
||||
Allows tests to override ID generation deterministically by installing an
|
||||
override callback via :func:`set_id_override`.
|
||||
"""
|
||||
|
||||
override = _id_override
|
||||
if override is not None:
|
||||
return override(kind, factory)
|
||||
return factory()
|
||||
|
||||
|
||||
def set_id_override(override: IdOverride) -> IdOverride | None:
|
||||
"""Install an override used to generate deterministic identifiers."""
|
||||
|
||||
global _id_override
|
||||
|
||||
previous = _id_override
|
||||
_id_override = override
|
||||
return previous
|
||||
|
||||
|
||||
def reset_id_override(previous: IdOverride | None) -> None:
|
||||
"""Restore the previous override returned by :func:`set_id_override`."""
|
||||
|
||||
global _id_override
|
||||
_id_override = previous
|
||||
86
src/llama_stack/core/inspect.py
Normal file
86
src/llama_stack/core/inspect.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from importlib.metadata import version
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inspect import (
|
||||
HealthInfo,
|
||||
Inspect,
|
||||
ListRoutesResponse,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self, config: DistributionInspectConfig, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
if providers: # Only process if there are providers for this API
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
|
||||
return ListRoutesResponse(data=ret)
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status=HealthStatus.OK)
|
||||
|
||||
async def version(self) -> VersionInfo:
|
||||
return VersionInfo(version=version("llama-stack"))
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
540
src/llama_stack/core/library_client.py
Normal file
540
src/llama_stack/core/library_client.py
Normal file
|
|
@ -0,0 +1,540 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import Response as FastAPIResponse
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
APIResponse,
|
||||
AsyncAPIResponse,
|
||||
AsyncLlamaStackClient,
|
||||
AsyncStream,
|
||||
LlamaStackClient,
|
||||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.build import print_pip_install_help
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
get_stack_run_config_from_distro,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.core.telemetry import Telemetry
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.log import get_logger, setup_logging
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [convert_pydantic_to_json_value(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||
return value
|
||||
|
||||
origin = get_origin(annotation)
|
||||
|
||||
if origin is list:
|
||||
item_type = get_args(annotation)[0]
|
||||
try:
|
||||
return [convert_to_pydantic(item_type, item) for item in value]
|
||||
except Exception:
|
||||
logger.error(f"Error converting list {value} into {item_type}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
key_type, val_type = get_args(annotation)
|
||||
try:
|
||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
logger.error(f"Error converting dict {value} into {val_type}")
|
||||
return value
|
||||
|
||||
try:
|
||||
# Handle Pydantic models and discriminated unions
|
||||
return TypeAdapter(annotation).validate_python(value)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
|
||||
# We should get rid of any non-discriminated unions in the API schema.
|
||||
if origin is Union:
|
||||
for union_type in get_args(annotation):
|
||||
try:
|
||||
return convert_to_pydantic(union_type, value)
|
||||
except Exception:
|
||||
continue
|
||||
logger.warning(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
)
|
||||
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||
|
||||
|
||||
class LibraryClientUploadFile:
|
||||
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
|
||||
|
||||
def __init__(self, filename: str, content: bytes):
|
||||
self.filename = filename
|
||||
self.content = content
|
||||
self.content_type = "application/octet-stream"
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return self.content
|
||||
|
||||
|
||||
class LibraryClientHttpxResponse:
|
||||
"""LibraryClient httpx Response object for FastAPI Response conversion."""
|
||||
|
||||
def __init__(self, response):
|
||||
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
|
||||
self.status_code = response.status_code
|
||||
self.headers = response.headers
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_distro_name: str,
|
||||
skip_logger_removal: bool = False,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||
)
|
||||
self.provider_data = provider_data
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
# use a new event loop to avoid interfering with the main event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(self.async_client.initialize())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Deprecated method for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
loop = self.loop
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if kwargs.get("stream"):
|
||||
|
||||
def sync_generator():
|
||||
try:
|
||||
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||
while True:
|
||||
chunk = loop.run_until_complete(async_stream.__anext__())
|
||||
yield chunk
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
|
||||
return sync_generator()
|
||||
else:
|
||||
try:
|
||||
result = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
return result
|
||||
|
||||
|
||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_distro_name: str,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
skip_logger_removal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Initialize logging from environment variables first
|
||||
setup_logging()
|
||||
|
||||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
if sinks_from_env := os.environ.get("TELEMETRY_SINKS", None):
|
||||
current_sinks = sinks_from_env.strip().lower().split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
if not skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
if config_path_or_distro_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_distro_name)
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Config file {config_path} does not exist")
|
||||
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
# distribution
|
||||
config = get_stack_run_config_from_distro(config_path_or_distro_name)
|
||||
|
||||
self.config_path_or_distro_name = config_path_or_distro_name
|
||||
self.config = config
|
||||
self.custom_provider_registry = custom_provider_registry
|
||||
self.provider_data = provider_data
|
||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
Initialize the async client.
|
||||
|
||||
Returns:
|
||||
bool: True if initialization was successful
|
||||
"""
|
||||
|
||||
try:
|
||||
self.route_impls = None
|
||||
|
||||
stack = Stack(self.config, self.custom_provider_registry)
|
||||
await stack.initialize()
|
||||
self.impls = stack.impls
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, color="red", file=sys.stderr)
|
||||
cprint(
|
||||
"Using llama-stack as a library requires installing dependencies depending on the distribution (providers) you choose.\n",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if self.config_path_or_distro_name.endswith(".yaml"):
|
||||
providers: dict[str, list[BuildProvider]] = {}
|
||||
for api, run_providers in self.config.providers.items():
|
||||
for provider in run_providers:
|
||||
providers.setdefault(api, []).append(
|
||||
BuildProvider(provider_type=provider.provider_type, module=provider.module)
|
||||
)
|
||||
providers = dict(providers)
|
||||
build_config = BuildConfig(
|
||||
distribution_spec=DistributionSpec(
|
||||
providers=providers,
|
||||
),
|
||||
external_providers_dir=self.config.external_providers_dir,
|
||||
)
|
||||
print_pip_install_help(build_config)
|
||||
else:
|
||||
prefix = "!" if in_notebook() else ""
|
||||
cprint(
|
||||
f"Please run:\n\n{prefix}llama stack list-deps {self.config_path_or_distro_name} | xargs -L1 uv pip install\n\n",
|
||||
"yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
cprint(
|
||||
"Please check your internet connection and try again.",
|
||||
"red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise _e
|
||||
|
||||
assert self.impls is not None
|
||||
if self.config.telemetry.enabled:
|
||||
setup_logger(Telemetry())
|
||||
|
||||
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_distro_name}[/blue]:")
|
||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
return True
|
||||
|
||||
async def request(
|
||||
self,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
*,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized. Please call initialize() first.")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = options.headers or {}
|
||||
if self.provider_data:
|
||||
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
|
||||
if all(key not in headers for key in keys):
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
|
||||
# Use context manager for provider data
|
||||
with request_provider_data_context(headers):
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
|
||||
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
|
||||
"""Handle file uploads from OpenAI client and add them to the request body."""
|
||||
if not (hasattr(options, "files") and options.files):
|
||||
return body, []
|
||||
|
||||
if not isinstance(options.files, list):
|
||||
return body, []
|
||||
|
||||
field_names = []
|
||||
for file_tuple in options.files:
|
||||
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
|
||||
continue
|
||||
|
||||
field_name = file_tuple[0]
|
||||
file_object = file_tuple[1]
|
||||
|
||||
if isinstance(file_object, BytesIO):
|
||||
file_object.seek(0)
|
||||
file_content = file_object.read()
|
||||
filename = getattr(file_object, "name", "uploaded_file")
|
||||
field_names.append(field_name)
|
||||
body[field_name] = LibraryClientUploadFile(filename, file_content)
|
||||
|
||||
return body, field_names
|
||||
|
||||
async def _call_non_streaming(
|
||||
self,
|
||||
*,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
):
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
|
||||
if hasattr(options, "extra_json") and options.extra_json:
|
||||
body |= options.extra_json
|
||||
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Handle FastAPI Response objects (e.g., from file content retrieval)
|
||||
if isinstance(result, FastAPIResponse):
|
||||
return LibraryClientHttpxResponse(result)
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
|
||||
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
|
||||
|
||||
status_code = httpx.codes.OK
|
||||
|
||||
if options.method.upper() == "DELETE" and result is None:
|
||||
status_code = httpx.codes.NO_CONTENT
|
||||
|
||||
if status_code == httpx.codes.NO_CONTENT:
|
||||
json_content = ""
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=status_code,
|
||||
content=json_content.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=convert_pydantic_to_json_value(filtered_body),
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
)
|
||||
return response.parse()
|
||||
|
||||
async def _call_streaming(
|
||||
self,
|
||||
*,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
stream_cls: Any,
|
||||
):
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=wrapped_gen,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=convert_pydantic_to_json_value(body),
|
||||
),
|
||||
)
|
||||
|
||||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=True,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||
body = body or {}
|
||||
exclude_params = exclude_params or set()
|
||||
sig = inspect.signature(func)
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
|
||||
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
param_type = param.annotation
|
||||
if is_unwrapped_body_param(param_type):
|
||||
base_type = get_args(param_type)[0]
|
||||
return {param.name: base_type(**body)}
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
||||
# Check if there's an unwrapped body parameter among multiple parameters
|
||||
# (e.g., path param + body param like: vector_store_id: str, params: Annotated[Model, Body(...)])
|
||||
unwrapped_body_param = None
|
||||
for param in params_list:
|
||||
if is_unwrapped_body_param(param.annotation):
|
||||
unwrapped_body_param = param
|
||||
break
|
||||
|
||||
# Convert parameters to Pydantic models where needed
|
||||
converted_body = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
if param_name in exclude_params:
|
||||
converted_body[param_name] = value
|
||||
else:
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
|
||||
# handle unwrapped body parameter after processing all named parameters
|
||||
if unwrapped_body_param:
|
||||
base_type = get_args(unwrapped_body_param.annotation)[0]
|
||||
# extract only keys not already used by other params
|
||||
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
|
||||
converted_body[unwrapped_body_param.name] = base_type(**remaining_keys)
|
||||
|
||||
return converted_body
|
||||
5
src/llama_stack/core/prompts/__init__.py
Normal file
5
src/llama_stack/core/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
232
src/llama_stack/core/prompts/prompts.py
Normal file
232
src/llama_stack/core/prompts/prompts.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
||||
|
||||
class PromptServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in prompt service.
|
||||
|
||||
:param run_config: Stack run configuration containing distribution info
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the prompt service implementation."""
|
||||
impl = PromptServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class PromptServiceImpl(Prompts):
|
||||
"""Built-in prompt service implementation using KVStore."""
|
||||
|
||||
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.kvstore: KVStore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# Use prompts store reference from run config
|
||||
prompts_ref = self.config.run_config.storage.stores.prompts
|
||||
if not prompts_ref:
|
||||
raise ValueError("storage.stores.prompts must be configured in run config")
|
||||
self.kvstore = await kvstore_impl(prompts_ref)
|
||||
|
||||
def _get_default_key(self, prompt_id: str) -> str:
|
||||
"""Get the KVStore key that stores the default version number."""
|
||||
return f"prompts:v1:{prompt_id}:default"
|
||||
|
||||
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
|
||||
"""Get the KVStore key for prompt data, returning default version if applicable."""
|
||||
if version:
|
||||
return self._get_version_key(prompt_id, str(version))
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
resolved_version = await self.kvstore.get(default_key)
|
||||
if resolved_version is None:
|
||||
raise ValueError(f"Prompt {prompt_id}:default not found")
|
||||
return self._get_version_key(prompt_id, resolved_version)
|
||||
|
||||
def _get_version_key(self, prompt_id: str, version: str) -> str:
|
||||
"""Get the KVStore key for a specific prompt version."""
|
||||
return f"prompts:v1:{prompt_id}:{version}"
|
||||
|
||||
def _get_list_key_prefix(self) -> str:
|
||||
"""Get the key prefix for listing prompts."""
|
||||
return "prompts:v1:"
|
||||
|
||||
def _serialize_prompt(self, prompt: Prompt) -> str:
|
||||
"""Serialize a prompt to JSON string for storage."""
|
||||
return json.dumps(
|
||||
{
|
||||
"prompt_id": prompt.prompt_id,
|
||||
"prompt": prompt.prompt,
|
||||
"version": prompt.version,
|
||||
"variables": prompt.variables or [],
|
||||
"is_default": prompt.is_default,
|
||||
}
|
||||
)
|
||||
|
||||
def _deserialize_prompt(self, data: str) -> Prompt:
|
||||
"""Deserialize a prompt from JSON string."""
|
||||
obj = json.loads(data)
|
||||
return Prompt(
|
||||
prompt_id=obj["prompt_id"],
|
||||
prompt=obj["prompt"],
|
||||
version=obj["version"],
|
||||
variables=obj.get("variables", []),
|
||||
is_default=obj.get("is_default", False),
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts (default versions only)."""
|
||||
prefix = self._get_list_key_prefix()
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
prompts = []
|
||||
for key in keys:
|
||||
if key.endswith(":default"):
|
||||
try:
|
||||
default_version = await self.kvstore.get(key)
|
||||
if default_version:
|
||||
prompt_id = key.replace(prefix, "").replace(":default", "")
|
||||
version_key = self._get_version_key(prompt_id, default_version)
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data:
|
||||
prompt = self._deserialize_prompt(data)
|
||||
prompts.append(prompt)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version."""
|
||||
key = await self._get_prompt_key(prompt_id, version)
|
||||
data = await self.kvstore.get(key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
|
||||
return self._deserialize_prompt(data)
|
||||
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt."""
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_obj = Prompt(
|
||||
prompt_id=Prompt.generate_prompt_id(),
|
||||
prompt=prompt,
|
||||
version=1,
|
||||
variables=variables,
|
||||
)
|
||||
|
||||
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
|
||||
data = self._serialize_prompt(prompt_obj)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
default_key = self._get_default_key(prompt_obj.prompt_id)
|
||||
await self.kvstore.set(default_key, str(prompt_obj.version))
|
||||
|
||||
return prompt_obj
|
||||
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
version: int,
|
||||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version)."""
|
||||
if version < 1:
|
||||
raise ValueError("Version must be >= 1")
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_versions = await self.list_prompt_versions(prompt_id)
|
||||
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
|
||||
|
||||
if version and latest_prompt.version != version:
|
||||
raise ValueError(
|
||||
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
|
||||
)
|
||||
|
||||
current_version = latest_prompt.version if version is None else version
|
||||
new_version = current_version + 1
|
||||
|
||||
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
||||
|
||||
version_key = self._get_version_key(prompt_id, str(new_version))
|
||||
data = self._serialize_prompt(updated_prompt)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
if set_as_default:
|
||||
await self.set_default_version(prompt_id, new_version)
|
||||
|
||||
return updated_prompt
|
||||
|
||||
async def delete_prompt(self, prompt_id: str) -> None:
|
||||
"""Delete a prompt and all its versions."""
|
||||
await self.get_prompt(prompt_id)
|
||||
|
||||
prefix = f"prompts:v1:{prompt_id}:"
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
for key in keys:
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt."""
|
||||
prefix = f"prompts:v1:{prompt_id}:"
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
default_version = None
|
||||
prompts = []
|
||||
|
||||
for key in keys:
|
||||
data = await self.kvstore.get(key)
|
||||
if key.endswith(":default"):
|
||||
default_version = data
|
||||
else:
|
||||
if data:
|
||||
prompt_obj = self._deserialize_prompt(data)
|
||||
prompts.append(prompt_obj)
|
||||
|
||||
if not prompts:
|
||||
raise ValueError(f"Prompt {prompt_id} not found")
|
||||
|
||||
for prompt in prompts:
|
||||
prompt.is_default = str(prompt.version) == default_version
|
||||
|
||||
prompts.sort(key=lambda x: x.version)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
|
||||
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
|
||||
version_key = self._get_version_key(prompt_id, str(version))
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
await self.kvstore.set(default_key, str(version))
|
||||
|
||||
return self._deserialize_prompt(data)
|
||||
137
src/llama_stack/core/providers.py
Normal file
137
src/llama_stack/core/providers.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = ProviderImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ProviderImpl(Providers):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ProviderImpl.shutdown")
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
providers_health = await self.get_providers_health()
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
for p in providers:
|
||||
# Skip providers that are not enabled
|
||||
if p.provider_id is None:
|
||||
continue
|
||||
ret.append(
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
health=providers_health.get(api, {}).get(
|
||||
p.provider_id,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||
Each API maps to a dictionary of provider IDs to their health responses.
|
||||
"""
|
||||
providers_health: dict[str, dict[str, HealthResponse]] = {}
|
||||
|
||||
# The timeout has to be long enough to allow all the providers to be checked, especially in
|
||||
# the case of the inference router health check since it checks all registered inference
|
||||
# providers.
|
||||
# The timeout must not be equal to the one set by health method for a given implementation,
|
||||
# otherwise we will miss some providers.
|
||||
timeout = 3.0
|
||||
|
||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||
# Skip special implementations (inspect/providers) that don't have provider specs
|
||||
if not hasattr(impl, "__provider_spec__"):
|
||||
return None
|
||||
api_name = impl.__provider_spec__.api.name
|
||||
if not hasattr(impl, "health"):
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
return api_name, health
|
||||
except TimeoutError:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
||||
)
|
||||
|
||||
# Create tasks for all providers
|
||||
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
||||
|
||||
# Wait for all health checks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Organize results by API and provider ID
|
||||
for result in results:
|
||||
if result is None: # Skip special implementations
|
||||
continue
|
||||
api_name, health_response = result
|
||||
providers_health[api_name] = health_response
|
||||
|
||||
return providers_health
|
||||
115
src/llama_stack/core/request_headers.py
Normal file
115
src/llama_stack/core/request_headers.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(AbstractContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
|
||||
self.provider_data = provider_data or {}
|
||||
if user:
|
||||
self.provider_data["__authenticated_user"] = user
|
||||
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
PROVIDER_DATA_VAR.reset(self.token)
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
def get_request_provider_data(self) -> Any:
|
||||
spec = self.__provider_spec__
|
||||
if not spec:
|
||||
raise ValueError(f"Provider spec not set on {self.__class__}")
|
||||
|
||||
provider_type = spec.provider_type
|
||||
validator_class = spec.provider_data_validator
|
||||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = PROVIDER_DATA_VAR.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
validator = instantiate_class_type(validator_class)
|
||||
try:
|
||||
provider_data = validator(**val)
|
||||
return provider_data
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing provider data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
|
||||
"""Parse provider data from request headers"""
|
||||
keys = [
|
||||
"X-LlamaStack-Provider-Data",
|
||||
"x-llamastack-provider-data",
|
||||
]
|
||||
val = None
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Provider data not encoded as a JSON object!")
|
||||
return None
|
||||
|
||||
|
||||
def request_provider_data_context(
|
||||
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
|
||||
) -> AbstractContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_authenticated_user() -> User | None:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__authenticated_user")
|
||||
|
||||
|
||||
def user_from_scope(scope: dict) -> User | None:
|
||||
"""Create a User object from ASGI scope data (set by authentication middleware)"""
|
||||
user_attributes = scope.get("user_attributes", {})
|
||||
principal = scope.get("principal", "")
|
||||
|
||||
# auth not enabled
|
||||
if not principal and not user_attributes:
|
||||
return None
|
||||
|
||||
return User(principal=principal, attributes=user_attributes)
|
||||
482
src/llama_stack/core/resolver.py
Normal file
482
src/llama_stack/core/resolver.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
AutoRoutedProviderSpec,
|
||||
Provider,
|
||||
RoutingTableProviderSpec,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
DatasetsProtocolPrivate,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
RemoteProviderSpec,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ShieldsProtocolPrivate,
|
||||
ToolGroupsProtocolPrivate,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
|
||||
"""Get a mapping of API types to their protocol classes.
|
||||
|
||||
Args:
|
||||
external_apis: Optional dictionary of external API specifications
|
||||
|
||||
Returns:
|
||||
Dictionary mapping API types to their protocol classes
|
||||
"""
|
||||
protocols = {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_stores: VectorStore,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
Api.datasetio: DatasetIO,
|
||||
Api.datasets: Datasets,
|
||||
Api.scoring: Scoring,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.eval: Eval,
|
||||
Api.benchmarks: Benchmarks,
|
||||
Api.post_training: PostTraining,
|
||||
Api.tool_groups: ToolGroups,
|
||||
Api.tool_runtime: ToolRuntime,
|
||||
Api.files: Files,
|
||||
Api.prompts: Prompts,
|
||||
Api.conversations: Conversations,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
for api, api_spec in external_apis.items():
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
api_class = getattr(module, api_spec.protocol)
|
||||
|
||||
protocols[api] = api_class
|
||||
except (ImportError, AttributeError):
|
||||
logger.exception(f"Failed to load external API {api_spec.name}")
|
||||
|
||||
return protocols
|
||||
|
||||
|
||||
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
|
||||
external_apis = load_external_apis(config)
|
||||
return {
|
||||
**api_protocol_map(external_apis),
|
||||
Api.inference: InferenceProvider,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ScoringFunctions,
|
||||
Api.scoring_functions,
|
||||
),
|
||||
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
|
||||
}
|
||||
|
||||
|
||||
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
||||
class ProviderWithSpec(Provider):
|
||||
spec: ProviderSpec
|
||||
|
||||
|
||||
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any] | None = None,
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by:
|
||||
1. Validating and organizing providers.
|
||||
2. Sorting them in dependency order.
|
||||
3. Instantiating them with required dependencies.
|
||||
"""
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||
|
||||
providers_with_specs = validate_and_prepare_providers(
|
||||
run_config, provider_registry, routing_table_apis, router_apis
|
||||
)
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
"""Generates specifications for automatically routed APIs."""
|
||||
specs = {}
|
||||
for info in builtin_automatically_routed_apis():
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__routing_table__",
|
||||
provider_type="__routing_table__",
|
||||
config={},
|
||||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
module="llama_stack.core.routers",
|
||||
api_dependencies=[],
|
||||
deps__=[f"inner-{info.router_api.value}"],
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__autorouted__",
|
||||
provider_type="__autorouted__",
|
||||
config={},
|
||||
spec=AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.core.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
return specs
|
||||
|
||||
|
||||
def validate_and_prepare_providers(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
||||
) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
|
||||
return providers_with_specs
|
||||
|
||||
|
||||
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
||||
"""Validates if the provider is allowed and handles deprecations."""
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logger.error(p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logger.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
||||
def sort_providers_by_deps(
|
||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||
) -> list[tuple[str, ProviderWithSpec]]:
|
||||
"""Sorts providers based on their dependencies."""
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
|
||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
return sorted_providers
|
||||
|
||||
|
||||
async def instantiate_providers(
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
||||
router_apis: set[Api],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any] | None = None,
|
||||
) -> dict[Api, Any]:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
|
||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
# Skip providers that are not enabled
|
||||
if provider.provider_id is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
except KeyError as e:
|
||||
missing_api = e.args[0]
|
||||
raise RuntimeError(
|
||||
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
|
||||
f"required dependency '{missing_api.value}' is not available. "
|
||||
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
|
||||
) from e
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
||||
inner_impls = {}
|
||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||
|
||||
if api_str.startswith("inner-"):
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def topological_sort(
|
||||
providers_with_specs: dict[str, list[ProviderWithSpec]],
|
||||
) -> list[tuple[str, ProviderWithSpec]]:
|
||||
def dfs(kv, visited: set[str], stack: list[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
||||
deps = []
|
||||
for provider in providers:
|
||||
for dep in provider.spec.deps__:
|
||||
deps.append(dep)
|
||||
|
||||
for dep in deps:
|
||||
if dep not in visited and dep in providers_with_specs:
|
||||
dfs((dep, providers_with_specs[dep]), visited, stack)
|
||||
|
||||
stack.append(api_str)
|
||||
|
||||
visited: set[str] = set()
|
||||
stack: list[str] = []
|
||||
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
dfs((api_str, providers), visited, stack)
|
||||
|
||||
flattened = []
|
||||
for api_str in stack:
|
||||
for provider in providers_with_specs[api_str]:
|
||||
flattened.append((api_str, provider))
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider: ProviderWithSpec,
|
||||
deps: dict[Api, Any],
|
||||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy]
|
||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||
args.append(policy)
|
||||
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
||||
args.append(run_config.telemetry.enabled)
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||
missing_methods = []
|
||||
|
||||
mro = type(obj).__mro__
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethods__"):
|
||||
has_alpha_api = False
|
||||
for webmethod in value.__webmethods__:
|
||||
if webmethod.level == LLAMA_STACK_API_V1ALPHA:
|
||||
has_alpha_api = True
|
||||
break
|
||||
# if this API has multiple webmethods, and one of them is an alpha API, this API should be skipped when checking for missing or not callable routes
|
||||
if has_alpha_api:
|
||||
continue
|
||||
if not hasattr(obj, name):
|
||||
missing_methods.append((name, "missing"))
|
||||
elif not callable(getattr(obj, name)):
|
||||
missing_methods.append((name, "not_callable"))
|
||||
else:
|
||||
# Check if the method signatures are compatible
|
||||
obj_method = getattr(obj, name)
|
||||
proto_sig = inspect.signature(value)
|
||||
obj_sig = inspect.signature(obj_method)
|
||||
|
||||
proto_params = set(proto_sig.parameters)
|
||||
proto_params.discard("self")
|
||||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method has a concrete implementation (not just a protocol stub)
|
||||
# Find all classes in MRO that define this method
|
||||
method_owners = [cls for cls in mro if name in cls.__dict__]
|
||||
|
||||
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
|
||||
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
|
||||
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
raise ValueError(
|
||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_remote_stack_impls(
|
||||
config: RemoteProviderConfig,
|
||||
apis: list[str],
|
||||
) -> dict[Api, Any]:
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
impls = {}
|
||||
for api_str in apis:
|
||||
api = Api(api_str)
|
||||
impls[api] = await get_client_impl(
|
||||
protocols[api],
|
||||
config,
|
||||
{},
|
||||
)
|
||||
if api in additional_protocols:
|
||||
_, additional_protocol, additional_api = additional_protocols[api]
|
||||
impls[additional_api] = await get_client_impl(
|
||||
additional_protocol,
|
||||
config,
|
||||
{},
|
||||
)
|
||||
|
||||
return impls
|
||||
96
src/llama_stack/core/routers/__init__.py
Normal file
96
src/llama_stack/core/routers/__init__.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
RoutedProtocol,
|
||||
)
|
||||
from llama_stack.core.stack import StackRunConfig
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
) -> Any:
|
||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||
from ..routing_tables.models import ModelsRoutingTable
|
||||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from ..routing_tables.shields import ShieldsRoutingTable
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from ..routing_tables.vector_stores import VectorStoresRoutingTable
|
||||
|
||||
api_to_tables = {
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"benchmarks": BenchmarksRoutingTable,
|
||||
"tool_groups": ToolGroupsRoutingTable,
|
||||
"vector_stores": VectorStoresRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(
|
||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule]
|
||||
) -> Any:
|
||||
from .datasets import DatasetIORouter
|
||||
from .eval_scoring import EvalRouter, ScoringRouter
|
||||
from .inference import InferenceRouter
|
||||
from .safety import SafetyRouter
|
||||
from .tool_runtime import ToolRuntimeRouter
|
||||
from .vector_io import VectorIORouter
|
||||
|
||||
api_to_routers = {
|
||||
"vector_io": VectorIORouter,
|
||||
"inference": InferenceRouter,
|
||||
"safety": SafetyRouter,
|
||||
"datasetio": DatasetIORouter,
|
||||
"scoring": ScoringRouter,
|
||||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
api_to_dep_impl = {}
|
||||
# TODO: move pass configs to routers instead
|
||||
if api == Api.inference:
|
||||
inference_ref = run_config.storage.stores.inference
|
||||
if not inference_ref:
|
||||
raise ValueError("storage.stores.inference must be configured in run config")
|
||||
|
||||
inference_store = InferenceStore(
|
||||
reference=inference_ref,
|
||||
policy=policy,
|
||||
)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled
|
||||
|
||||
elif api == Api.vector_io:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
elif api == Api.safety:
|
||||
api_to_dep_impl["safety_config"] = run_config.safety
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
73
src/llama_stack/core/routers/datasets.py
Normal file
73
src/llama_stack/core/routers/datasets.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class DatasetIORouter(DatasetIO):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing DatasetIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("DatasetIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||
)
|
||||
await self.routing_table.register_dataset(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
metadata=metadata,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
start_index=start_index,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
)
|
||||
155
src/llama_stack/core/routers/eval_scoring.py
Normal file
155
src/llama_stack/core/routers/eval_scoring.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class ScoringRouter(Scoring):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ScoringRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("ScoringRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ScoringRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
score_response = await provider.score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
if save_results_dataset:
|
||||
raise NotImplementedError("Save results dataset not implemented yet")
|
||||
|
||||
return ScoreBatchResponse(
|
||||
results=res,
|
||||
)
|
||||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
score_response = await provider.score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
return ScoreResponse(results=res)
|
||||
|
||||
|
||||
class EvalRouter(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing EvalRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("EvalRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("EvalRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
async def job_status(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
await provider.job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
||||
async def job_result(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
608
src/llama_stack/core/routers/inference.py
Normal file
608
src/llama_stack/core/routers/inference.py
Normal file
|
|
@ -0,0 +1,608 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Body
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
Order,
|
||||
RerankResponse,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
"""Routes to an provider based on the model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
store: InferenceStore | None = None,
|
||||
telemetry_enabled: bool = False,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
self.store = store
|
||||
if self.telemetry_enabled:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
if self.store:
|
||||
try:
|
||||
await self.store.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> list[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
model: Model object containing model_id and provider_id
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
"""
|
||||
span = get_current_span()
|
||||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=datetime.now(UTC),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> list[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry_enabled:
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: list[Message] | InterleavedContent,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> int | None:
|
||||
if not hasattr(self, "formatter") or self.formatter is None:
|
||||
return None
|
||||
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
|
||||
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type != expected_model_type:
|
||||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
logger.debug(f"InferenceRouter.rerank: {model}")
|
||||
model_obj = await self._get_model(model, ModelType.rerank)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.rerank(
|
||||
model=model_obj.identifier,
|
||||
query=query,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
|
||||
response = await provider.openai_completion(params)
|
||||
if self.telemetry_enabled:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if params.tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
||||
if params.tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
if params.tools:
|
||||
for tool in params.tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
# Some providers make tool calls even when tool_choice is "none"
|
||||
# so just clear them both out to avoid unexpected tool calls
|
||||
if params.tool_choice == "none" and params.tools is not None:
|
||||
params.tool_choice = None
|
||||
params.tools = None
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if params.stream:
|
||||
response_stream = await provider.openai_chat_completion(params)
|
||||
|
||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
messages=params.messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry_enabled:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.embedding)
|
||||
|
||||
# Update model to use resolved identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_embeddings(params)
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 20,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
if self.store:
|
||||
return await self.store.list_chat_completions(after, limit, model, order)
|
||||
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
||||
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
if self.store:
|
||||
return await self.store.get_chat_completion(completion_id)
|
||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||
|
||||
async def _nonstream_openai_chat_completion(
|
||||
self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody
|
||||
) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(params)
|
||||
for choice in response.choices:
|
||||
# some providers return an empty list for no tool calls in non-streaming responses
|
||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
||||
choice.message.tool_calls = None
|
||||
return response
|
||||
|
||||
async def health(self) -> dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
timeout = 1 # increasing the timeout to 1 second for health checks
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
try:
|
||||
# check if the provider has a health method
|
||||
if not hasattr(impl, "health"):
|
||||
continue
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
health_statuses[provider_id] = health
|
||||
except TimeoutError:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"Health check timed out after {timeout} seconds",
|
||||
)
|
||||
except NotImplementedError:
|
||||
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
except Exception as e:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
async def stream_tokens_and_compute_metrics(
|
||||
self,
|
||||
response,
|
||||
prompt_tokens,
|
||||
model,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
complete = False
|
||||
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
else:
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
# if we are done receiving tokens
|
||||
if complete:
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for streaming completion metrics
|
||||
if self.telemetry_enabled:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in [
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
else:
|
||||
# Fallback if no telemetry
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
yield chunk
|
||||
|
||||
async def count_tokens_and_compute_metrics(
|
||||
self,
|
||||
response: ChatCompletionResponse | CompletionResponse,
|
||||
prompt_tokens,
|
||||
model,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
):
|
||||
if isinstance(response, ChatCompletionResponse):
|
||||
content = [response.completion_message]
|
||||
else:
|
||||
content = response.content
|
||||
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for completion metrics
|
||||
if self.telemetry_enabled:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
||||
# Fallback if no telemetry
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||
self,
|
||||
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: Model,
|
||||
messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
||||
id = None
|
||||
created = None
|
||||
choices_data: dict[int, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
async for chunk in response:
|
||||
# Skip None chunks
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
# Capture ID and created timestamp from first chunk
|
||||
if id is None and chunk.id:
|
||||
id = chunk.id
|
||||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
# Accumulate choice data for final assembly
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
idx = choice_delta.index
|
||||
if idx not in choices_data:
|
||||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": "stop",
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
||||
if choice_delta.delta:
|
||||
delta = choice_delta.delta
|
||||
if delta.content:
|
||||
current_choice_data["content_parts"].append(delta.content)
|
||||
if delta.tool_calls:
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
tc_idx = tool_call_delta.index
|
||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function_name_parts": [],
|
||||
"function_arguments_parts": [],
|
||||
}
|
||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||
if tool_call_delta.id:
|
||||
builder["id"] = tool_call_delta.id
|
||||
if tool_call_delta.type:
|
||||
builder["type"] = tool_call_delta.type
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||
if tool_call_delta.function.arguments:
|
||||
builder["function_arguments_parts"].append(
|
||||
tool_call_delta.function.arguments
|
||||
)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
|
||||
# Compute metrics on final chunk
|
||||
if chunk.choices and chunk.choices[0].finish_reason:
|
||||
completion_text = ""
|
||||
for choice_data in choices_data.values():
|
||||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
# Store the final assembled completion
|
||||
if id and self.store and messages:
|
||||
assembled_choices: list[OpenAIChoice] = []
|
||||
for choice_idx, choice_data in choices_data.items():
|
||||
content_str = "".join(choice_data["content_parts"])
|
||||
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
||||
if choice_data["tool_calls_builder"]:
|
||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||
if tc_build_data["id"]:
|
||||
func_name = "".join(tc_build_data["function_name_parts"])
|
||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||
assembled_tool_calls.append(
|
||||
OpenAIChatCompletionToolCall(
|
||||
id=tc_build_data["id"],
|
||||
type=tc_build_data["type"],
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=func_name, arguments=func_args
|
||||
),
|
||||
)
|
||||
)
|
||||
message = OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content_str if content_str else None,
|
||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||
)
|
||||
logprobs_content = choice_data["logprobs_content_parts"]
|
||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
assembled_choices.append(
|
||||
OpenAIChoice(
|
||||
finish_reason=choice_data["finish_reason"],
|
||||
index=choice_idx,
|
||||
message=message,
|
||||
logprobs=final_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
final_response = OpenAIChatCompletion(
|
||||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(time.time()),
|
||||
model=model.identifier,
|
||||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||
109
src/llama_stack/core/routers/safety.py
Normal file
109
src/llama_stack/core/routers/safety.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
safety_config: SafetyConfig | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
self.safety_config = safety_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
||||
return await self.routing_table.unregister_shield(identifier)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
return await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
shields = list_shields_response.data
|
||||
|
||||
selected_shield: Shield | None = None
|
||||
provider_model: str | None = model
|
||||
|
||||
if model:
|
||||
matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
|
||||
if not matches:
|
||||
raise ValueError(
|
||||
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
|
||||
)
|
||||
if len(matches) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
|
||||
)
|
||||
selected_shield = matches[0]
|
||||
else:
|
||||
default_shield_id = self.safety_config.default_shield_id if self.safety_config else None
|
||||
if not default_shield_id:
|
||||
raise ValueError(
|
||||
"No moderation model specified and no default_shield_id configured in safety config: select model "
|
||||
f"from {[s.provider_resource_id or s.identifier for s in shields]}"
|
||||
)
|
||||
|
||||
selected_shield = next((s for s in shields if s.identifier == default_shield_id), None)
|
||||
if selected_shield is None:
|
||||
raise ValueError(
|
||||
f"Default moderation model not found. Choose from {[s.provider_resource_id or s.identifier for s in shields]}."
|
||||
)
|
||||
|
||||
provider_model = selected_shield.provider_resource_id
|
||||
|
||||
shield_id = selected_shield.identifier
|
||||
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
|
||||
response = await provider.run_moderation(
|
||||
input=input,
|
||||
model=provider_model,
|
||||
)
|
||||
|
||||
return response
|
||||
91
src/llama_stack/core/routers/tool_runtime.py
Normal file
91
src/llama_stack/core/routers/tool_runtime.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
class RagToolImpl(RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_store_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_store_ids, query_config)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_store_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
self.rag_tool = self.RagToolImpl(routing_table)
|
||||
for method in ("query", "insert"):
|
||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
442
src/llama_stack/core/routers/vector_io.py
Normal file
442
src/llama_stack/core/routers/vector_io.py
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Body
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
QueryChunksResponse,
|
||||
SearchRankingOptions,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFilesListInBatchResponse,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
"""Routes to an provider based on the vector db identifier"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
self.vector_stores_config = vector_stores_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
|
||||
"""Get the embedding dimension for a specific embedding model."""
|
||||
all_models = await self.routing_table.get_all_with_type("model")
|
||||
|
||||
for model in all_models:
|
||||
if model.identifier == embedding_model_id and model.model_type == ModelType.embedding:
|
||||
dimension = model.metadata.get("embedding_dimension")
|
||||
if dimension is None:
|
||||
raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata")
|
||||
return int(dimension)
|
||||
|
||||
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
doc_ids = [chunk.document_id for chunk in chunks[:3]]
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, "
|
||||
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.query_chunks(vector_db_id, query, params)
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
|
||||
) -> VectorStoreObject:
|
||||
# Extract llama-stack-specific parameters from extra_body
|
||||
extra = params.model_extra or {}
|
||||
embedding_model = extra.get("embedding_model")
|
||||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
# Use default embedding model if not specified
|
||||
if (
|
||||
embedding_model is None
|
||||
and self.vector_stores_config
|
||||
and self.vector_stores_config.default_embedding_model is not None
|
||||
):
|
||||
# Construct the full model ID with provider prefix
|
||||
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
|
||||
model_id = self.vector_stores_config.default_embedding_model.model_id
|
||||
embedding_model = f"{embedding_provider_id}/{model_id}"
|
||||
|
||||
if embedding_model is not None and embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
# Auto-select provider if not specified
|
||||
if provider_id is None:
|
||||
num_providers = len(self.routing_table.impls_by_provider_id)
|
||||
if num_providers == 0:
|
||||
raise ValueError("No vector_io providers available")
|
||||
if num_providers > 1:
|
||||
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
||||
# Use default configured provider
|
||||
if self.vector_stores_config and self.vector_stores_config.default_provider_id:
|
||||
default_provider = self.vector_stores_config.default_provider_id
|
||||
if default_provider in available_providers:
|
||||
provider_id = default_provider
|
||||
logger.debug(f"Using configured default vector store provider: {provider_id}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Configured default vector store provider '{default_provider}' not found. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
else:
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
vector_store_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_store = await self.routing_table.register_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=provider_id,
|
||||
provider_vector_store_id=vector_store_id,
|
||||
vector_store_name=params.name,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
|
||||
|
||||
# Update model_extra with registered values so provider uses the already-registered vector_store
|
||||
if params.model_extra is None:
|
||||
params.model_extra = {}
|
||||
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
|
||||
params.model_extra["provider_id"] = registered_vector_store.provider_id
|
||||
if embedding_model is not None:
|
||||
params.model_extra["embedding_model"] = embedding_model
|
||||
if embedding_dimension is not None:
|
||||
params.model_extra["embedding_dimension"] = embedding_dimension
|
||||
|
||||
return await provider.openai_create_vector_store(params)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||
# Route to default provider for now - could aggregate from all providers in the future
|
||||
# call retrieve on each vector dbs to get list of vector stores
|
||||
vector_stores = await self.routing_table.get_all_with_type("vector_store")
|
||||
all_stores = []
|
||||
for vector_store in vector_stores:
|
||||
try:
|
||||
provider = await self.routing_table.get_provider_impl(vector_store.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier)
|
||||
all_stores.append(vector_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by created_at
|
||||
reverse_order = order == "desc"
|
||||
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
|
||||
|
||||
# Apply cursor-based pagination
|
||||
if after:
|
||||
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
|
||||
if after_index >= 0:
|
||||
all_stores = all_stores[after_index + 1 :]
|
||||
|
||||
if before:
|
||||
before_index = next(
|
||||
(i for i, store in enumerate(all_stores) if store.id == before),
|
||||
len(all_stores),
|
||||
)
|
||||
all_stores = all_stores[:before_index]
|
||||
|
||||
# Apply limit
|
||||
limited_stores = all_stores[:limit]
|
||||
|
||||
# Determine pagination info
|
||||
has_more = len(all_stores) > limit
|
||||
first_id = limited_stores[0].id if limited_stores else None
|
||||
last_id = limited_stores[-1].id if limited_stores else None
|
||||
|
||||
return VectorStoreListResponse(
|
||||
data=limited_stores,
|
||||
has_more=has_more,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def health(self) -> dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
timeout = 1 # increasing the timeout to 1 second for health checks
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
try:
|
||||
# check if the provider has a health method
|
||||
if not hasattr(impl, "health"):
|
||||
continue
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
health_statuses[provider_id] = health
|
||||
except TimeoutError:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"Health check timed out after {timeout} seconds",
|
||||
)
|
||||
except NotImplementedError:
|
||||
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
except Exception as e:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(
|
||||
f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_create_vector_store_file_batch(vector_store_id, params)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
5
src/llama_stack/core/routing_tables/__init__.py
Normal file
5
src/llama_stack/core/routing_tables/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
62
src/llama_stack/core/routing_tables/benchmarks.py
Normal file
62
src/llama_stack/core/routing_tables/benchmarks.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.core.datatypes import (
|
||||
BenchmarkWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||
|
||||
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
||||
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||
if benchmark is None:
|
||||
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
||||
return benchmark
|
||||
|
||||
async def register_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: list[str],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
provider_benchmark_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
) -> None:
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if provider_benchmark_id is None:
|
||||
provider_benchmark_id = benchmark_id
|
||||
benchmark = BenchmarkWithOwner(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_benchmark_id,
|
||||
)
|
||||
await self.register_object(benchmark)
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
existing_benchmark = await self.get_benchmark(benchmark_id)
|
||||
await self.unregister_object(existing_benchmark)
|
||||
254
src/llama_stack/core/routing_tables/common.py
Normal file
254
src/llama_stack/core/routing_tables/common.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import Action
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
return p.__provider_spec__.api
|
||||
|
||||
|
||||
# TODO: this should return the registered object for all APIs
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||
api = get_impl_api(p)
|
||||
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.vector_io:
|
||||
return await p.register_vector_store(obj)
|
||||
elif api == Api.datasetio:
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
return await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
return await p.register_benchmark(obj)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.register_toolgroup(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.vector_io:
|
||||
return await p.unregister_vector_store(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.safety:
|
||||
return await p.unregister_shield(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
elif api == Api.eval:
|
||||
return await p.unregister_benchmark(obj.identifier)
|
||||
elif api == Api.scoring:
|
||||
return await p.unregister_scoring_function(obj.identifier)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.unregister_toolgroup(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
||||
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
) -> None:
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
self.policy = policy
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
for obj in objs:
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
for pid, p in self.impls_by_provider_id.items():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.vector_io:
|
||||
p.vector_store_store = self
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFnWithOwner)
|
||||
elif api == Api.eval:
|
||||
p.benchmark_store = self
|
||||
elif api == Api.tool_runtime:
|
||||
p.tool_store = self
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
from .benchmarks import BenchmarksRoutingTable
|
||||
from .datasets import DatasetsRoutingTable
|
||||
from .models import ModelsRoutingTable
|
||||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_stores import VectorStoresRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorStoresRoutingTable):
|
||||
return ("VectorIO", "vector_store")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
return ("Scoring", "scoring_function")
|
||||
elif isinstance(self, BenchmarksRoutingTable):
|
||||
return ("Eval", "benchmark")
|
||||
elif isinstance(self, ToolGroupsRoutingTable):
|
||||
return ("ToolGroups", "tool_group")
|
||||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
apiname, objtype = apiname_object()
|
||||
|
||||
# Get objects from disk registry
|
||||
obj = self.dist_registry.get_cached(objtype, routing_key)
|
||||
if not obj:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if len(provider_ids) > 1:
|
||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||
else:
|
||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||
raise ValueError(
|
||||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||
)
|
||||
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
# Get from disk registry
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}'")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
user = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, "delete", obj, user):
|
||||
raise AccessDeniedError("delete", obj, user)
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||
|
||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
creator = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, "create", obj, creator):
|
||||
raise AccessDeniedError("create", obj, creator)
|
||||
if creator:
|
||||
obj.owner = creator
|
||||
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def assert_action_allowed(
|
||||
self,
|
||||
action: Action,
|
||||
type: str,
|
||||
identifier: str,
|
||||
) -> None:
|
||||
"""Fetch a registered object by type/identifier and enforce the given action permission."""
|
||||
obj = await self.get_object_by_identifier(type, identifier)
|
||||
if obj is None:
|
||||
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
|
||||
user = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, action, obj, user):
|
||||
raise AccessDeniedError(action, obj, user)
|
||||
|
||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [
|
||||
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
|
||||
]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
||||
model = await routing_table.get_object_by_identifier("model", model_id)
|
||||
if not model:
|
||||
raise ModelNotFoundError(model_id)
|
||||
return model
|
||||
91
src/llama_stack/core/routing_tables/datasets.py
Normal file
91
src/llama_stack/core/routing_tables/datasets.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import DatasetNotFoundError
|
||||
from llama_stack.apis.datasets import (
|
||||
Dataset,
|
||||
DatasetPurpose,
|
||||
Datasets,
|
||||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.core.datatypes import (
|
||||
DatasetWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||
if dataset is None:
|
||||
raise DatasetNotFoundError(dataset_id)
|
||||
return dataset
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> Dataset:
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
source = URIDataSource.parse_obj(source)
|
||||
elif source["type"] == "rows":
|
||||
source = RowsDataSource.parse_obj(source)
|
||||
|
||||
if not dataset_id:
|
||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||
|
||||
provider_dataset_id = dataset_id
|
||||
|
||||
# infer provider from source
|
||||
if metadata and metadata.get("provider_id"):
|
||||
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
||||
elif source.type == DatasetType.rows.value:
|
||||
provider_id = "localfs"
|
||||
elif source.type == DatasetType.uri.value:
|
||||
# infer provider from uri
|
||||
if source.uri.startswith("huggingface"):
|
||||
provider_id = "huggingface"
|
||||
else:
|
||||
provider_id = "localfs"
|
||||
else:
|
||||
raise ValueError(f"Unknown data source type: {source.type}")
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
dataset = DatasetWithOwner(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.register_object(dataset)
|
||||
return dataset
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
dataset = await self.get_dataset(dataset_id)
|
||||
await self.unregister_object(dataset)
|
||||
163
src/llama_stack/core/routing_tables/models.py
Normal file
163
src/llama_stack/core/routing_tables/models.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.core.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
listed_providers: set[str] = set()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
refresh = await provider.should_refresh_models()
|
||||
refresh = refresh or provider_id not in self.listed_providers
|
||||
if not refresh:
|
||||
continue
|
||||
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
if models is None:
|
||||
continue
|
||||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
models = await self.get_all_with_type("model")
|
||||
openai_models = [
|
||||
OpenAIModel(
|
||||
id=model.identifier,
|
||||
object="model",
|
||||
created=int(time.time()),
|
||||
owned_by="llama_stack",
|
||||
)
|
||||
for model in models
|
||||
]
|
||||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
return await lookup_model(self, model_id)
|
||||
|
||||
async def get_provider_impl(self, model_id: str) -> Any:
|
||||
model = await lookup_model(self, model_id)
|
||||
if model.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def has_model(self, model_id: str) -> bool:
|
||||
"""
|
||||
Check if a model exists in the routing table.
|
||||
|
||||
:param model_id: The model identifier to check
|
||||
:return: True if the model exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
await lookup_model(self, model_id)
|
||||
return True
|
||||
except ModelNotFoundError:
|
||||
return False
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
|
||||
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
||||
)
|
||||
|
||||
provider_model_id = provider_model_id or model_id
|
||||
metadata = metadata or {}
|
||||
model_type = model_type or ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
|
||||
identifier = f"{provider_id}/{provider_model_id}"
|
||||
model = ModelWithOwner(
|
||||
identifier=identifier,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
source=RegistryEntrySource.via_register_api,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
async def update_registered_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
) -> None:
|
||||
existing_models = await self.get_all_with_type("model")
|
||||
|
||||
# we may have an alias for the model registered by the user (or during initialization
|
||||
# from run.yaml) that we need to keep track of
|
||||
model_ids = {}
|
||||
for model in existing_models:
|
||||
if model.provider_id != provider_id:
|
||||
continue
|
||||
if model.source == RegistryEntrySource.via_register_api:
|
||||
model_ids[model.provider_resource_id] = model.identifier
|
||||
continue
|
||||
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
|
||||
for model in models:
|
||||
if model.provider_resource_id in model_ids:
|
||||
# avoid overwriting a non-provider-registered model entry
|
||||
continue
|
||||
|
||||
if model.identifier == model.provider_resource_id:
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
ModelWithOwner(
|
||||
identifier=model.identifier,
|
||||
provider_resource_id=model.provider_resource_id,
|
||||
provider_id=provider_id,
|
||||
metadata=model.metadata,
|
||||
model_type=model.model_type,
|
||||
source=RegistryEntrySource.listed_from_provider,
|
||||
)
|
||||
)
|
||||
66
src/llama_stack/core/routing_tables/scoring_functions.py
Normal file
66
src/llama_stack/core/routing_tables/scoring_functions.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
ListScoringFunctionsResponse,
|
||||
ScoringFn,
|
||||
ScoringFnParams,
|
||||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
if scoring_fn is None:
|
||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||
return scoring_fn
|
||||
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: ScoringFnParams | None = None,
|
||||
) -> None:
|
||||
if provider_scoring_fn_id is None:
|
||||
provider_scoring_fn_id = scoring_fn_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFnWithOwner(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
provider_resource_id=provider_scoring_fn_id,
|
||||
provider_id=provider_id,
|
||||
params=params,
|
||||
)
|
||||
scoring_fn.provider_id = provider_id
|
||||
await self.register_object(scoring_fn)
|
||||
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
|
||||
await self.unregister_object(existing_scoring_fn)
|
||||
61
src/llama_stack/core/routing_tables/shields.py
Normal file
61
src/llama_stack/core/routing_tables/shields.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||
from llama_stack.core.datatypes import (
|
||||
ShieldWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
shield = await self.get_object_by_identifier("shield", identifier)
|
||||
if shield is None:
|
||||
raise ValueError(f"Shield '{identifier}' not found")
|
||||
return shield
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
if provider_shield_id is None:
|
||||
provider_shield_id = shield_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = ShieldWithOwner(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
params=params,
|
||||
)
|
||||
await self.register_object(shield)
|
||||
return shield
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
existing_shield = await self.get_shield(identifier)
|
||||
await self.unregister_object(existing_shield)
|
||||
129
src/llama_stack/core/routing_tables/toolgroups.py
Normal file
129
src/llama_stack/core/routing_tables/toolgroups.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||
# handle the funny case like "builtin::rag/knowledge_search"
|
||||
parts = toolgroup_name_with_maybe_tool_name.split("/")
|
||||
if len(parts) == 2:
|
||||
return parts[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
toolgroups_to_tools: dict[str, list[ToolDef]] = {}
|
||||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||
|
||||
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
|
||||
if toolgroup_id:
|
||||
routing_key = toolgroup_id
|
||||
|
||||
if routing_key in self.tool_to_toolgroup:
|
||||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
||||
else:
|
||||
toolgroups = await self.get_all_with_type("tool_group")
|
||||
|
||||
all_tools = []
|
||||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
except AuthenticationRequiredError:
|
||||
# Send authentication errors back to the client so it knows
|
||||
# that it needs to supply credentials for remote MCP servers.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Other errors that the client cannot fix are logged and
|
||||
# those specific toolgroups are skipped.
|
||||
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
|
||||
logger.debug(e, exc_info=True)
|
||||
continue
|
||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
tooldefs = tooldefs_response.data
|
||||
for t in tooldefs:
|
||||
t.toolgroup_id = toolgroup.identifier
|
||||
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tooldefs
|
||||
for tool in tooldefs:
|
||||
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ToolGroupNotFoundError(toolgroup_id)
|
||||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> ToolDef:
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
async def register_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
toolgroup = ToolGroupWithOwner(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
mcp_endpoint=mcp_endpoint,
|
||||
args=args,
|
||||
)
|
||||
await self.register_object(toolgroup)
|
||||
|
||||
# ideally, indexing of the tools should not be necessary because anyone using
|
||||
# the tools should first list the tools and then use them. but there are assumptions
|
||||
# baked in some of the code and tests right now.
|
||||
if not toolgroup.mcp_endpoint:
|
||||
await self._index_tools(toolgroup)
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
await self.unregister_object(await self.get_tool_group(toolgroup_id))
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
292
src/llama_stack/core/routing_tables/vector_stores.py
Normal file
292
src/llama_stack/core/routing_tables/vector_stores.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
|
||||
# Removed VectorStores import to avoid exposing public API
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorStoreWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
||||
"""Internal routing table for vector_store operations.
|
||||
|
||||
Does not inherit from VectorStores to avoid exposing public API endpoints.
|
||||
Only provides internal routing functionality for VectorIORouter.
|
||||
"""
|
||||
|
||||
# Internal methods only - no public API exposure
|
||||
|
||||
async def register_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_store_id: str | None = None,
|
||||
vector_store_name: str | None = None,
|
||||
) -> Any:
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
|
||||
vector_store = VectorStoreWithOwner(
|
||||
identifier=vector_store_id,
|
||||
type=ResourceType.vector_store.value,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_vector_store_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
vector_store_name=vector_store_name,
|
||||
)
|
||||
await self.register_object(vector_store)
|
||||
return vector_store
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_store(vector_store_id)
|
||||
return result
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
"""Remove the vector store from the routing table registry."""
|
||||
try:
|
||||
vector_store_obj = await self.get_object_by_identifier("vector_store", vector_store_id)
|
||||
if vector_store_obj:
|
||||
await self.unregister_object(vector_store_obj)
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the operation
|
||||
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: Any | None = None,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
file_ids=file_ids,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
5
src/llama_stack/core/server/__init__.py
Normal file
5
src/llama_stack/core/server/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
187
src/llama_stack/core/server/auth.py
Normal file
187
src/llama_stack/core/server/auth.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.datatypes import AuthenticationConfig, User
|
||||
from llama_stack.core.request_headers import user_from_scope
|
||||
from llama_stack.core.server.auth_providers import create_auth_provider
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core::auth")
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using configured authentication provider.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Uses the configured auth provider to validate the token
|
||||
3. Extracts user attributes from the provider's response
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Unauthenticated Access:
|
||||
Endpoints can opt out of authentication by setting require_authentication=False
|
||||
in their @webmethod decorator. This is typically used for operational endpoints
|
||||
like /health and /version to support monitoring, load balancers, and observability tools.
|
||||
|
||||
The middleware supports multiple authentication providers through the AuthProvider interface:
|
||||
- Kubernetes: Validates tokens against the Kubernetes API server
|
||||
- Custom: Validates tokens against a custom endpoint
|
||||
|
||||
Authentication Request Format for Custom Auth Provider:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
"request": {
|
||||
"path": "/models/list",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "..."
|
||||
// All headers except Authorization
|
||||
},
|
||||
"params": {
|
||||
"limit": ["100"],
|
||||
"offset": ["0"]
|
||||
// Query parameters as key -> list of values
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Expected Auth Endpoint Response Format:
|
||||
```json
|
||||
{
|
||||
"access_attributes": { // Structured attribute format
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research"]
|
||||
},
|
||||
"message": "Optional message about auth result"
|
||||
}
|
||||
```
|
||||
|
||||
Token Validation:
|
||||
Each provider implements its own token validation logic:
|
||||
- Kubernetes: Uses TokenReview API to validate service account tokens
|
||||
- Custom: Sends token to custom endpoint for validation
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth provider are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth provider doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_config: AuthenticationConfig, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# Find the route and check if authentication is required
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
webmethod = None
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass here to run auth anyways
|
||||
pass
|
||||
|
||||
# If webmethod explicitly sets require_authentication=False, allow without auth
|
||||
if webmethod and webmethod.require_authentication is False:
|
||||
logger.debug(f"Allowing unauthenticated access to endpoint: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Handle authentication
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if not auth_header:
|
||||
error_msg = self.auth_provider.get_auth_error_message(scope)
|
||||
return await self._send_auth_error(send, error_msg)
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Invalid Authorization header format")
|
||||
|
||||
token = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
# Validate token and get access attributes
|
||||
try:
|
||||
validation_result = await self.auth_provider.validate_token(token, scope)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except ValueError as e:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, str(e))
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||
# can identify the requester and enforce per-client rate limits.
|
||||
scope["authenticated_client_id"] = token
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["principal"] = validation_result.principal
|
||||
if validation_result.attributes:
|
||||
scope["user_attributes"] = validation_result.attributes
|
||||
logger.debug(
|
||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
|
||||
# Scope-based API access control
|
||||
if webmethod and webmethod.required_scope:
|
||||
user = user_from_scope(scope)
|
||||
if not _has_required_scope(webmethod.required_scope, user):
|
||||
return await self._send_auth_error(
|
||||
send,
|
||||
f"Access denied: user does not have required scope: {webmethod.required_scope}",
|
||||
status=403,
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message, status=401):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_key = "message" if status == 401 else "detail"
|
||||
error_msg = json.dumps({"error": {error_key: message}}).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
|
||||
def _has_required_scope(required_scope: str, user: User | None) -> bool:
|
||||
# if no user, assume auth is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not user.attributes:
|
||||
return False
|
||||
|
||||
user_scopes = user.attributes.get("scopes", [])
|
||||
return required_scope in user_scopes
|
||||
494
src/llama_stack/core/server/auth_providers.py
Normal file
494
src/llama_stack/core/server/auth_providers.py
Normal file
|
|
@ -0,0 +1,494 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import ssl
|
||||
from abc import ABC, abstractmethod
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.errors import TokenValidationError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
KubernetesAuthProviderConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core::auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
message: str | None = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return provider-specific authentication error message."""
|
||||
return "Authentication required"
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
# First try dot notation for nested traversal (e.g., "resource_access.llamastack.roles")
|
||||
# Then fall back to literal key with dots (e.g., "my.dotted.key")
|
||||
claim: object = claims
|
||||
keys = claim_key.split(".")
|
||||
for key in keys:
|
||||
if isinstance(claim, dict) and key in claim:
|
||||
claim = claim[key]
|
||||
else:
|
||||
claim = None
|
||||
break
|
||||
|
||||
if claim is None and claim_key in claims:
|
||||
# Fall back to checking if claim_key exists as a literal key
|
||||
claim = claims[claim_key]
|
||||
|
||||
if claim is None:
|
||||
continue
|
||||
|
||||
if isinstance(claim, list):
|
||||
values = claim
|
||||
elif isinstance(claim, str):
|
||||
values = claim.split()
|
||||
else:
|
||||
continue
|
||||
|
||||
if attribute_key in attributes:
|
||||
attributes[attribute_key].extend(values)
|
||||
else:
|
||||
attributes[attribute_key] = values
|
||||
return attributes
|
||||
|
||||
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||
|
||||
This should be the standard authentication provider for most use cases.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OAuth2TokenAuthConfig):
|
||||
self.config = config
|
||||
self._jwks_client: jwt.PyJWKClient | None = None
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
if self.config.jwks:
|
||||
return await self.validate_jwt_token(token, scope)
|
||||
if self.config.introspection:
|
||||
return await self.introspect_token(token, scope)
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
|
||||
def _get_jwks_client(self) -> jwt.PyJWKClient:
|
||||
if self._jwks_client is None:
|
||||
ssl_context = None
|
||||
if not self.config.verify_tls:
|
||||
# Disable SSL verification if verify_tls is False
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
elif self.config.tls_cafile:
|
||||
# Use custom CA file if provided
|
||||
ssl_context = ssl.create_default_context(
|
||||
cafile=self.config.tls_cafile.as_posix(),
|
||||
)
|
||||
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
|
||||
|
||||
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
|
||||
# to the JWK endpoint, we must use the token in the config to authenticate
|
||||
headers = {}
|
||||
if self.config.jwks and self.config.jwks.token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||
|
||||
self._jwks_client = jwt.PyJWKClient(
|
||||
self.config.jwks.uri if self.config.jwks else None,
|
||||
cache_keys=True,
|
||||
max_cached_keys=10,
|
||||
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
|
||||
headers=headers,
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
return self._jwks_client
|
||||
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the JWT token."""
|
||||
try:
|
||||
jwks_client: jwt.PyJWKClient = self._get_jwks_client()
|
||||
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
||||
algorithm = jwt.get_unverified_header(token)["alg"]
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
issuer=self.config.issuer,
|
||||
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
|
||||
)
|
||||
|
||||
# Decode and verify the JWT
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
issuer=self.config.issuer,
|
||||
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError("Invalid JWT token") from exc
|
||||
|
||||
# There are other standard claims, the most relevant of which is `scope`.
|
||||
# We should incorporate these into the access attributes.
|
||||
principal = claims["sub"]
|
||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||
form = {
|
||||
"token": token,
|
||||
}
|
||||
if self.config.introspection is None:
|
||||
raise ValueError("Introspection is not configured")
|
||||
|
||||
if self.config.introspection.send_secret_in_body:
|
||||
form["client_id"] = self.config.introspection.client_id
|
||||
form["client_secret"] = self.config.introspection.client_secret
|
||||
auth = None
|
||||
else:
|
||||
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
|
||||
ssl_ctxt = None
|
||||
if self.config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||
response = await client.post(
|
||||
self.config.introspection.url,
|
||||
data=form,
|
||||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
||||
fields = response.json()
|
||||
if not fields["active"]:
|
||||
raise ValueError("Token not active")
|
||||
principal = fields["sub"] or fields["username"]
|
||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Token introspection request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during token introspection")
|
||||
raise ValueError("Token introspection error") from e
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return OAuth2-specific authentication error message."""
|
||||
if self.config.issuer:
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
|
||||
elif self.config.introspection:
|
||||
# Extract domain from introspection URL for a cleaner message
|
||||
domain = urlparse(self.config.introspection.url).netloc
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
|
||||
else:
|
||||
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: CustomAuthConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if scope is None:
|
||||
scope = {}
|
||||
|
||||
headers = dict(scope.get("headers", []))
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=token,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.config.endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
return User(principal=auth_response.principal, attributes=auth_response.attributes)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during authentication")
|
||||
raise ValueError("Authentication service error") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return custom auth provider-specific authentication error message."""
|
||||
domain = urlparse(self.config.endpoint).netloc
|
||||
if domain:
|
||||
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
|
||||
else:
|
||||
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
|
||||
|
||||
|
||||
class GitHubTokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
GitHub token authentication provider that validates GitHub access tokens directly.
|
||||
|
||||
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
|
||||
them against the GitHub API to get user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GitHubTokenAuthConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a GitHub token by calling the GitHub API.
|
||||
|
||||
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
|
||||
"""
|
||||
try:
|
||||
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning(f"GitHub token validation failed: {e}")
|
||||
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
|
||||
|
||||
principal = user_info["user"]["login"]
|
||||
|
||||
github_data = {
|
||||
"login": user_info["user"]["login"],
|
||||
"id": str(user_info["user"]["id"]),
|
||||
"organizations": user_info.get("organizations", []),
|
||||
}
|
||||
|
||||
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return GitHub-specific authentication error message."""
|
||||
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
|
||||
|
||||
|
||||
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
|
||||
"""Fetch user info and organizations from GitHub API."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "llama-stack",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
|
||||
user_response.raise_for_status()
|
||||
user_data = user_response.json()
|
||||
|
||||
return {
|
||||
"user": user_data,
|
||||
}
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""
|
||||
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
|
||||
This provider integrates with Kubernetes API server by using the
|
||||
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||
self.config = config
|
||||
|
||||
def _httpx_verify_value(self) -> bool | str:
|
||||
"""
|
||||
Build the value for httpx's `verify` parameter.
|
||||
- False disables verification.
|
||||
- Path string points to a CA bundle.
|
||||
- True uses system defaults.
|
||||
"""
|
||||
if not self.config.verify_tls:
|
||||
return False
|
||||
if self.config.tls_cafile:
|
||||
return self.config.tls_cafile.as_posix()
|
||||
return True
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
|
||||
# Build the Kubernetes SelfSubjectReview API endpoint URL
|
||||
review_api_url = urljoin(self.config.api_server_url, "/apis/authentication.k8s.io/v1/selfsubjectreviews")
|
||||
|
||||
# Create SelfSubjectReview request body
|
||||
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
|
||||
verify = self._httpx_verify_value()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=verify, timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
review_api_url,
|
||||
json=review_request,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == httpx.codes.UNAUTHORIZED:
|
||||
raise TokenValidationError("Invalid token")
|
||||
if response.status_code != httpx.codes.CREATED:
|
||||
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
|
||||
raise TokenValidationError(f"Token validation failed: {response.status_code}")
|
||||
|
||||
review_response = response.json()
|
||||
# Extract user information from SelfSubjectReview response
|
||||
status = review_response.get("status", {})
|
||||
if not status:
|
||||
raise ValueError("No status found in SelfSubjectReview response")
|
||||
|
||||
user_info = status.get("userInfo", {})
|
||||
if not user_info:
|
||||
raise ValueError("No userInfo found in SelfSubjectReview response")
|
||||
|
||||
username = user_info.get("username")
|
||||
if not username:
|
||||
raise ValueError("No username found in SelfSubjectReview response")
|
||||
|
||||
# Build user attributes from Kubernetes user info
|
||||
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=username,
|
||||
attributes=user_attributes,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Kubernetes SelfSubjectReview API request timed out")
|
||||
raise ValueError("Token validation timeout") from None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during token validation: {str(e)}")
|
||||
raise ValueError(f"Token validation error: {str(e)}") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close any resources."""
|
||||
pass
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_config = config.provider_config
|
||||
|
||||
if isinstance(provider_config, CustomAuthConfig):
|
||||
return CustomAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, OAuth2TokenAuthConfig):
|
||||
return OAuth2TokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||
return GitHubTokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, KubernetesAuthProviderConfig):
|
||||
return KubernetesAuthProvider(provider_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||
110
src/llama_stack/core/server/quota.py
Normal file
110
src/llama_stack/core/server/quota.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
"""
|
||||
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
|
||||
within a configurable time window.
|
||||
|
||||
- For authenticated requests, it reads the client ID from the
|
||||
`Authorization: Bearer <client_id>` header.
|
||||
- For anonymous requests, it falls back to the IP address of the client.
|
||||
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
|
||||
once a client exceeds its quota.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
kv_config: KVStoreReference,
|
||||
anonymous_max_requests: int,
|
||||
authenticated_max_requests: int,
|
||||
window_seconds: int = 86400,
|
||||
):
|
||||
self.app = app
|
||||
self.kv_config = kv_config
|
||||
self.kv: KVStore | None = None
|
||||
self.anonymous_max_requests = anonymous_max_requests
|
||||
self.authenticated_max_requests = authenticated_max_requests
|
||||
self.window_seconds = window_seconds
|
||||
|
||||
async def _get_kv(self) -> KVStore:
|
||||
if self.kv is None:
|
||||
self.kv = await kvstore_impl(self.kv_config)
|
||||
backend_config = _KVSTORE_BACKENDS.get(self.kv_config.backend)
|
||||
if backend_config and backend_config.type == StorageBackendType.KV_SQLITE:
|
||||
logger.warning(
|
||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||
f"window_seconds={self.window_seconds}"
|
||||
)
|
||||
return self.kv
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] == "http":
|
||||
# pick key & limit based on auth
|
||||
auth_id = scope.get("authenticated_client_id")
|
||||
if auth_id:
|
||||
key_id = auth_id
|
||||
limit = self.authenticated_max_requests
|
||||
else:
|
||||
# fallback to IP
|
||||
client = scope.get("client")
|
||||
key_id = client[0] if client else "anonymous"
|
||||
limit = self.anonymous_max_requests
|
||||
|
||||
current_window = int(time.time() // self.window_seconds)
|
||||
key = f"quota:{key_id}:{current_window}"
|
||||
|
||||
try:
|
||||
kv = await self._get_kv()
|
||||
prev = await kv.get(key) or "0"
|
||||
count = int(prev) + 1
|
||||
|
||||
if int(prev) == 0:
|
||||
# Set with expiration datetime when it is the first request in the window.
|
||||
expiration = datetime.now(UTC) + timedelta(seconds=self.window_seconds)
|
||||
await kv.set(key, str(count), expiration=expiration)
|
||||
else:
|
||||
await kv.set(key, str(count))
|
||||
except Exception:
|
||||
logger.exception("Failed to access KV store for quota")
|
||||
return await self._send_error(send, 500, "Quota service error")
|
||||
|
||||
if count > limit:
|
||||
logger.warning(
|
||||
"Quota exceeded for client %s: %d/%d",
|
||||
key_id,
|
||||
count,
|
||||
limit,
|
||||
)
|
||||
return await self._send_error(send, 429, "Quota exceeded")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_error(self, send: Send, status: int, message: str):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
body = json.dumps({"error": {"message": message}}).encode()
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
141
src/llama_stack/core/server/routes.py
Normal file
141
src/llama_stack/core/server/routes.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import hdrs
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
EndpointFunc = Callable[..., Any]
|
||||
PathParams = dict[str, str]
|
||||
RouteInfo = tuple[EndpointFunc, str, WebMethod]
|
||||
PathImpl = dict[str, RouteInfo]
|
||||
RouteImpls = dict[str, PathImpl]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
return {
|
||||
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
||||
}
|
||||
|
||||
|
||||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
routes = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
# HACK ALERT
|
||||
if api == Api.tool_runtime:
|
||||
for tool_group in SpecialToolGroup:
|
||||
sub_protocol = toolgroup_protocols[tool_group]
|
||||
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
||||
for name, method in sub_protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
if not webmethods:
|
||||
continue
|
||||
|
||||
# Create routes for each webmethod decorator
|
||||
for webmethod in webmethods:
|
||||
path = f"/{webmethod.level}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = routes
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
|
||||
api_to_routes = get_all_api_routes(external_apis)
|
||||
route_impls: RouteImpls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_routes in api_to_routes.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for route, webmethod in api_routes:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, route.name)
|
||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
continue # Skip if only HEAD method is available
|
||||
method = available_methods[0].lower()
|
||||
if method not in route_impls:
|
||||
route_impls[method] = {}
|
||||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
||||
func,
|
||||
route.path,
|
||||
webmethod,
|
||||
)
|
||||
|
||||
return route_impls
|
||||
|
||||
|
||||
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
route_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = route_impls.get(method.lower())
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, route_path, webmethod) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, route_path, webmethod
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
536
src/llama_stack/core/server/server.py
Normal file
536
src/llama_stack/core/server/server.py
Normal file
|
|
@ -0,0 +1,536 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, get_origin
|
||||
|
||||
import httpx
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
StackRunConfig,
|
||||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
cast_image_name_to_string,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.core.telemetry import Telemetry
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, setup_logger
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import LoggingConfig, get_logger, setup_logging
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
from .tracing import TracingMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
log = file if hasattr(file, "write") else sys.stderr
|
||||
traceback.print_stack(file=log)
|
||||
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
||||
|
||||
|
||||
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
|
||||
warnings.showwarning = warn_with_traceback
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump_json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
|
||||
if isinstance(exc, ValidationError):
|
||||
exc = RequestValidationError(exc.errors())
|
||||
|
||||
if isinstance(exc, RequestValidationError):
|
||||
return HTTPException(
|
||||
status_code=httpx.codes.BAD_REQUEST,
|
||||
detail={
|
||||
"errors": [
|
||||
{
|
||||
"loc": list(error["loc"]),
|
||||
"msg": error["msg"],
|
||||
"type": error["type"],
|
||||
}
|
||||
for error in exc.errors()
|
||||
]
|
||||
},
|
||||
)
|
||||
elif isinstance(exc, ConflictError):
|
||||
return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
|
||||
elif isinstance(exc, ResourceNotFoundError):
|
||||
return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError | AccessDeniedError):
|
||||
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, ConnectionError | httpx.ConnectError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
|
||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
|
||||
elif isinstance(exc, AuthenticationRequiredError):
|
||||
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
|
||||
elif hasattr(exc, "status_code") and isinstance(getattr(exc, "status_code", None), int):
|
||||
# Handle provider SDK exceptions (e.g., OpenAI's APIStatusError and subclasses)
|
||||
# These include AuthenticationError (401), PermissionDeniedError (403), etc.
|
||||
# This preserves the actual HTTP status code from the provider
|
||||
status_code = exc.status_code
|
||||
detail = str(exc)
|
||||
return HTTPException(status_code=status_code, detail=detail)
|
||||
else:
|
||||
return HTTPException(
|
||||
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error: An unexpected error occurred.",
|
||||
)
|
||||
|
||||
|
||||
class StackApp(FastAPI):
|
||||
"""
|
||||
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.stack: Stack = Stack(config)
|
||||
|
||||
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||
# function.
|
||||
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||
# in a separate thread.
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||
future.result()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: StackApp):
|
||||
server_version = parse_version("llama-stack")
|
||||
|
||||
logger.info(f"Starting up Llama Stack server (version: {server_version})")
|
||||
assert app.stack is not None
|
||||
app.stack.create_registry_refresh_task()
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
await app.stack.shutdown()
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
# If there's a stream parameter at top level, use it
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it
|
||||
if "params" in kwargs:
|
||||
params = kwargs["params"]
|
||||
if hasattr(params, "stream"):
|
||||
return params.stream
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
if inspect.iscoroutine(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen_coroutine):
|
||||
event_gen = None
|
||||
try:
|
||||
event_gen = await event_gen_coroutine
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
if event_gen:
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def log_request_pre_validation(request: Request):
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
try:
|
||||
parsed_body = json.loads(body_bytes.decode())
|
||||
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
log_output = repr(body_bytes)
|
||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||
else:
|
||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def route_handler(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user = user_from_scope(request.scope)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
test_context_token = None
|
||||
test_context_var = None
|
||||
reset_test_context_fn = None
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user):
|
||||
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
|
||||
from llama_stack.core.testing_context import (
|
||||
TEST_CONTEXT,
|
||||
reset_test_context,
|
||||
sync_test_context_from_provider_data,
|
||||
)
|
||||
|
||||
test_context_token = sync_test_context_from_provider_data()
|
||||
test_context_var = TEST_CONTEXT
|
||||
reset_test_context_fn = reset_test_context
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
if test_context_var is not None:
|
||||
context_vars.append(test_context_var)
|
||||
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
result = await maybe_await(value)
|
||||
if isinstance(result, PaginatedResponse) and result.url is None:
|
||||
result.url = route
|
||||
|
||||
if method.upper() == "DELETE" and result is None:
|
||||
return Response(status_code=httpx.codes.NO_CONTENT)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
else:
|
||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
if test_context_token is not None and reset_test_context_fn is not None:
|
||||
reset_test_context_fn(test_context_token)
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
path_params = extract_path_params(route)
|
||||
if method == "post":
|
||||
# Annotate parameters that are in the path with Path(...) and others with Body(...),
|
||||
# but preserve existing File() and Form() annotations for multipart form data
|
||||
new_params = (
|
||||
[new_params[0]]
|
||||
+ [
|
||||
(
|
||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||
if param.name in path_params
|
||||
else (
|
||||
param # Keep original annotation if it's already an Annotated type
|
||||
if get_origin(param.annotation) is Annotated
|
||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
)
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
)
|
||||
|
||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
return route_handler
|
||||
|
||||
|
||||
class ClientVersionMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.server_version = parse_version("llama-stack")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = dict(scope.get("headers", []))
|
||||
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
||||
if client_version:
|
||||
try:
|
||||
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
|
||||
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
|
||||
if client_version_parts != server_version_parts:
|
||||
|
||||
async def send_version_error(send):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": httpx.codes.UPGRADE_REQUIRED,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
|
||||
}
|
||||
}
|
||||
).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
return await send_version_error(send)
|
||||
except (ValueError, IndexError):
|
||||
# If version parsing fails, let the request through
|
||||
pass
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def create_app() -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
This factory function reads configuration from environment variables:
|
||||
- LLAMA_STACK_CONFIG: Path to config file (required)
|
||||
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
# Initialize logging from environment variables first
|
||||
setup_logging()
|
||||
|
||||
config_file = os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
|
||||
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
# Load and process configuration
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = StackApp(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
config=config,
|
||||
)
|
||||
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
impls = app.stack.impls
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||
else:
|
||||
if config.server.quota:
|
||||
quota = config.server.quota
|
||||
logger.warning(
|
||||
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
||||
"falling back to anonymous_max_requests (%d) for all the requests",
|
||||
quota.authenticated_max_requests,
|
||||
quota.anonymous_max_requests,
|
||||
)
|
||||
|
||||
if config.server.quota:
|
||||
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
||||
|
||||
quota = config.server.quota
|
||||
anonymous_max_requests = quota.anonymous_max_requests
|
||||
# if auth is disabled, use the anonymous max requests
|
||||
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
|
||||
|
||||
kv_config = quota.kvstore
|
||||
window_map = {"day": 86400}
|
||||
window_seconds = window_map[quota.period.value]
|
||||
|
||||
app.add_middleware(
|
||||
QuotaMiddleware,
|
||||
kv_config=kv_config,
|
||||
anonymous_max_requests=anonymous_max_requests,
|
||||
authenticated_max_requests=authenticated_max_requests,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
if config.server.cors:
|
||||
logger.info("Enabling CORS")
|
||||
cors_config = process_cors_config(config.server.cors)
|
||||
if cors_config:
|
||||
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||
|
||||
if config.telemetry.enabled:
|
||||
setup_logger(Telemetry())
|
||||
|
||||
# Load external APIs if configured
|
||||
external_apis = load_external_apis(config)
|
||||
all_routes = get_all_api_routes(external_apis)
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
# if we do not serve the corresponding router API, we should not serve the routing table API
|
||||
if inf.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
apis_to_serve.add("conversations")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
try:
|
||||
impl = impls[api]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
||||
|
||||
for route, _ in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||
|
||||
impl_method = getattr(impl, route.name)
|
||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||
method = available_methods[0]
|
||||
logger.debug(f"{method} {route.path}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
getattr(app, method.lower())(route.path, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
method.lower(),
|
||||
route.path,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
if config.telemetry.enabled:
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
||||
clean_config = remove_disabled_providers(safe_config)
|
||||
logger.info(yaml.dump(clean_config, indent=2))
|
||||
|
||||
|
||||
def extract_path_params(route: str) -> list[str]:
|
||||
segments = route.split("/")
|
||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||
# to handle path params like {param:path}
|
||||
params = [param.split(":")[0] for param in params]
|
||||
return params
|
||||
|
||||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
|
||||
else:
|
||||
return obj
|
||||
80
src/llama_stack/core/server/tracing.py
Normal file
80
src/llama_stack/core/server/tracing.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.telemetry.tracing import end_trace, start_trace
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Log deprecation warning if route is deprecated
|
||||
if getattr(webmethod, "deprecated", False):
|
||||
logger.warning(
|
||||
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
|
||||
f"This route is deprecated and may be removed in a future version. "
|
||||
f"Please check the docs for the supported version."
|
||||
)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
572
src/llama_stack/core/stack.py
Normal file
572
src/llama_stack/core/stack.py
Normal file
|
|
@ -0,0 +1,572 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib.resources
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageBackendConfig,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.core.store.registry import create_dist_registry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
Providers,
|
||||
Inference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
VectorIO,
|
||||
Eval,
|
||||
Benchmarks,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
Conversations,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
RESOURCES = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
Api.scoring_functions,
|
||||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
|
||||
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||
]
|
||||
|
||||
|
||||
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
|
||||
REGISTRY_REFRESH_TASK = None
|
||||
TEST_RECORDING_CONTEXT = None
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config.registered_resources, rsrc)
|
||||
if api not in impls:
|
||||
continue
|
||||
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
if hasattr(obj, "provider_id"):
|
||||
# Do not register models on disabled providers
|
||||
if not obj.provider_id or obj.provider_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
||||
|
||||
method = getattr(impls[api], list_method)
|
||||
response = await method()
|
||||
|
||||
objects_to_process = response.data if hasattr(response, "data") else response
|
||||
|
||||
for obj in objects_to_process:
|
||||
logger.debug(
|
||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
|
||||
async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
|
||||
"""Validate vector stores configuration."""
|
||||
if vector_stores_config is None:
|
||||
return
|
||||
|
||||
default_embedding_model = vector_stores_config.default_embedding_model
|
||||
if default_embedding_model is None:
|
||||
return
|
||||
|
||||
provider_id = default_embedding_model.provider_id
|
||||
model_id = default_embedding_model.model_id
|
||||
default_model_id = f"{provider_id}/{model_id}"
|
||||
|
||||
if Api.models not in impls:
|
||||
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
|
||||
|
||||
models_impl = impls[Api.models]
|
||||
response = await models_impl.list_models()
|
||||
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
|
||||
|
||||
default_model = models_list.get(default_model_id)
|
||||
if default_model is None:
|
||||
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
|
||||
|
||||
embedding_dimension = default_model.metadata.get("embedding_dimension")
|
||||
if embedding_dimension is None:
|
||||
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
|
||||
|
||||
try:
|
||||
int(embedding_dimension)
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
|
||||
|
||||
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
||||
|
||||
|
||||
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
|
||||
if safety_config is None or safety_config.default_shield_id is None:
|
||||
return
|
||||
|
||||
if Api.shields not in impls:
|
||||
raise ValueError("Safety configuration requires the shields API to be enabled")
|
||||
|
||||
if Api.safety not in impls:
|
||||
raise ValueError("Safety configuration requires the safety API to be enabled")
|
||||
|
||||
shields_impl = impls[Api.shields]
|
||||
response = await shields_impl.list_shields()
|
||||
shields_by_id = {shield.identifier: shield for shield in response.data}
|
||||
|
||||
default_shield_id = safety_config.default_shield_id
|
||||
# don't validate if there are no shields registered
|
||||
if shields_by_id and default_shield_id not in shields_by_id:
|
||||
available = sorted(shields_by_id)
|
||||
raise ValueError(
|
||||
f"Configured default_shield_id '{default_shield_id}' not found among registered shields."
|
||||
f" Available shields: {available}"
|
||||
)
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
def __init__(self, var_name: str, path: str = ""):
|
||||
self.var_name = var_name
|
||||
self.path = path
|
||||
super().__init__(
|
||||
f"Environment variable '{var_name}' not set or empty {f'at {path}' if path else ''}. "
|
||||
f"Use ${{env.{var_name}:=default_value}} to provide a default value, "
|
||||
f"${{env.{var_name}:+value_if_set}} to make the field conditional, "
|
||||
f"or ensure the environment variable is set."
|
||||
)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
for k, v in config.items():
|
||||
try:
|
||||
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
||||
elif isinstance(config, list):
|
||||
result = []
|
||||
for i, v in enumerate(config):
|
||||
try:
|
||||
# Special handling for providers: first resolve the provider_id to check if provider
|
||||
# is disabled so that we can skip config env variable expansion and avoid validation errors
|
||||
if isinstance(v, dict) and "provider_id" in v:
|
||||
try:
|
||||
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
|
||||
if resolved_provider_id == "__disabled__":
|
||||
logger.debug(
|
||||
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
||||
)
|
||||
# Create a copy with resolved provider_id but original config
|
||||
disabled_provider = v.copy()
|
||||
disabled_provider["provider_id"] = resolved_provider_id
|
||||
continue
|
||||
except EnvVarError:
|
||||
# If we can't resolve the provider_id, continue with normal processing
|
||||
pass
|
||||
|
||||
# Normal processing for non-disabled providers
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]"))
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
||||
elif isinstance(config, str):
|
||||
# Pattern supports bash-like syntax: := for default and :+ for conditional and a optional value
|
||||
pattern = r"\${env\.([A-Z0-9_]+)(?::([=+])([^}]*))?}"
|
||||
|
||||
def get_env_var(match: re.Match):
|
||||
env_var = match.group(1)
|
||||
operator = match.group(2) # '=' for default, '+' for conditional
|
||||
value_expr = match.group(3)
|
||||
|
||||
env_value = os.environ.get(env_var)
|
||||
|
||||
if operator == "=": # Default value syntax: ${env.FOO:=default}
|
||||
# If the env is set like ${env.FOO:=default} then use the env value when set
|
||||
if env_value:
|
||||
value = env_value
|
||||
else:
|
||||
# If the env is not set, look for a default value
|
||||
# value_expr returns empty string (not None) when not matched
|
||||
# This means ${env.FOO:=} and it's accepted and returns empty string - just like bash
|
||||
if value_expr == "":
|
||||
return ""
|
||||
else:
|
||||
value = value_expr
|
||||
|
||||
elif operator == "+": # Conditional value syntax: ${env.FOO:+value_if_set}
|
||||
# If the env is set like ${env.FOO:+value_if_set} then use the value_if_set
|
||||
if env_value:
|
||||
if value_expr:
|
||||
value = value_expr
|
||||
# This means ${env.FOO:+}
|
||||
else:
|
||||
# Just like bash, this doesn't care whether the env is set or not and applies
|
||||
# the value, in this case the empty string
|
||||
return ""
|
||||
else:
|
||||
# Just like bash, this doesn't care whether the env is set or not, since it's not set
|
||||
# we return an empty string
|
||||
value = ""
|
||||
else: # No operator case: ${env.FOO}
|
||||
if not env_value:
|
||||
raise EnvVarError(env_var, path)
|
||||
value = env_value
|
||||
|
||||
# expand "~" from the values
|
||||
return os.path.expanduser(value)
|
||||
|
||||
try:
|
||||
result = re.sub(pattern, get_env_var, config)
|
||||
# Only apply type conversion if substitution actually happened
|
||||
if result != config:
|
||||
return _convert_string_to_proper_type(result)
|
||||
return result
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _convert_string_to_proper_type(value: str) -> Any:
|
||||
# This might be tricky depending on what the config type is, if 'str | None' we are
|
||||
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
|
||||
# providers config should be typed this way.
|
||||
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
|
||||
# and then convert the empty string to None or not
|
||||
if value == "":
|
||||
return None
|
||||
|
||||
lowered = value.lower()
|
||||
if lowered == "true":
|
||||
return True
|
||||
elif lowered == "false":
|
||||
return False
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
|
||||
if "image_name" in config_dict and config_dict["image_name"] is not None:
|
||||
config_dict["image_name"] = str(config_dict["image_name"])
|
||||
return config_dict
|
||||
|
||||
|
||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
Args:
|
||||
impls: Dictionary of API implementations
|
||||
run_config: Stack run configuration
|
||||
"""
|
||||
inspect_impl = DistributionInspectImpl(
|
||||
DistributionInspectConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.inspect] = inspect_impl
|
||||
|
||||
providers_impl = ProviderImpl(
|
||||
ProviderImplConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.providers] = providers_impl
|
||||
|
||||
prompts_impl = PromptServiceImpl(
|
||||
PromptServiceConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
conversations_impl = ConversationServiceImpl(
|
||||
ConversationServiceConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.conversations] = conversations_impl
|
||||
|
||||
|
||||
def _initialize_storage(run_config: StackRunConfig):
|
||||
kv_backends: dict[str, StorageBackendConfig] = {}
|
||||
sql_backends: dict[str, StorageBackendConfig] = {}
|
||||
for backend_name, backend_config in run_config.storage.backends.items():
|
||||
type = backend_config.type.value
|
||||
if type.startswith("kv_"):
|
||||
kv_backends[backend_name] = backend_config
|
||||
elif type.startswith("sql_"):
|
||||
sql_backends[backend_name] = backend_config
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {type}")
|
||||
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
register_kvstore_backends(kv_backends)
|
||||
register_sqlstore_backends(sql_backends)
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
self.provider_registry = provider_registry
|
||||
self.impls = None
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def initialize(self):
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.api_recorder import setup_api_recording
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_api_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
_initialize_storage(self.run_config)
|
||||
stores = self.run_config.storage.stores
|
||||
if not stores.metadata:
|
||||
raise ValueError("storage.stores.metadata must be configured with a kv_* backend")
|
||||
dist_registry, _ = await create_dist_registry(stores.metadata, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
|
||||
internal_impls = {}
|
||||
add_internal_implementations(internal_impls, self.run_config)
|
||||
|
||||
impls = await resolve_impls(
|
||||
self.run_config,
|
||||
self.provider_registry or get_provider_registry(self.run_config),
|
||||
dist_registry,
|
||||
policy,
|
||||
internal_impls,
|
||||
)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
if Api.conversations in impls:
|
||||
await impls[Api.conversations].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
await refresh_registry_once(impls)
|
||||
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
||||
await validate_safety_config(self.run_config.safety, impls)
|
||||
self.impls = impls
|
||||
|
||||
def create_registry_refresh_task(self):
|
||||
assert self.impls is not None, "Must call initialize() before starting"
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
|
||||
async def shutdown(self):
|
||||
for impl in self.impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning(f"No shutdown method for {impl_name}")
|
||||
except TimeoutError:
|
||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during API recording cleanup: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
REGISTRY_REFRESH_TASK.cancel()
|
||||
|
||||
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
logger.debug("refreshing registry")
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
||||
|
||||
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||
logger.info("starting registry refresh task")
|
||||
while True:
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_path) as path:
|
||||
if not path.exists():
|
||||
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
||||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
|
||||
) -> StackRunConfig:
|
||||
"""
|
||||
Create an adhoc distribution from a list of API providers.
|
||||
|
||||
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
|
||||
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
|
||||
"""
|
||||
|
||||
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
||||
provider_registry = provider_registry or get_provider_registry()
|
||||
|
||||
distro_dir = tempfile.mkdtemp()
|
||||
provider_configs_by_api = {}
|
||||
for api_provider in api_providers:
|
||||
api_str, provider = api_provider.split("=")
|
||||
api = Api(api_str)
|
||||
|
||||
providers_by_type = provider_registry[api]
|
||||
provider_spec = providers_by_type.get(provider)
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"inline::{provider}")
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"remote::{provider}")
|
||||
|
||||
if not provider_spec:
|
||||
raise ValueError(
|
||||
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
|
||||
)
|
||||
|
||||
# call method "sample_run_config" on the provider spec config class
|
||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||
|
||||
provider_configs_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=provider,
|
||||
provider_type=provider_spec.provider_type,
|
||||
config=provider_config,
|
||||
)
|
||||
]
|
||||
config = StackRunConfig(
|
||||
image_name="distro-test",
|
||||
apis=list(provider_configs_by_api.keys()),
|
||||
providers=provider_configs_by_api,
|
||||
storage=StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=f"{distro_dir}/kvstore.db"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=f"{distro_dir}/sql_store.db"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="openai_conversations"),
|
||||
prompts=KVStoreReference(backend="kv_default", namespace="prompts"),
|
||||
),
|
||||
),
|
||||
)
|
||||
return config
|
||||
117
src/llama_stack/core/start_stack.sh
Executable file
117
src/llama_stack/core/start_stack.sh
Executable file
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
env_type="$1"
|
||||
shift
|
||||
|
||||
env_path_or_name="$1"
|
||||
container_image="localhost/$env_path_or_name"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
|
||||
# Initialize variables
|
||||
yaml_config=""
|
||||
other_args=""
|
||||
|
||||
# Process remaining arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--config)
|
||||
if [[ -n "$2" ]]; then
|
||||
yaml_config="$2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check if yaml_config is required
|
||||
if [[ "$env_type" == "venv" ]] && [ -z "$yaml_config" ]; then
|
||||
echo -e "${RED}Error: --config is required for venv environment${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON_BINARY="python"
|
||||
case "$env_type" in
|
||||
"venv")
|
||||
if [ -n "$VIRTUAL_ENV" ] && [ "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
|
||||
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
|
||||
else
|
||||
# Activate virtual environment
|
||||
if [ ! -d "$env_path_or_name" ]; then
|
||||
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$env_path_or_name/bin/activate" ]; then
|
||||
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source "$env_path_or_name/bin/activate"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
# Handle unsupported env_types here
|
||||
echo -e "${RED}Error: Unsupported environment type '$env_type'. Only 'venv' is supported.${NC}" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
if [[ "$env_type" == "venv" ]]; then
|
||||
set -x
|
||||
|
||||
if [ -n "$yaml_config" ]; then
|
||||
yaml_config_arg="$yaml_config"
|
||||
else
|
||||
yaml_config_arg=""
|
||||
fi
|
||||
|
||||
llama stack run \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||
echo -e "Please refer to the documentation for more information: https://llamastack.github.io/latest/distributions/building_distro.html#llama-stack-build"
|
||||
exit 1
|
||||
fi
|
||||
5
src/llama_stack/core/storage/__init__.py
Normal file
5
src/llama_stack/core/storage/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
287
src/llama_stack/core/storage/datatypes.py
Normal file
287
src/llama_stack/core/storage/datatypes.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class StorageBackendType(StrEnum):
|
||||
KV_REDIS = "kv_redis"
|
||||
KV_SQLITE = "kv_sqlite"
|
||||
KV_POSTGRES = "kv_postgres"
|
||||
KV_MONGODB = "kv_mongodb"
|
||||
SQL_SQLITE = "sql_sqlite"
|
||||
SQL_POSTGRES = "sql_postgres"
|
||||
|
||||
|
||||
class CommonConfig(BaseModel):
|
||||
namespace: str | None = Field(
|
||||
default=None,
|
||||
description="All keys will be prefixed with this namespace",
|
||||
)
|
||||
|
||||
|
||||
class RedisKVStoreConfig(CommonConfig):
|
||||
type: Literal[StorageBackendType.KV_REDIS] = StorageBackendType.KV_REDIS
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["redis"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
"type": StorageBackendType.KV_REDIS.value,
|
||||
"host": "${env.REDIS_HOST:=localhost}",
|
||||
"port": "${env.REDIS_PORT:=6379}",
|
||||
}
|
||||
|
||||
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal[StorageBackendType.KV_SQLITE] = StorageBackendType.KV_SQLITE
|
||||
db_path: str = Field(
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["aiosqlite"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
|
||||
return {
|
||||
"type": StorageBackendType.KV_SQLITE.value,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
|
||||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[StorageBackendType.KV_POSTGRES] = StorageBackendType.KV_POSTGRES
|
||||
host: str = "localhost"
|
||||
port: int | str = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
ssl_mode: str | None = None
|
||||
ca_cert_path: str | None = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
|
||||
return {
|
||||
"type": StorageBackendType.KV_POSTGRES.value,
|
||||
"host": "${env.POSTGRES_HOST:=localhost}",
|
||||
"port": "${env.POSTGRES_PORT:=5432}",
|
||||
"db": "${env.POSTGRES_DB:=llamastack}",
|
||||
"user": "${env.POSTGRES_USER:=llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
|
||||
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@field_validator("table_name")
|
||||
def validate_table_name(cls, v: str) -> str:
|
||||
# PostgreSQL identifiers rules:
|
||||
# - Must start with a letter or underscore
|
||||
# - Can contain letters, numbers, and underscores
|
||||
# - Maximum length is 63 bytes
|
||||
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
|
||||
if not re.match(pattern, v):
|
||||
raise ValueError(
|
||||
"Invalid table name. Must start with letter or underscore and contain only letters, numbers, and underscores"
|
||||
)
|
||||
if len(v) > 63:
|
||||
raise ValueError("Table name must be less than 63 characters")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["psycopg2-binary"]
|
||||
|
||||
|
||||
class MongoDBKVStoreConfig(CommonConfig):
|
||||
type: Literal[StorageBackendType.KV_MONGODB] = StorageBackendType.KV_MONGODB
|
||||
host: str = "localhost"
|
||||
port: int = 27017
|
||||
db: str = "llamastack"
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["pymongo"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
|
||||
return {
|
||||
"type": StorageBackendType.KV_MONGODB.value,
|
||||
"host": "${env.MONGODB_HOST:=localhost}",
|
||||
"port": "${env.MONGODB_PORT:=5432}",
|
||||
"db": "${env.MONGODB_DB}",
|
||||
"user": "${env.MONGODB_USER}",
|
||||
"password": "${env.MONGODB_PASSWORD}",
|
||||
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
|
||||
}
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreConfig(BaseModel):
|
||||
@property
|
||||
@abstractmethod
|
||||
def engine_str(self) -> str: ...
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal[StorageBackendType.SQL_SQLITE] = StorageBackendType.SQL_SQLITE
|
||||
db_path: str = Field(
|
||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||
)
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
|
||||
return {
|
||||
"type": StorageBackendType.SQL_SQLITE.value,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return super().pip_packages() + ["aiosqlite"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal[StorageBackendType.SQL_POSTGRES] = StorageBackendType.SQL_POSTGRES
|
||||
host: str = "localhost"
|
||||
port: int | str = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return super().pip_packages() + ["asyncpg"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return {
|
||||
"type": StorageBackendType.SQL_POSTGRES.value,
|
||||
"host": "${env.POSTGRES_HOST:=localhost}",
|
||||
"port": "${env.POSTGRES_PORT:=5432}",
|
||||
"db": "${env.POSTGRES_DB:=llamastack}",
|
||||
"user": "${env.POSTGRES_USER:=llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
|
||||
}
|
||||
|
||||
|
||||
# reference = (backend_name, table_name)
|
||||
class SqlStoreReference(BaseModel):
|
||||
"""A reference to a 'SQL-like' persistent store. A table name must be provided."""
|
||||
|
||||
table_name: str = Field(
|
||||
description="Name of the table to use for the SqlStore",
|
||||
)
|
||||
|
||||
backend: str = Field(
|
||||
description="Name of backend from storage.backends",
|
||||
)
|
||||
|
||||
|
||||
# reference = (backend_name, namespace)
|
||||
class KVStoreReference(BaseModel):
|
||||
"""A reference to a 'key-value' persistent store. A namespace must be provided."""
|
||||
|
||||
namespace: str = Field(
|
||||
description="Key prefix for KVStore backends",
|
||||
)
|
||||
|
||||
backend: str = Field(
|
||||
description="Name of backend from storage.backends",
|
||||
)
|
||||
|
||||
|
||||
StorageBackendConfig = Annotated[
|
||||
RedisKVStoreConfig
|
||||
| SqliteKVStoreConfig
|
||||
| PostgresKVStoreConfig
|
||||
| MongoDBKVStoreConfig
|
||||
| SqliteSqlStoreConfig
|
||||
| PostgresSqlStoreConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class InferenceStoreReference(SqlStoreReference):
|
||||
"""Inference store configuration with queue tuning."""
|
||||
|
||||
max_write_queue_size: int = Field(
|
||||
default=10000,
|
||||
description="Max queued writes for inference store",
|
||||
)
|
||||
num_writers: int = Field(
|
||||
default=4,
|
||||
description="Number of concurrent background writers",
|
||||
)
|
||||
|
||||
|
||||
class ResponsesStoreReference(InferenceStoreReference):
|
||||
"""Responses store configuration with queue tuning."""
|
||||
|
||||
|
||||
class ServerStoresConfig(BaseModel):
|
||||
metadata: KVStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Metadata store configuration (uses KV backend)",
|
||||
)
|
||||
inference: InferenceStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Inference store configuration (uses SQL backend)",
|
||||
)
|
||||
conversations: SqlStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Conversations store configuration (uses SQL backend)",
|
||||
)
|
||||
responses: ResponsesStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Responses store configuration (uses SQL backend)",
|
||||
)
|
||||
prompts: KVStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Prompts store configuration (uses KV backend)",
|
||||
)
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
backends: dict[str, StorageBackendConfig] = Field(
|
||||
description="Named backend configurations (e.g., 'default', 'cache')",
|
||||
)
|
||||
stores: ServerStoresConfig = Field(
|
||||
default_factory=lambda: ServerStoresConfig(),
|
||||
description="Named references to storage backends used by the stack core",
|
||||
)
|
||||
7
src/llama_stack/core/store/__init__.py
Normal file
7
src/llama_stack/core/store/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .registry import * # noqa: F401 F403
|
||||
199
src/llama_stack/core/store/registry.py
Normal file
199
src/llama_stack/core/store/registry.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Protocol
|
||||
|
||||
import pydantic
|
||||
|
||||
from llama_stack.core.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
||||
logger = get_logger(__name__, category="core::registry")
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
async def get_all(self) -> list[RoutableObjectWithProvider]: ...
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
|
||||
|
||||
def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None: ...
|
||||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v10"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
def _get_registry_key_range() -> tuple[str, str]:
|
||||
"""Returns the start and end keys for the registry range query."""
|
||||
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
||||
return start_key, f"{start_key}\xff"
|
||||
|
||||
|
||||
def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]:
|
||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||
all_objects = []
|
||||
for value in values:
|
||||
try:
|
||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||
all_objects.append(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
|
||||
continue
|
||||
|
||||
return all_objects
|
||||
|
||||
|
||||
class DiskDistributionRegistry(DistributionRegistry):
|
||||
def __init__(self, kvstore: KVStore):
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
# Disk registry does not have a cache
|
||||
raise NotImplementedError("Disk registry does not have a cache")
|
||||
|
||||
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||
return _parse_registry_values(values)
|
||||
|
||||
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
|
||||
return None
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
obj.model_dump_json(),
|
||||
)
|
||||
return obj
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
existing_obj = await self.get(obj.type, obj.identifier)
|
||||
if existing_obj and existing_obj != obj:
|
||||
raise ValueError(
|
||||
f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. "
|
||||
"Unregister it first if you want to replace it."
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
obj.model_dump_json(),
|
||||
)
|
||||
return True
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None:
|
||||
await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||
|
||||
|
||||
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||
def __init__(self, kvstore: KVStore):
|
||||
super().__init__(kvstore)
|
||||
self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {}
|
||||
self._initialized = False
|
||||
self._initialize_lock = asyncio.Lock()
|
||||
self._cache_lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _locked_cache(self):
|
||||
"""Context manager for safely accessing the cache with a lock."""
|
||||
async with self._cache_lock:
|
||||
yield self.cache
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
"""Ensures the registry is initialized before operations."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._initialize_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||
objects = _parse_registry_values(values)
|
||||
|
||||
async with self._locked_cache() as cache:
|
||||
for obj in objects:
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
cache[cache_key] = obj
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self._ensure_initialized()
|
||||
|
||||
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
return self.cache.get((type, identifier), None)
|
||||
|
||||
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||
await self._ensure_initialized()
|
||||
async with self._locked_cache() as cache:
|
||||
return list(cache.values())
|
||||
|
||||
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
await self._ensure_initialized()
|
||||
cache_key = (type, identifier)
|
||||
|
||||
async with self._locked_cache() as cache:
|
||||
return cache.get(cache_key, None)
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
await self._ensure_initialized()
|
||||
success = await super().register(obj)
|
||||
|
||||
if success:
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
cache[cache_key] = obj
|
||||
|
||||
return success
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await super().update(obj)
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
cache[cache_key] = obj
|
||||
return obj
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None:
|
||||
await super().delete(type, identifier)
|
||||
cache_key = (type, identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
if cache_key in cache:
|
||||
del cache[cache_key]
|
||||
|
||||
|
||||
async def create_dist_registry(
|
||||
metadata_store: KVStoreReference, image_name: str
|
||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
dist_kvstore = await kvstore_impl(metadata_store)
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
await dist_registry.initialize()
|
||||
return dist_registry, dist_kvstore
|
||||
32
src/llama_stack/core/telemetry/__init__.py
Normal file
32
src/llama_stack/core/telemetry/__init__.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .telemetry import Telemetry
|
||||
from .trace_protocol import serialize_value, trace_protocol
|
||||
from .tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
ROOT_SPAN_MARKERS,
|
||||
end_trace,
|
||||
enqueue_event,
|
||||
get_current_span,
|
||||
setup_logger,
|
||||
span,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Telemetry",
|
||||
"trace_protocol",
|
||||
"serialize_value",
|
||||
"CURRENT_TRACE_CONTEXT",
|
||||
"ROOT_SPAN_MARKERS",
|
||||
"end_trace",
|
||||
"enqueue_event",
|
||||
"get_current_span",
|
||||
"setup_logger",
|
||||
"span",
|
||||
"start_trace",
|
||||
]
|
||||
250
src/llama_stack/core/telemetry/telemetry.py
Normal file
250
src/llama_stack/core/telemetry/telemetry.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
MetricEvent,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
StructuredLogEvent,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.apis.telemetry import (
|
||||
Telemetry as TelemetryBase,
|
||||
)
|
||||
from llama_stack.core.telemetry.tracing import ROOT_SPAN_MARKERS
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
||||
logger = get_logger(name=__name__, category="telemetry")
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class Telemetry(TelemetryBase):
|
||||
def __init__(self) -> None:
|
||||
self.meter = None
|
||||
|
||||
global _TRACER_PROVIDER
|
||||
# Initialize the correct span processor based on the provider state.
|
||||
# This is needed since once the span processor is set, it cannot be unset.
|
||||
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
|
||||
# Since the library client can be recreated multiple times in a notebook,
|
||||
# the kernel will hold on to the span processor and cause duplicate spans to be written.
|
||||
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
|
||||
if _TRACER_PROVIDER is None:
|
||||
provider = TracerProvider()
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
|
||||
# Use single OTLP endpoint for all telemetry signals
|
||||
|
||||
# Let OpenTelemetry SDK handle endpoint construction automatically
|
||||
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
|
||||
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
|
||||
span_exporter = OTLPSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.is_otel_endpoint_set = True
|
||||
else:
|
||||
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry")
|
||||
self.is_otel_endpoint_set = False
|
||||
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.is_otel_endpoint_set:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type.value,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
else:
|
||||
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
with self._lock:
|
||||
# Only try to add to span if we have a valid span_id
|
||||
if event.span_id:
|
||||
try:
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=f"metric.{event.metric}",
|
||||
attributes={
|
||||
"value": event.value,
|
||||
"unit": event.unit,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
# Invalid span_id or span not found, but we already logged to console above
|
||||
pass
|
||||
except Exception:
|
||||
# Lock acquisition failed
|
||||
logger.debug("Failed to acquire lock to add metric to span")
|
||||
|
||||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = int(event.span_id, 16)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
# Extract these W3C trace context attributes so they are not written to
|
||||
# underlying storage, as we just need them to propagate the trace context.
|
||||
traceparent = event.attributes.pop("traceparent", None)
|
||||
tracestate = event.attributes.pop("tracestate", None)
|
||||
if traceparent:
|
||||
# If we have a traceparent header value, we're not the root span.
|
||||
for root_attribute in ROOT_SPAN_MARKERS:
|
||||
event.attributes.pop(root_attribute, None)
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
context = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
"tracestate": tracestate,
|
||||
}
|
||||
context = TraceContextTextMapPropagator().extract(carrier=carrier)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
if span:
|
||||
if event.attributes:
|
||||
span.set_attributes(event.attributes)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
span.set_status(status)
|
||||
span.end()
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
145
src/llama_stack/core/telemetry/trace_protocol.py
Normal file
145
src/llama_stack/core/telemetry/trace_protocol.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
|
||||
type JSONValue = Primitive | list["JSONValue"] | dict[str, "JSONValue"]
|
||||
|
||||
|
||||
def serialize_value(value: Any) -> str:
|
||||
return str(_prepare_for_json(value))
|
||||
|
||||
|
||||
def _prepare_for_json(value: Any) -> JSONValue:
|
||||
"""Serialize a single value into JSON-compatible format."""
|
||||
if value is None:
|
||||
return ""
|
||||
elif isinstance(value, str | int | float | bool):
|
||||
return value
|
||||
elif hasattr(value, "_name_"):
|
||||
return cast(str, value._name_)
|
||||
elif isinstance(value, BaseModel):
|
||||
return cast(JSONValue, json.loads(value.model_dump_json()))
|
||||
elif isinstance(value, list | tuple | set):
|
||||
return [_prepare_for_json(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {str(k): _prepare_for_json(v) for k, v in value.items()}
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return cast(JSONValue, value)
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
|
||||
def trace_protocol[T: type[Any]](cls: T) -> T:
|
||||
"""
|
||||
A class decorator that automatically traces all methods in a protocol/base class
|
||||
and its inheriting classes.
|
||||
"""
|
||||
|
||||
def trace_method(method: Callable[..., Any]) -> Callable[..., Any]:
|
||||
is_async = asyncio.iscoroutinefunction(method)
|
||||
is_async_gen = inspect.isasyncgenfunction(method)
|
||||
|
||||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple[str, str, dict[str, Primitive]]:
|
||||
class_name = self.__class__.__name__
|
||||
method_name = method.__name__
|
||||
span_type = "async_generator" if is_async_gen else "async" if is_async else "sync"
|
||||
sig = inspect.signature(method)
|
||||
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
||||
combined_args: dict[str, str] = {}
|
||||
for i, arg in enumerate(args):
|
||||
param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}"
|
||||
combined_args[param_name] = serialize_value(arg)
|
||||
for k, v in kwargs.items():
|
||||
combined_args[str(k)] = serialize_value(v)
|
||||
|
||||
span_attributes: dict[str, Primitive] = {
|
||||
"__autotraced__": True,
|
||||
"__class__": class_name,
|
||||
"__method__": method_name,
|
||||
"__type__": span_type,
|
||||
"__args__": json.dumps(combined_args),
|
||||
}
|
||||
|
||||
return class_name, method_name, span_attributes
|
||||
|
||||
@wraps(method)
|
||||
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
count = 0
|
||||
try:
|
||||
async for item in method(self, *args, **kwargs):
|
||||
yield item
|
||||
count += 1
|
||||
finally:
|
||||
span.set_attribute("chunk_count", count)
|
||||
|
||||
@wraps(method)
|
||||
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
result = await method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("error", str(e))
|
||||
raise
|
||||
|
||||
@wraps(method)
|
||||
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
result = method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("error", str(e))
|
||||
raise
|
||||
|
||||
if is_async_gen:
|
||||
return async_gen_wrapper
|
||||
elif is_async:
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
|
||||
|
||||
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
|
||||
if original_init_subclass:
|
||||
cast(Callable[..., None], original_init_subclass)(**kwargs)
|
||||
|
||||
for name, method in vars(cls_child).items():
|
||||
if inspect.isfunction(method) and not name.startswith("_"):
|
||||
setattr(cls_child, name, trace_method(method)) # noqa: B010
|
||||
|
||||
cls_any = cast(Any, cls)
|
||||
cls_any.__init_subclass__ = classmethod(__init_subclass__)
|
||||
|
||||
return cls
|
||||
388
src/llama_stack/core/telemetry/tracing.py
Normal file
388
src/llama_stack/core/telemetry/tracing.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging # allow-direct-logging
|
||||
import queue
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Self
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
LogSeverity,
|
||||
Span,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
StructuredLogEvent,
|
||||
Telemetry,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.core.telemetry.trace_protocol import serialize_value
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion
|
||||
_fallback_logger = logging.getLogger("llama_stack.telemetry.background")
|
||||
if not _fallback_logger.handlers:
|
||||
_fallback_logger.propagate = False
|
||||
_fallback_logger.setLevel(logging.ERROR)
|
||||
_fallback_handler = logging.StreamHandler(sys.stderr)
|
||||
_fallback_handler.setLevel(logging.ERROR)
|
||||
_fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
|
||||
_fallback_logger.addHandler(_fallback_handler)
|
||||
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
# The logical root span may not be visible to this process if a parent context
|
||||
# is passed in. The local root span is the first local span in a trace.
|
||||
LOCAL_ROOT_SPAN_MARKER = "__local_root_span__"
|
||||
|
||||
|
||||
def trace_id_to_str(trace_id: int) -> str:
|
||||
"""Convenience trace ID formatting method
|
||||
Args:
|
||||
trace_id: Trace ID int
|
||||
|
||||
Returns:
|
||||
The trace ID as 32-byte hexadecimal string
|
||||
"""
|
||||
return format(trace_id, "032x")
|
||||
|
||||
|
||||
def span_id_to_str(span_id: int) -> str:
|
||||
"""Convenience span ID formatting method
|
||||
Args:
|
||||
span_id: Span ID int
|
||||
|
||||
Returns:
|
||||
The span ID as 16-byte hexadecimal string
|
||||
"""
|
||||
return format(span_id, "016x")
|
||||
|
||||
|
||||
def generate_span_id() -> str:
|
||||
span_id = secrets.randbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = secrets.randbits(64)
|
||||
return span_id_to_str(span_id)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
trace_id = secrets.randbits(128)
|
||||
while trace_id == INVALID_TRACE_ID:
|
||||
trace_id = secrets.randbits(128)
|
||||
return trace_id_to_str(trace_id)
|
||||
|
||||
|
||||
LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0
|
||||
|
||||
|
||||
class BackgroundLogger:
|
||||
def __init__(self, api: Telemetry, capacity: int = 100000):
|
||||
self.api = api
|
||||
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
self._last_queue_full_log_time: float = 0.0
|
||||
self._dropped_since_last_notice: int = 0
|
||||
|
||||
def log_event(self, event: Event) -> None:
|
||||
try:
|
||||
self.log_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
# Aggregate drops and emit at most once per interval via fallback logger
|
||||
self._dropped_since_last_notice += 1
|
||||
current_time = time.time()
|
||||
if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS:
|
||||
_fallback_logger.error(
|
||||
"Log queue is full; dropped %d events since last notice",
|
||||
self._dropped_since_last_notice,
|
||||
)
|
||||
self._last_queue_full_log_time = current_time
|
||||
self._dropped_since_last_notice = 0
|
||||
|
||||
def _worker(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._process_logs())
|
||||
|
||||
async def _process_logs(self):
|
||||
while True:
|
||||
try:
|
||||
event = self.log_queue.get()
|
||||
await self.api.log_event(event)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print("Error processing log event")
|
||||
finally:
|
||||
self.log_queue.task_done()
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.log_queue.join()
|
||||
|
||||
|
||||
BACKGROUND_LOGGER: BackgroundLogger | None = None
|
||||
|
||||
|
||||
def enqueue_event(event: Event) -> None:
|
||||
"""Enqueue a telemetry event to the background logger if available.
|
||||
|
||||
This provides a non-blocking path for routers and other hot paths to
|
||||
submit telemetry without awaiting the Telemetry API, reducing contention
|
||||
with the main event loop.
|
||||
"""
|
||||
global BACKGROUND_LOGGER
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
BACKGROUND_LOGGER.log_event(event)
|
||||
|
||||
|
||||
class TraceContext:
|
||||
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
||||
self.logger = logger
|
||||
self.trace_id = trace_id
|
||||
self.spans: list[Span] = []
|
||||
|
||||
def push_span(self, name: str, attributes: dict[str, Any] | None = None) -> Span:
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_span_id(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(UTC),
|
||||
parent_span_id=current_span.span_id if current_span else None,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanStartPayload(
|
||||
name=span.name,
|
||||
parent_span_id=span.parent_span_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.spans.append(span)
|
||||
return span
|
||||
|
||||
def pop_span(self, status: SpanStatus = SpanStatus.OK) -> None:
|
||||
span = self.spans.pop()
|
||||
if span is not None:
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanEndPayload(
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def get_current_span(self) -> Span | None:
|
||||
return self.spans[-1] if self.spans else None
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar(
|
||||
"trace_context", default=None
|
||||
)
|
||||
|
||||
|
||||
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||
global BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
root_logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: dict[str, Any] | None = None) -> TraceContext | None:
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return None
|
||||
|
||||
trace_id = generate_trace_id()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
# Mark this span as the root for the trace for now. The processing of
|
||||
# traceparent context if supplied comes later and will result in the
|
||||
# ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root,
|
||||
# i.e. the root of the spans originating in this process as this is
|
||||
# needed to ensure that we insert this 'local' root span's id into
|
||||
# the trace record in sqlite store.
|
||||
attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {})
|
||||
context.push_span(name, attributes)
|
||||
|
||||
CURRENT_TRACE_CONTEXT.set(context)
|
||||
return context
|
||||
|
||||
|
||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
logger.debug("No trace context to end")
|
||||
return
|
||||
|
||||
context.pop_span(status)
|
||||
CURRENT_TRACE_CONTEXT.set(None)
|
||||
|
||||
|
||||
def severity(levelname: str) -> LogSeverity:
|
||||
if levelname == "DEBUG":
|
||||
return LogSeverity.DEBUG
|
||||
elif levelname == "INFO":
|
||||
return LogSeverity.INFO
|
||||
elif levelname == "WARNING":
|
||||
return LogSeverity.WARN
|
||||
elif levelname == "ERROR":
|
||||
return LogSeverity.ERROR
|
||||
elif levelname == "CRITICAL":
|
||||
return LogSeverity.CRITICAL
|
||||
else:
|
||||
raise ValueError(f"Unknown log level: {levelname}")
|
||||
|
||||
|
||||
# TODO: ideally, the actual emitting should be done inside a separate daemon
|
||||
# process completely isolated from the server
|
||||
class TelemetryHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
|
||||
if record.module in ("asyncio", "selector_events"):
|
||||
return
|
||||
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
return
|
||||
|
||||
span = context.get_current_span()
|
||||
if span is None:
|
||||
return
|
||||
|
||||
enqueue_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
message=self.format(record),
|
||||
severity=severity(record.levelname),
|
||||
)
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SpanContextManager:
|
||||
def __init__(self, name: str, attributes: dict[str, Any] | None = None):
|
||||
self.name = name
|
||||
self.attributes = attributes
|
||||
self.span: Span | None = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def set_attribute(self, key: str, value: Any) -> None:
|
||||
if self.span:
|
||||
if self.span.attributes is None:
|
||||
self.span.attributes = {}
|
||||
self.span.attributes[key] = serialize_value(value)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async with self:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper(*args, **kwargs)
|
||||
else:
|
||||
return sync_wrapper(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def span(name: str, attributes: dict[str, Any] | None = None) -> SpanContextManager:
|
||||
return SpanContextManager(name, attributes)
|
||||
|
||||
|
||||
def get_current_span() -> Span | None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
if CURRENT_TRACE_CONTEXT is None:
|
||||
logger.debug("No trace context to get current span")
|
||||
return None
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context:
|
||||
return context.get_current_span()
|
||||
return None
|
||||
49
src/llama_stack/core/testing_context.py
Normal file
49
src/llama_stack/core/testing_context.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
|
||||
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
|
||||
|
||||
|
||||
def get_test_context() -> str | None:
|
||||
return TEST_CONTEXT.get()
|
||||
|
||||
|
||||
def set_test_context(value: str | None):
|
||||
return TEST_CONTEXT.set(value)
|
||||
|
||||
|
||||
def reset_test_context(token) -> None:
|
||||
TEST_CONTEXT.reset(token)
|
||||
|
||||
|
||||
def sync_test_context_from_provider_data():
|
||||
"""Sync test context from provider data when running in server test mode."""
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
||||
return None
|
||||
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
if stack_config_type != "server":
|
||||
return None
|
||||
|
||||
try:
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
except LookupError:
|
||||
provider_data = None
|
||||
|
||||
if provider_data and "__test_id" in provider_data:
|
||||
return TEST_CONTEXT.set(provider_data["__test_id"])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_debug_mode() -> bool:
|
||||
"""Check if test recording debug mode is enabled via LLAMA_STACK_TEST_DEBUG env var."""
|
||||
return os.environ.get("LLAMA_STACK_TEST_DEBUG", "").lower() in ("1", "true", "yes")
|
||||
11
src/llama_stack/core/ui/Containerfile
Normal file
11
src/llama_stack/core/ui/Containerfile
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# More info on playground configuration can be found here:
|
||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||
|
||||
FROM python:3.12-slim
|
||||
WORKDIR /app
|
||||
COPY . /app/
|
||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||
/usr/local/bin/pip3 install -r requirements.txt
|
||||
EXPOSE 8501
|
||||
|
||||
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
||||
50
src/llama_stack/core/ui/README.md
Normal file
50
src/llama_stack/core/ui/README.md
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# (Experimental) LLama Stack UI
|
||||
|
||||
## Docker Setup
|
||||
|
||||
:warning: This is a work in progress.
|
||||
|
||||
## Developer Setup
|
||||
|
||||
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
|
||||
|
||||
```
|
||||
llama stack list-deps together | xargs -L1 uv pip install
|
||||
|
||||
llama stack run together
|
||||
```
|
||||
|
||||
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
|
||||
|
||||
```bash
|
||||
llama-stack-client datasets register \
|
||||
--dataset-id "mmlu" \
|
||||
--provider-id "huggingface" \
|
||||
--url "https://huggingface.co/datasets/llamastack/evals" \
|
||||
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
|
||||
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}'
|
||||
```
|
||||
|
||||
```bash
|
||||
llama-stack-client benchmarks register \
|
||||
--eval-task-id meta-reference-mmlu \
|
||||
--provider-id meta-reference \
|
||||
--dataset-id mmlu \
|
||||
--scoring-functions basic::regex_parser_multiple_choice_answer
|
||||
```
|
||||
|
||||
3. Start Streamlit UI
|
||||
|
||||
```bash
|
||||
uv run --with ".[ui]" streamlit run llama_stack.core/ui/app.py
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Environment Variable | Description | Default Value |
|
||||
|----------------------------|------------------------------------|---------------------------|
|
||||
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
|
||||
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
|
||||
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
|
||||
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
|
||||
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
|
||||
5
src/llama_stack/core/ui/__init__.py
Normal file
5
src/llama_stack/core/ui/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
55
src/llama_stack/core/ui/app.py
Normal file
55
src/llama_stack/core/ui/app.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def main():
|
||||
# Evaluation pages
|
||||
application_evaluation_page = st.Page(
|
||||
"page/evaluations/app_eval.py",
|
||||
title="Evaluations (Scoring)",
|
||||
icon="📊",
|
||||
default=False,
|
||||
)
|
||||
native_evaluation_page = st.Page(
|
||||
"page/evaluations/native_eval.py",
|
||||
title="Evaluations (Generation + Scoring)",
|
||||
icon="📊",
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Playground pages
|
||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
|
||||
|
||||
# Distribution pages
|
||||
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||
provider_page = st.Page(
|
||||
"page/distribution/providers.py",
|
||||
title="API Providers",
|
||||
icon="🔍",
|
||||
default=False,
|
||||
)
|
||||
|
||||
pg = st.navigation(
|
||||
{
|
||||
"Playground": [
|
||||
chat_page,
|
||||
rag_page,
|
||||
tool_page,
|
||||
application_evaluation_page,
|
||||
native_evaluation_page,
|
||||
],
|
||||
"Inspect": [provider_page, resources_page],
|
||||
},
|
||||
expanded=False,
|
||||
)
|
||||
pg.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
src/llama_stack/core/ui/modules/__init__.py
Normal file
5
src/llama_stack/core/ui/modules/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
32
src/llama_stack/core/ui/modules/api.py
Normal file
32
src/llama_stack/core/ui/modules/api.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
|
||||
class LlamaStackApi:
|
||||
def __init__(self):
|
||||
self.client = LlamaStackClient(
|
||||
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
||||
provider_data={
|
||||
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
|
||||
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
|
||||
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
|
||||
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
|
||||
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
|
||||
"""Run scoring on a single row"""
|
||||
if not scoring_params:
|
||||
scoring_params = dict.fromkeys(scoring_function_ids)
|
||||
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
||||
|
||||
|
||||
llama_stack_api = LlamaStackApi()
|
||||
42
src/llama_stack/core/ui/modules/utils.py
Normal file
42
src/llama_stack/core/ui/modules/utils.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def process_dataset(file):
|
||||
if file is None:
|
||||
return "No file uploaded", None
|
||||
|
||||
try:
|
||||
# Determine file type and read accordingly
|
||||
file_ext = os.path.splitext(file.name)[1].lower()
|
||||
if file_ext == ".csv":
|
||||
df = pd.read_csv(file)
|
||||
elif file_ext in [".xlsx", ".xls"]:
|
||||
df = pd.read_excel(file)
|
||||
else:
|
||||
return "Unsupported file format. Please upload a CSV or Excel file.", None
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error processing file: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def data_url_from_file(file) -> str:
|
||||
file_content = file.getvalue()
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type = file.type
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
5
src/llama_stack/core/ui/page/__init__.py
Normal file
5
src/llama_stack/core/ui/page/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
5
src/llama_stack/core/ui/page/distribution/__init__.py
Normal file
5
src/llama_stack/core/ui/page/distribution/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
18
src/llama_stack/core/ui/page/distribution/datasets.py
Normal file
18
src/llama_stack/core/ui/page/distribution/datasets.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def datasets():
|
||||
st.header("Datasets")
|
||||
|
||||
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
|
||||
if len(datasets_info) > 0:
|
||||
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
|
||||
st.json(datasets_info[selected_dataset], expanded=True)
|
||||
20
src/llama_stack/core/ui/page/distribution/eval_tasks.py
Normal file
20
src/llama_stack/core/ui/page/distribution/eval_tasks.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def benchmarks():
|
||||
# Benchmarks Section
|
||||
st.header("Benchmarks")
|
||||
|
||||
benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()}
|
||||
|
||||
if len(benchmarks_info) > 0:
|
||||
selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect")
|
||||
st.json(benchmarks_info[selected_benchmark], expanded=True)
|
||||
18
src/llama_stack/core/ui/page/distribution/models.py
Normal file
18
src/llama_stack/core/ui/page/distribution/models.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def models():
|
||||
# Models Section
|
||||
st.header("Models")
|
||||
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
|
||||
|
||||
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
||||
st.json(models_info[selected_model])
|
||||
27
src/llama_stack/core/ui/page/distribution/providers.py
Normal file
27
src/llama_stack/core/ui/page/distribution/providers.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def providers():
|
||||
st.header("🔍 API Providers")
|
||||
apis_providers_lst = llama_stack_api.client.providers.list()
|
||||
api_to_providers = {}
|
||||
for api_provider in apis_providers_lst:
|
||||
if api_provider.api in api_to_providers:
|
||||
api_to_providers[api_provider.api].append(api_provider)
|
||||
else:
|
||||
api_to_providers[api_provider.api] = [api_provider]
|
||||
|
||||
for api in api_to_providers.keys():
|
||||
st.markdown(f"###### {api}")
|
||||
st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
|
||||
|
||||
|
||||
providers()
|
||||
48
src/llama_stack/core/ui/page/distribution/resources.py
Normal file
48
src/llama_stack/core/ui/page/distribution/resources.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
from llama_stack.core.ui.page.distribution.datasets import datasets
|
||||
from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
|
||||
from llama_stack.core.ui.page.distribution.models import models
|
||||
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.core.ui.page.distribution.shields import shields
|
||||
|
||||
|
||||
def resources_page():
|
||||
options = [
|
||||
"Models",
|
||||
"Shields",
|
||||
"Scoring Functions",
|
||||
"Datasets",
|
||||
"Benchmarks",
|
||||
]
|
||||
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
|
||||
selected_resource = option_menu(
|
||||
None,
|
||||
options,
|
||||
icons=icons,
|
||||
orientation="horizontal",
|
||||
styles={
|
||||
"nav-link": {
|
||||
"font-size": "12px",
|
||||
},
|
||||
},
|
||||
)
|
||||
if selected_resource == "Benchmarks":
|
||||
benchmarks()
|
||||
elif selected_resource == "Datasets":
|
||||
datasets()
|
||||
elif selected_resource == "Models":
|
||||
models()
|
||||
elif selected_resource == "Scoring Functions":
|
||||
scoring_functions()
|
||||
elif selected_resource == "Shields":
|
||||
shields()
|
||||
|
||||
|
||||
resources_page()
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def scoring_functions():
|
||||
st.header("Scoring Functions")
|
||||
|
||||
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
|
||||
|
||||
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
|
||||
st.json(scoring_functions_info[selected_scoring_function], expanded=True)
|
||||
19
src/llama_stack/core/ui/page/distribution/shields.py
Normal file
19
src/llama_stack/core/ui/page/distribution/shields.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def shields():
|
||||
# Shields Section
|
||||
st.header("Shields")
|
||||
|
||||
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
|
||||
|
||||
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
|
||||
st.json(shields_info[selected_shield])
|
||||
5
src/llama_stack/core/ui/page/evaluations/__init__.py
Normal file
5
src/llama_stack/core/ui/page/evaluations/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
143
src/llama_stack/core/ui/page/evaluations/app_eval.py
Normal file
143
src/llama_stack/core/ui/page/evaluations/app_eval.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.utils import process_dataset
|
||||
|
||||
|
||||
def application_evaluation_page():
|
||||
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Scoring)")
|
||||
|
||||
# File uploader
|
||||
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
|
||||
|
||||
if uploaded_file is None:
|
||||
st.error("No file uploaded")
|
||||
return
|
||||
|
||||
# Process uploaded file
|
||||
df = process_dataset(uploaded_file)
|
||||
if df is None:
|
||||
st.error("Error processing file")
|
||||
return
|
||||
|
||||
# Display dataset information
|
||||
st.success("Dataset loaded successfully!")
|
||||
|
||||
# Display dataframe preview
|
||||
st.subheader("Dataset Preview")
|
||||
st.dataframe(df)
|
||||
|
||||
# Select Scoring Functions to Run Evaluation On
|
||||
st.subheader("Select Scoring Functions")
|
||||
scoring_functions = llama_stack_api.client.scoring_functions.list()
|
||||
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
|
||||
scoring_functions_names = list(scoring_functions.keys())
|
||||
selected_scoring_functions = st.multiselect(
|
||||
"Choose one or more scoring functions",
|
||||
options=scoring_functions_names,
|
||||
help="Choose one or more scoring functions.",
|
||||
)
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [m.identifier for m in available_models]
|
||||
|
||||
scoring_params = {}
|
||||
if selected_scoring_functions:
|
||||
st.write("Selected:")
|
||||
for scoring_fn_id in selected_scoring_functions:
|
||||
scoring_fn = scoring_functions[scoring_fn_id]
|
||||
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
|
||||
new_params = None
|
||||
if scoring_fn.params:
|
||||
new_params = {}
|
||||
for param_name, param_value in scoring_fn.params.to_dict().items():
|
||||
if param_name == "type":
|
||||
new_params[param_name] = param_value
|
||||
continue
|
||||
|
||||
if param_name == "judge_model":
|
||||
value = st.selectbox(
|
||||
f"Select **{param_name}** for {scoring_fn_id}",
|
||||
options=available_models,
|
||||
index=0,
|
||||
key=f"{scoring_fn_id}_{param_name}",
|
||||
)
|
||||
new_params[param_name] = value
|
||||
else:
|
||||
value = st.text_area(
|
||||
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
|
||||
value=json.dumps(param_value, indent=2),
|
||||
height=80,
|
||||
)
|
||||
try:
|
||||
new_params[param_name] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
|
||||
|
||||
st.json(new_params)
|
||||
scoring_params[scoring_fn_id] = new_params
|
||||
|
||||
# Add run evaluation button & slider
|
||||
total_rows = len(df)
|
||||
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
|
||||
|
||||
if st.button("Run Evaluation"):
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = df.to_dict(orient="records")
|
||||
if num_rows < total_rows:
|
||||
rows = rows[:num_rows]
|
||||
|
||||
# Create separate containers for progress text and results
|
||||
progress_text_container = st.empty()
|
||||
results_container = st.empty()
|
||||
output_res = {}
|
||||
for i, r in enumerate(rows):
|
||||
# Update progress
|
||||
progress = i / len(rows)
|
||||
progress_bar.progress(progress, text=progress_text)
|
||||
|
||||
# Run evaluation for current row
|
||||
score_res = llama_stack_api.run_scoring(
|
||||
r,
|
||||
scoring_function_ids=selected_scoring_functions,
|
||||
scoring_params=scoring_params,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(r[k])
|
||||
|
||||
for fn_id in selected_scoring_functions:
|
||||
if fn_id not in output_res:
|
||||
output_res[fn_id] = []
|
||||
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
|
||||
|
||||
# Display current row results using separate containers
|
||||
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||
results_container.json(
|
||||
score_res.to_json(),
|
||||
expanded=2,
|
||||
)
|
||||
|
||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||
|
||||
# Display results in dataframe
|
||||
if output_res:
|
||||
output_df = pd.DataFrame(output_res)
|
||||
st.subheader("Evaluation Results")
|
||||
st.dataframe(output_df)
|
||||
|
||||
|
||||
application_evaluation_page()
|
||||
253
src/llama_stack/core/ui/page/evaluations/native_eval.py
Normal file
253
src/llama_stack/core/ui/page/evaluations/native_eval.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def select_benchmark_1():
|
||||
# Select Benchmarks
|
||||
st.subheader("1. Choose An Eval Task")
|
||||
benchmarks = llama_stack_api.client.benchmarks.list()
|
||||
benchmarks = {et.identifier: et for et in benchmarks}
|
||||
benchmarks_names = list(benchmarks.keys())
|
||||
selected_benchmark = st.selectbox(
|
||||
"Choose an eval task.",
|
||||
options=benchmarks_names,
|
||||
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
|
||||
)
|
||||
with st.expander("View Eval Task"):
|
||||
st.json(benchmarks[selected_benchmark], expanded=True)
|
||||
|
||||
st.session_state["selected_benchmark"] = selected_benchmark
|
||||
st.session_state["benchmarks"] = benchmarks
|
||||
if st.button("Confirm", key="confirm_1"):
|
||||
st.session_state["selected_benchmark_1_next"] = True
|
||||
|
||||
|
||||
def define_eval_candidate_2():
|
||||
if not st.session_state.get("selected_benchmark_1_next", None):
|
||||
return
|
||||
|
||||
st.subheader("2. Define Eval Candidate")
|
||||
st.info(
|
||||
"""
|
||||
Define the configurations for the evaluation candidate model or agent used for generation.
|
||||
Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig.
|
||||
"""
|
||||
)
|
||||
with st.expander("Define Eval Candidate", expanded=True):
|
||||
# Define Eval Candidate
|
||||
candidate_type = st.radio("Candidate Type", ["model", "agent"])
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Sampling Parameters
|
||||
st.markdown("##### Sampling Parameters")
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
)
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
)
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
repetition_penalty = st.slider(
|
||||
"Repetition Penalty",
|
||||
min_value=1.0,
|
||||
max_value=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
if candidate_type == "model":
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
eval_candidate = {
|
||||
"type": "model",
|
||||
"model": selected_model,
|
||||
"sampling_params": {
|
||||
"strategy": strategy,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
}
|
||||
elif candidate_type == "agent":
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful AI assistant.",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
)
|
||||
tools_json = st.text_area(
|
||||
"Tools Configuration (JSON)",
|
||||
value=json.dumps(
|
||||
[
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": "ENTER_BRAVE_API_KEY_HERE",
|
||||
}
|
||||
]
|
||||
),
|
||||
help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.",
|
||||
height=200,
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools_json)
|
||||
except json.JSONDecodeError:
|
||||
st.error("Invalid JSON format for tools configuration")
|
||||
tools = []
|
||||
eval_candidate = {
|
||||
"type": "agent",
|
||||
"config": {
|
||||
"model": selected_model,
|
||||
"instructions": system_prompt,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"tool_prompt_format": "json",
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
"enable_session_persistence": False,
|
||||
},
|
||||
}
|
||||
st.session_state["eval_candidate"] = eval_candidate
|
||||
|
||||
if st.button("Confirm", key="confirm_2"):
|
||||
st.session_state["selected_eval_candidate_2_next"] = True
|
||||
|
||||
|
||||
def run_evaluation_3():
|
||||
if not st.session_state.get("selected_eval_candidate_2_next", None):
|
||||
return
|
||||
|
||||
st.subheader("3. Run Evaluation")
|
||||
# Add info box to explain configurations being used
|
||||
st.info(
|
||||
"""
|
||||
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
|
||||
"""
|
||||
)
|
||||
selected_benchmark = st.session_state["selected_benchmark"]
|
||||
benchmarks = st.session_state["benchmarks"]
|
||||
eval_candidate = st.session_state["eval_candidate"]
|
||||
|
||||
dataset_id = benchmarks[selected_benchmark].dataset_id
|
||||
rows = llama_stack_api.client.datasets.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
total_rows = len(rows.data)
|
||||
# Add number of examples control
|
||||
num_rows = st.number_input(
|
||||
"Number of Examples to Evaluate",
|
||||
min_value=1,
|
||||
max_value=total_rows,
|
||||
value=5,
|
||||
help="Number of examples from the dataset to evaluate. ",
|
||||
)
|
||||
|
||||
benchmark_config = {
|
||||
"type": "benchmark",
|
||||
"eval_candidate": eval_candidate,
|
||||
"scoring_params": {},
|
||||
}
|
||||
|
||||
with st.expander("View Evaluation Task", expanded=True):
|
||||
st.json(benchmarks[selected_benchmark], expanded=True)
|
||||
with st.expander("View Evaluation Task Configuration", expanded=True):
|
||||
st.json(benchmark_config, expanded=True)
|
||||
|
||||
# Add run button and handle evaluation
|
||||
if st.button("Run Evaluation"):
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = rows.data
|
||||
if num_rows < total_rows:
|
||||
rows = rows[:num_rows]
|
||||
|
||||
# Create separate containers for progress text and results
|
||||
progress_text_container = st.empty()
|
||||
results_container = st.empty()
|
||||
output_res = {}
|
||||
for i, r in enumerate(rows):
|
||||
# Update progress
|
||||
progress = i / len(rows)
|
||||
progress_bar.progress(progress, text=progress_text)
|
||||
# Run evaluation for current row
|
||||
eval_res = llama_stack_api.client.eval.evaluate_rows(
|
||||
benchmark_id=selected_benchmark,
|
||||
input_rows=[r],
|
||||
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(r[k])
|
||||
|
||||
for k in eval_res.generations[0].keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(eval_res.generations[0][k])
|
||||
|
||||
for scoring_fn in benchmarks[selected_benchmark].scoring_functions:
|
||||
if scoring_fn not in output_res:
|
||||
output_res[scoring_fn] = []
|
||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||
|
||||
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||
results_container.json(eval_res, expanded=2)
|
||||
|
||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||
# Display results in dataframe
|
||||
if output_res:
|
||||
output_df = pd.DataFrame(output_res)
|
||||
st.subheader("Evaluation Results")
|
||||
st.dataframe(output_df)
|
||||
|
||||
|
||||
def native_evaluation_page():
|
||||
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Generation + Scoring)")
|
||||
|
||||
select_benchmark_1()
|
||||
define_eval_candidate_2()
|
||||
run_evaluation_3()
|
||||
|
||||
|
||||
native_evaluation_page()
|
||||
5
src/llama_stack/core/ui/page/playground/__init__.py
Normal file
5
src/llama_stack/core/ui/page/playground/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
130
src/llama_stack/core/ui/page/playground/chat.py
Normal file
130
src/llama_stack/core/ui/page/playground/chat.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
# Sidebar configurations
|
||||
with st.sidebar:
|
||||
st.header("Configuration")
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
index=0,
|
||||
)
|
||||
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
)
|
||||
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
)
|
||||
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
|
||||
repetition_penalty = st.slider(
|
||||
"Repetition Penalty",
|
||||
min_value=1.0,
|
||||
max_value=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
|
||||
stream = st.checkbox("Stream", value=True)
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful AI assistant.",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
)
|
||||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
|
||||
# Main chat interface
|
||||
st.title("🦙 Chat")
|
||||
|
||||
|
||||
# Initialize chat history
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
# Display chat messages
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Example: What is Llama Stack?"):
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# Display assistant response
|
||||
with st.chat_message("assistant"):
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
model_id=selected_model,
|
||||
stream=stream,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
)
|
||||
|
||||
if stream:
|
||||
for chunk in response:
|
||||
if chunk.event.event_type == "progress":
|
||||
full_response += chunk.event.delta.text
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
else:
|
||||
full_response = response.completion_message.content
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
352
src/llama_stack/core/ui/page/playground/tools.py
Normal file
352
src/llama_stack/core/ui/page/playground/tools.py
Normal file
|
|
@ -0,0 +1,352 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import enum
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent
|
||||
from llama_stack_client.lib.agents.react.agent import ReActAgent
|
||||
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
class AgentType(enum.Enum):
|
||||
REGULAR = "Regular"
|
||||
REACT = "ReAct"
|
||||
|
||||
|
||||
def tool_chat_page():
|
||||
st.title("🛠 Tools")
|
||||
|
||||
client = llama_stack_api.client
|
||||
models = client.models.list()
|
||||
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
||||
|
||||
tool_groups = client.toolgroups.list()
|
||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
selected_vector_stores = []
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
with st.sidebar:
|
||||
st.title("Configuration")
|
||||
st.subheader("Model")
|
||||
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
|
||||
|
||||
st.subheader("Available ToolGroups")
|
||||
|
||||
toolgroup_selection = st.pills(
|
||||
label="Built-in tools",
|
||||
options=builtin_tools_list,
|
||||
selection_mode="multi",
|
||||
on_change=reset_agent,
|
||||
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
||||
help="List of built-in tools from your llama stack server.",
|
||||
)
|
||||
|
||||
if "builtin::rag" in toolgroup_selection:
|
||||
vector_stores = llama_stack_api.client.vector_stores.list() or []
|
||||
if not vector_stores:
|
||||
st.info("No vector databases available for selection.")
|
||||
vector_stores = [vector_store.identifier for vector_store in vector_stores]
|
||||
selected_vector_stores = st.multiselect(
|
||||
label="Select Document Collections to use in RAG queries",
|
||||
options=vector_stores,
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
mcp_selection = st.pills(
|
||||
label="MCP Servers",
|
||||
options=mcp_tools_list,
|
||||
selection_mode="multi",
|
||||
on_change=reset_agent,
|
||||
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
||||
help="List of MCP servers registered to your llama stack server.",
|
||||
)
|
||||
|
||||
toolgroup_selection.extend(mcp_selection)
|
||||
|
||||
grouped_tools = {}
|
||||
total_tools = 0
|
||||
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
|
||||
total_tools += len(tools)
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||
|
||||
for group_id, tools in grouped_tools.items():
|
||||
with st.expander(f"🔧 Tools from `{group_id}`"):
|
||||
for idx, tool in enumerate(tools, start=1):
|
||||
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
|
||||
|
||||
st.subheader("Agent Configurations")
|
||||
st.subheader("Agent Type")
|
||||
agent_type = st.radio(
|
||||
label="Select Agent Type",
|
||||
options=["Regular", "ReAct"],
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
if agent_type == "ReAct":
|
||||
agent_type = AgentType.REACT
|
||||
else:
|
||||
agent_type = AgentType.REGULAR
|
||||
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=64,
|
||||
help="The maximum number of tokens to generate",
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
for i, tool_name in enumerate(toolgroup_selection):
|
||||
if tool_name == "builtin::rag":
|
||||
tool_dict = dict(
|
||||
name="builtin::rag",
|
||||
args={
|
||||
"vector_store_ids": list(selected_vector_stores),
|
||||
},
|
||||
)
|
||||
toolgroup_selection[i] = tool_dict
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||
return ReActAgent(
|
||||
client=client,
|
||||
model=model,
|
||||
tools=toolgroup_selection,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": ReActOutput.model_json_schema(),
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
else:
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
st.session_state.agent_type = agent_type
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||
|
||||
for msg in st.session_state.messages:
|
||||
with st.chat_message(msg["role"]):
|
||||
st.markdown(msg["content"])
|
||||
|
||||
if prompt := st.chat_input(placeholder=""):
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
turn_response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def response_generator(turn_response):
|
||||
if st.session_state.get("agent_type") == AgentType.REACT:
|
||||
return _handle_react_response(turn_response)
|
||||
else:
|
||||
return _handle_regular_response(turn_response)
|
||||
|
||||
def _handle_react_response(turn_response):
|
||||
current_step_content = ""
|
||||
final_answer = None
|
||||
tool_results = []
|
||||
|
||||
for response in turn_response:
|
||||
if not hasattr(response.event, "payload"):
|
||||
yield (
|
||||
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
|
||||
"The response received is missing an expected `payload` attribute.\n"
|
||||
"This could indicate a malformed response or an internal issue within the server.\n\n"
|
||||
f"Error details: {response}"
|
||||
)
|
||||
return
|
||||
|
||||
payload = response.event.payload
|
||||
|
||||
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
|
||||
current_step_content += payload.delta.text
|
||||
continue
|
||||
|
||||
if payload.event_type == "step_complete":
|
||||
step_details = payload.step_details
|
||||
|
||||
if step_details.step_type == "inference":
|
||||
yield from _process_inference_step(current_step_content, tool_results, final_answer)
|
||||
current_step_content = ""
|
||||
elif step_details.step_type == "tool_execution":
|
||||
tool_results = _process_tool_execution(step_details, tool_results)
|
||||
current_step_content = ""
|
||||
else:
|
||||
current_step_content = ""
|
||||
|
||||
if not final_answer and tool_results:
|
||||
yield from _format_tool_results_summary(tool_results)
|
||||
|
||||
def _process_inference_step(current_step_content, tool_results, final_answer):
|
||||
try:
|
||||
react_output_data = json.loads(current_step_content)
|
||||
thought = react_output_data.get("thought")
|
||||
action = react_output_data.get("action")
|
||||
answer = react_output_data.get("answer")
|
||||
|
||||
if answer and answer != "null" and answer is not None:
|
||||
final_answer = answer
|
||||
|
||||
if thought:
|
||||
with st.expander("🤔 Thinking...", expanded=False):
|
||||
st.markdown(f":grey[__{thought}__]")
|
||||
|
||||
if action and isinstance(action, dict):
|
||||
tool_name = action.get("tool_name")
|
||||
tool_params = action.get("tool_params")
|
||||
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
|
||||
st.json(tool_params)
|
||||
|
||||
if answer and answer != "null" and answer is not None:
|
||||
yield f"\n\n✅ **Final Answer:**\n{answer}"
|
||||
|
||||
except json.JSONDecodeError:
|
||||
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
|
||||
except Exception as e:
|
||||
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
|
||||
|
||||
return final_answer
|
||||
|
||||
def _process_tool_execution(step_details, tool_results):
|
||||
try:
|
||||
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
|
||||
for tool_response in step_details.tool_responses:
|
||||
tool_name = tool_response.tool_name
|
||||
content = tool_response.content
|
||||
tool_results.append((tool_name, content))
|
||||
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
st.json(parsed_content)
|
||||
except json.JSONDecodeError:
|
||||
st.code(content, language=None)
|
||||
else:
|
||||
with st.expander("⚙️ Observation", expanded=False):
|
||||
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
|
||||
except Exception as e:
|
||||
with st.expander("⚙️ Error in Tool Execution", expanded=False):
|
||||
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
|
||||
|
||||
return tool_results
|
||||
|
||||
def _format_tool_results_summary(tool_results):
|
||||
yield "\n\n**Here's what I found:**\n"
|
||||
for tool_name, content in tool_results:
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
|
||||
if tool_name == "web_search" and "top_k" in parsed_content:
|
||||
yield from _format_web_search_results(parsed_content)
|
||||
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
|
||||
yield from _format_results_list(parsed_content["results"])
|
||||
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
|
||||
yield from _format_dict_results(parsed_content)
|
||||
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
|
||||
yield from _format_list_results(parsed_content)
|
||||
except json.JSONDecodeError:
|
||||
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
|
||||
except (TypeError, AttributeError, KeyError, IndexError) as e:
|
||||
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
|
||||
|
||||
def _format_web_search_results(parsed_content):
|
||||
for i, result in enumerate(parsed_content["top_k"], 1):
|
||||
if i <= 3:
|
||||
title = result.get("title", "Untitled")
|
||||
url = result.get("url", "")
|
||||
content_text = result.get("content", "").strip()
|
||||
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
|
||||
|
||||
def _format_results_list(results):
|
||||
for i, result in enumerate(results, 1):
|
||||
if i <= 3:
|
||||
if isinstance(result, dict):
|
||||
name = result.get("name", result.get("title", "Result " + str(i)))
|
||||
description = result.get("description", result.get("content", result.get("summary", "")))
|
||||
yield f"\n- **{name}**\n {description}\n"
|
||||
else:
|
||||
yield f"\n- {result}\n"
|
||||
|
||||
def _format_dict_results(parsed_content):
|
||||
yield "\n```\n"
|
||||
for key, value in list(parsed_content.items())[:5]:
|
||||
if isinstance(value, str) and len(value) < 100:
|
||||
yield f"{key}: {value}\n"
|
||||
else:
|
||||
yield f"{key}: [Complex data]\n"
|
||||
yield "```\n"
|
||||
|
||||
def _format_list_results(parsed_content):
|
||||
yield "\n"
|
||||
for _, item in enumerate(parsed_content[:3], 1):
|
||||
if isinstance(item, str):
|
||||
yield f"- {item}\n"
|
||||
elif isinstance(item, dict) and "text" in item:
|
||||
yield f"- {item['text']}\n"
|
||||
elif isinstance(item, dict) and len(item) > 0:
|
||||
first_value = next(iter(item.values()))
|
||||
if isinstance(first_value, str) and len(first_value) < 100:
|
||||
yield f"- {first_value}\n"
|
||||
|
||||
def _handle_regular_response(turn_response):
|
||||
for response in turn_response:
|
||||
if hasattr(response.event, "payload"):
|
||||
print(response.event.payload)
|
||||
if response.event.payload.event_type == "step_progress":
|
||||
if hasattr(response.event.payload.delta, "text"):
|
||||
yield response.event.payload.delta.text
|
||||
if response.event.payload.event_type == "step_complete":
|
||||
if response.event.payload.step_details.step_type == "tool_execution":
|
||||
if response.event.payload.step_details.tool_calls:
|
||||
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
|
||||
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
|
||||
else:
|
||||
yield "No tool_calls present in step_details"
|
||||
else:
|
||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
response_content = st.write_stream(response_generator(turn_response))
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
tool_chat_page()
|
||||
5
src/llama_stack/core/ui/requirements.txt
Normal file
5
src/llama_stack/core/ui/requirements.txt
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
llama-stack>=0.2.1
|
||||
llama-stack-client>=0.2.1
|
||||
pandas
|
||||
streamlit
|
||||
streamlit-option-menu
|
||||
5
src/llama_stack/core/utils/__init__.py
Normal file
5
src/llama_stack/core/utils/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
30
src/llama_stack/core/utils/config.py
Normal file
30
src/llama_stack/core/utils/config.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_value(v: Any) -> Any:
|
||||
if isinstance(v, dict):
|
||||
return _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
return [_redact_value(i) for i in v]
|
||||
return v
|
||||
|
||||
def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = _redact_value(v)
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
||||
18
src/llama_stack/core/utils/config_dirs.py
Normal file
18
src/llama_stack/core/utils/config_dirs.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
|
||||
|
||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
|
||||
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"
|
||||
125
src/llama_stack/core/utils/config_resolution.py
Normal file
125
src/llama_stack/core/utils/config_resolution.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||
|
||||
|
||||
class Mode(StrEnum):
|
||||
RUN = "run"
|
||||
BUILD = "build"
|
||||
|
||||
|
||||
def resolve_config_or_distro(
|
||||
config_or_distro: str,
|
||||
mode: Mode = Mode.RUN,
|
||||
) -> Path:
|
||||
"""
|
||||
Resolve a config/distro argument to a concrete config file path.
|
||||
|
||||
Args:
|
||||
config_or_distro: User input (file path, distribution name, or built distribution)
|
||||
mode: Mode resolving for ("run", "build", "server")
|
||||
|
||||
Returns:
|
||||
Path to the resolved config file
|
||||
|
||||
Raises:
|
||||
ValueError: If resolution fails
|
||||
"""
|
||||
|
||||
# Strategy 1: Try as file path first
|
||||
config_path = Path(config_or_distro)
|
||||
if config_path.exists() and config_path.is_file():
|
||||
logger.debug(f"Using file path: {config_path}")
|
||||
return config_path.resolve()
|
||||
|
||||
# Strategy 2: Try as distribution name (if no .yaml extension)
|
||||
if not config_or_distro.endswith(".yaml"):
|
||||
distro_config = _get_distro_config_path(config_or_distro, mode)
|
||||
if distro_config.exists():
|
||||
logger.debug(f"Using distribution: {distro_config}")
|
||||
return distro_config
|
||||
|
||||
# Strategy 3: Try as built distribution name
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.debug(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.debug(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
# Strategy 4: Failed - provide helpful error
|
||||
raise ValueError(_format_resolution_error(config_or_distro, mode))
|
||||
|
||||
|
||||
def _get_distro_config_path(distro_name: str, mode: Mode) -> Path:
|
||||
"""Get the config file path for a distro."""
|
||||
return DISTRO_DIR / distro_name / f"{mode}.yaml"
|
||||
|
||||
|
||||
def _format_resolution_error(config_or_distro: str, mode: Mode) -> str:
|
||||
"""Format a helpful error message for resolution failures."""
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
distro_path = _get_distro_config_path(config_or_distro, mode)
|
||||
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||
|
||||
available_distros = _get_available_distros()
|
||||
distros_str = ", ".join(available_distros) if available_distros else "none found"
|
||||
|
||||
return f"""Could not resolve config or distribution '{config_or_distro}'.
|
||||
|
||||
Tried the following locations:
|
||||
1. As file path: {Path(config_or_distro).resolve()}
|
||||
2. As distribution: {distro_path}
|
||||
3. As built distribution: ({distrib_path}, {distrib_path2})
|
||||
|
||||
Available distributions: {distros_str}
|
||||
|
||||
Did you mean one of these distributions?
|
||||
{_format_distro_suggestions(available_distros, config_or_distro)}
|
||||
"""
|
||||
|
||||
|
||||
def _get_available_distros() -> list[str]:
|
||||
"""Get list of available distro names."""
|
||||
if not DISTRO_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
|
||||
return []
|
||||
|
||||
return list(
|
||||
set(
|
||||
[d.name for d in DISTRO_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _format_distro_suggestions(distros: list[str], user_input: str) -> str:
|
||||
"""Format distro suggestions for error messages, showing closest matches first."""
|
||||
if not distros:
|
||||
return " (no distros found)"
|
||||
|
||||
import difflib
|
||||
|
||||
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
|
||||
close_matches = difflib.get_close_matches(user_input, distros, n=3, cutoff=0.3)
|
||||
display_distros = close_matches if close_matches else distros[:3]
|
||||
|
||||
suggestions = [f" - {d}" for d in display_distros]
|
||||
return "\n".join(suggestions)
|
||||
40
src/llama_stack/core/utils/context.py
Normal file
40
src/llama_stack/core/utils/context.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
def preserve_contexts_async_generator[T](
|
||||
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
# Capture initial context values
|
||||
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
try:
|
||||
# Restore context values before any await
|
||||
for context_var in context_vars:
|
||||
context_var.set(initial_context_values[context_var.name])
|
||||
|
||||
item = await gen.__anext__()
|
||||
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
for context_var in context_vars:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
|
||||
yield item
|
||||
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
13
src/llama_stack/core/utils/dynamic.py
Normal file
13
src/llama_stack/core/utils/dynamic.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
96
src/llama_stack/core/utils/exec.py
Normal file
96
src/llama_stack/core/utils/exec.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def formulate_run_args(image_type: str, image_name: str) -> list:
|
||||
# Only venv is supported now
|
||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||
env_name = image_name or current_venv
|
||||
if not env_name:
|
||||
cprint(
|
||||
"No current virtual environment detected, please specify a virtual environment name with --image-name",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return []
|
||||
|
||||
cprint(f"Using virtual environment: {env_name}", file=sys.stderr)
|
||||
|
||||
script = importlib.resources.files("llama_stack") / "core/start_stack.sh"
|
||||
run_args = [
|
||||
script,
|
||||
image_type,
|
||||
env_name,
|
||||
]
|
||||
|
||||
return run_args
|
||||
|
||||
|
||||
def in_notebook():
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
|
||||
ipython = get_ipython()
|
||||
if ipython is None or "IPKernelApp" not in ipython.config: # pragma: no cover
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_command(command: list[str]) -> int:
|
||||
"""
|
||||
Run a command with interrupt handling and output capture.
|
||||
Uses subprocess.run with direct stream piping for better performance.
|
||||
|
||||
Args:
|
||||
command (list): The command to run.
|
||||
|
||||
Returns:
|
||||
int: The return code of the command.
|
||||
"""
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
ctrl_c_pressed = False
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
ctrl_c_pressed = True
|
||||
log.info("\nCtrl-C detected. Aborting...")
|
||||
|
||||
try:
|
||||
# Set up the signal handler
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
# Run the command with stdout/stderr piped directly to system streams
|
||||
result = subprocess.run(
|
||||
command,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
return result.returncode
|
||||
except subprocess.SubprocessError as e:
|
||||
log.error(f"Subprocess error: {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
log.exception(f"Unexpected error: {e}")
|
||||
return 1
|
||||
finally:
|
||||
# Restore the original signal handler
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
12
src/llama_stack/core/utils/image_types.py
Normal file
12
src/llama_stack/core/utils/image_types.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class LlamaStackImageType(enum.Enum):
|
||||
CONTAINER = "container"
|
||||
VENV = "venv"
|
||||
13
src/llama_stack/core/utils/model_utils.py
Normal file
13
src/llama_stack/core/utils/model_utils.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
|
||||
def model_local_dir(descriptor: str) -> str:
|
||||
return str(Path(DEFAULT_CHECKPOINT_DIR) / (descriptor.replace(":", "-")))
|
||||
283
src/llama_stack/core/utils/prompt_for_config.py
Normal file
283
src/llama_stack/core/utils/prompt_for_config.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def is_list_of_primitives(field_type):
|
||||
"""Check if a field type is a List of primitive types."""
|
||||
origin = get_origin(field_type)
|
||||
if origin is list or origin is list:
|
||||
args = get_args(field_type)
|
||||
if len(args) == 1 and args[0] in (int, float, str, bool):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_basemodel_without_fields(typ):
|
||||
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||
|
||||
|
||||
def can_recurse(typ):
|
||||
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||
|
||||
|
||||
def get_literal_values(field):
|
||||
"""Extract literal values from a field if it's a Literal type."""
|
||||
if get_origin(field.annotation) is Literal:
|
||||
return get_args(field.annotation)
|
||||
return None
|
||||
|
||||
|
||||
def is_optional(field_type):
|
||||
"""Check if a field type is Optional."""
|
||||
return get_origin(field_type) is Union and type(None) in get_args(field_type)
|
||||
|
||||
|
||||
def get_non_none_type(field_type):
|
||||
"""Get the non-None type from an Optional type."""
|
||||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
|
||||
|
||||
def manually_validate_field(model: type[BaseModel], field_name: str, value: Any):
|
||||
validators = model.__pydantic_decorators__.field_validators
|
||||
for _name, validator in validators.items():
|
||||
if field_name in validator.info.fields:
|
||||
validator.func(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def is_discriminated_union(typ) -> bool:
|
||||
if isinstance(typ, FieldInfo):
|
||||
return typ.discriminator
|
||||
else:
|
||||
if get_origin(typ) is not Annotated:
|
||||
return False
|
||||
args = get_args(typ)
|
||||
return len(args) >= 2 and args[1].discriminator
|
||||
|
||||
|
||||
def prompt_for_discriminated_union(
|
||||
field_name,
|
||||
typ,
|
||||
existing_value,
|
||||
):
|
||||
if isinstance(typ, FieldInfo):
|
||||
inner_type = typ.annotation
|
||||
discriminator = typ.discriminator
|
||||
default_value = typ.default
|
||||
else:
|
||||
args = get_args(typ)
|
||||
inner_type = args[0]
|
||||
discriminator = args[1].discriminator
|
||||
default_value = args[1].default
|
||||
|
||||
union_types = get_args(inner_type)
|
||||
# Find the discriminator field in each union type
|
||||
type_map = {}
|
||||
for t in union_types:
|
||||
disc_field = t.__fields__[discriminator]
|
||||
literal_values = get_literal_values(disc_field)
|
||||
if literal_values:
|
||||
for value in literal_values:
|
||||
type_map[value] = t
|
||||
|
||||
while True:
|
||||
prompt = f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())})"
|
||||
if default_value is not None:
|
||||
prompt += f" (default: {default_value})"
|
||||
|
||||
discriminator_value = input(f"{prompt}: ")
|
||||
if discriminator_value == "" and default_value is not None:
|
||||
discriminator_value = default_value
|
||||
|
||||
if discriminator_value in type_map:
|
||||
chosen_type = type_map[discriminator_value]
|
||||
log.info(f"\nConfiguring {chosen_type.__name__}:")
|
||||
|
||||
if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
|
||||
existing_value = None
|
||||
|
||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||
# Set the discriminator field in the sub-config
|
||||
setattr(sub_config, discriminator, discriminator_value)
|
||||
return sub_config
|
||||
else:
|
||||
log.error(f"Invalid {discriminator}. Please try again.")
|
||||
|
||||
|
||||
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
||||
# We should add handling for the most common cases to tide us over.
|
||||
#
|
||||
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
||||
# unit tests for coverage.
|
||||
def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel:
|
||||
"""
|
||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||
|
||||
Args:
|
||||
config_type: A Pydantic BaseModel class representing the configuration structure.
|
||||
|
||||
Returns:
|
||||
An instance of the config_type with user-provided values.
|
||||
"""
|
||||
config_data = {}
|
||||
|
||||
for field_name, field in config_type.__fields__.items():
|
||||
field_type = field.annotation
|
||||
existing_value = getattr(existing_config, field_name) if existing_config else None
|
||||
if existing_value:
|
||||
default_value = existing_value
|
||||
else:
|
||||
default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
|
||||
is_required = field.is_required
|
||||
|
||||
# Skip fields with Literal type
|
||||
if get_origin(field_type) is Literal:
|
||||
continue
|
||||
|
||||
# Skip fields with no type annotations
|
||||
if is_basemodel_without_fields(field_type):
|
||||
config_data[field_name] = field_type()
|
||||
continue
|
||||
|
||||
if inspect.isclass(field_type) and issubclass(field_type, Enum):
|
||||
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
|
||||
while True:
|
||||
# this branch does not handle existing and default values yet
|
||||
user_input = input(prompt + " ")
|
||||
try:
|
||||
value = field_type[user_input]
|
||||
validated_value = manually_validate_field(config_type, field, value)
|
||||
config_data[field_name] = validated_value
|
||||
break
|
||||
except KeyError:
|
||||
log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
|
||||
continue
|
||||
|
||||
if is_discriminated_union(field):
|
||||
config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
|
||||
continue
|
||||
|
||||
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||
if input(prompt).lower() == "n":
|
||||
config_data[field_name] = None
|
||||
continue
|
||||
nested_type = get_non_none_type(field_type)
|
||||
log.info(f"Entering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||
elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
|
||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||
if input(prompt).lower() == "n":
|
||||
config_data[field_name] = None
|
||||
continue
|
||||
nested_type = get_non_none_type(field_type)
|
||||
config_data[field_name] = prompt_for_discriminated_union(
|
||||
field_name,
|
||||
nested_type,
|
||||
existing_value,
|
||||
)
|
||||
elif can_recurse(field_type):
|
||||
log.info(f"\nEntering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(
|
||||
field_type,
|
||||
existing_value,
|
||||
)
|
||||
else:
|
||||
prompt = f"Enter value for {field_name}"
|
||||
if existing_value is not None:
|
||||
prompt += f" (existing: {existing_value})"
|
||||
elif default_value is not None:
|
||||
prompt += f" (default: {default_value})"
|
||||
if is_optional(field_type):
|
||||
prompt += " (optional)"
|
||||
elif is_required:
|
||||
prompt += " (required)"
|
||||
prompt += ": "
|
||||
|
||||
while True:
|
||||
user_input = input(prompt)
|
||||
if user_input == "":
|
||||
if default_value is not None:
|
||||
config_data[field_name] = default_value
|
||||
break
|
||||
elif is_optional(field_type) or not is_required:
|
||||
config_data[field_name] = None
|
||||
break
|
||||
else:
|
||||
log.error("This field is required. Please provide a value.")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# Handle Optional types
|
||||
if is_optional(field_type):
|
||||
if user_input.lower() == "none":
|
||||
value = None
|
||||
else:
|
||||
field_type = get_non_none_type(field_type)
|
||||
value = user_input
|
||||
|
||||
# Handle List of primitives
|
||||
elif is_list_of_primitives(field_type):
|
||||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("Input must be a JSON-encoded list")
|
||||
element_type = get_args(field_type)[0]
|
||||
value = [element_type(item) for item in value]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
|
||||
continue
|
||||
except ValueError as e:
|
||||
log.error(f"{str(e)}")
|
||||
continue
|
||||
|
||||
elif get_origin(field_type) is dict:
|
||||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("Input must be a JSON-encoded dictionary")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
|
||||
continue
|
||||
|
||||
# Convert the input to the correct type
|
||||
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
# For nested BaseModels, we assume a dictionary-like string input
|
||||
import ast
|
||||
|
||||
value = field_type(**ast.literal_eval(user_input))
|
||||
else:
|
||||
value = field_type(user_input)
|
||||
|
||||
except ValueError:
|
||||
log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Validate the field using our manual validation function
|
||||
validated_value = manually_validate_field(config_type, field_name, value)
|
||||
config_data[field_name] = validated_value
|
||||
break
|
||||
except ValueError as e:
|
||||
log.error(f"Validation error: {str(e)}")
|
||||
|
||||
return config_type(**config_data)
|
||||
18
src/llama_stack/core/utils/serialize.py
Normal file
18
src/llama_stack/core/utils/serialize.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
Loading…
Add table
Add a link
Reference in a new issue