mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 01:28:06 +00:00
Merge branch 'main' into allow-dynamic-models-ollama
This commit is contained in:
commit
56476fa462
247 changed files with 9176 additions and 7177 deletions
|
|
@ -4,15 +4,83 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DynamicApiMeta(EnumMeta):
|
||||
def __new__(cls, name, bases, namespace):
|
||||
# Store the original enum values
|
||||
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
|
||||
|
||||
# Create the enum class
|
||||
cls = super().__new__(cls, name, bases, namespace)
|
||||
|
||||
# Store the original values for reference
|
||||
cls._original_values = original_values
|
||||
# Initialize _dynamic_values
|
||||
cls._dynamic_values = {}
|
||||
|
||||
return cls
|
||||
|
||||
def __call__(cls, value):
|
||||
try:
|
||||
return super().__call__(value)
|
||||
except ValueError as e:
|
||||
# If this value was already dynamically added, return it
|
||||
if value in cls._dynamic_values:
|
||||
return cls._dynamic_values[value]
|
||||
|
||||
# If the value doesn't exist, create a new enum member
|
||||
# Create a new member name from the value
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return the existing member
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
|
||||
# register new APIs explicitly
|
||||
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
|
||||
|
||||
def __iter__(cls):
|
||||
# Allow iteration over both static and dynamic members
|
||||
yield from super().__iter__()
|
||||
if hasattr(cls, "_dynamic_values"):
|
||||
yield from cls._dynamic_values.values()
|
||||
|
||||
def add(cls, value):
|
||||
"""
|
||||
Add a new API to the enum.
|
||||
Used to register external APIs.
|
||||
"""
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return it
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Create a new enum member
|
||||
member = object.__new__(cls)
|
||||
member._name_ = member_name
|
||||
member._value_ = value
|
||||
|
||||
# Add it to the enum class
|
||||
cls._member_map_[member_name] = member
|
||||
cls._member_names_.append(member_name)
|
||||
cls._member_type_ = str
|
||||
|
||||
# Store it in our dynamic values
|
||||
cls._dynamic_values[value] = member
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
class Api(Enum, metaclass=DynamicApiMeta):
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
|
|
@ -54,3 +122,12 @@ class Error(BaseModel):
|
|||
title: str
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
class ExternalApiSpec(BaseModel):
|
||||
"""Specification for an external API implementation."""
|
||||
|
||||
module: str = Field(..., description="Python module containing the API implementation")
|
||||
name: str = Field(..., description="Name of the API")
|
||||
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
|
||||
protocol: str = Field(..., description="Name of the protocol class for the API")
|
||||
|
|
|
|||
|
|
@ -464,6 +464,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
|
|||
|
||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||
|
||||
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
|
|
@ -489,7 +491,7 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
|
@ -518,7 +520,7 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: OpenAIChatCompletionMessageContent | None = None
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
|
@ -534,7 +536,7 @@ class OpenAIToolMessageParam(BaseModel):
|
|||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -547,7 +549,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
|
@ -819,12 +821,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
|
|||
class ModelStore(Protocol):
|
||||
async def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
async def update_registered_llm_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class TextTruncation(Enum):
|
||||
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
# Add this constant near the top of the file, after the imports
|
||||
DEFAULT_TTL_DAYS = 7
|
||||
|
||||
REQUIRED_SCOPE = "telemetry.read"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
|
|
@ -259,7 +261,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
|
|
@ -277,7 +279,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
||||
|
|
@ -286,7 +288,9 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
|
||||
)
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
||||
|
|
@ -296,7 +300,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
|
@ -312,7 +316,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
|
@ -345,7 +349,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
|
@ -88,7 +88,7 @@ class RAGQueryGenerator(Enum):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class RAGSearchMode(Enum):
|
||||
class RAGSearchMode(StrEnum):
|
||||
"""
|
||||
Search modes for RAG query retrieval:
|
||||
- VECTOR: Uses vector similarity search for semantic matching
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class VectorDBInput(BaseModel):
|
|||
vector_db_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_id: str | None = None
|
||||
provider_vector_db_id: str | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -338,7 +338,7 @@ class VectorIO(Protocol):
|
|||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
name: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
|
|
|
|||
|
|
@ -31,11 +31,13 @@ from llama_stack.distribution.build import (
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildConfig,
|
||||
BuildProvider,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
|
|
@ -93,7 +95,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
elif args.providers:
|
||||
providers_list: dict[str, str | list[str]] = dict()
|
||||
provider_list: dict[str, list[BuildProvider]] = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
|
|
@ -102,7 +104,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider = api_provider.split("=")
|
||||
api, provider_type = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
|
|
@ -111,16 +113,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider in providers_for_api:
|
||||
if api not in providers_list:
|
||||
providers_list[api] = []
|
||||
# Use type guarding to ensure we have a list
|
||||
provider_value = providers_list[api]
|
||||
if isinstance(provider_value, list):
|
||||
provider_value.append(provider)
|
||||
else:
|
||||
# Convert string to list and append
|
||||
providers_list[api] = [provider_value, provider]
|
||||
if provider_type in providers_for_api:
|
||||
provider = BuildProvider(
|
||||
provider_type=provider_type,
|
||||
module=None,
|
||||
)
|
||||
provider_list.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
|
|
@ -129,7 +127,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
distribution_spec = DistributionSpec(
|
||||
providers=providers_list,
|
||||
providers=provider_list,
|
||||
description=",".join(args.providers),
|
||||
)
|
||||
if not args.image_type:
|
||||
|
|
@ -190,7 +188,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
|
||||
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
|
||||
|
||||
providers: dict[str, str | list[str]] = dict()
|
||||
providers: dict[str, list[BuildProvider]] = dict()
|
||||
for api, providers_for_api in get_provider_registry().items():
|
||||
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
|
||||
if not available_providers:
|
||||
|
|
@ -205,7 +203,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
),
|
||||
)
|
||||
|
||||
providers[api.value] = api_provider
|
||||
string_providers = api_provider.split(" ")
|
||||
|
||||
for provider in string_providers:
|
||||
providers.setdefault(api.value, []).append(BuildProvider(provider_type=provider))
|
||||
|
||||
description = prompt(
|
||||
"\n > (Optional) Enter a short description for your Llama Stack: ",
|
||||
|
|
@ -236,11 +237,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
|
||||
if args.print_deps_only:
|
||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
print(f"uv pip install {' '.join(normal_deps)}")
|
||||
for special_dep in special_deps:
|
||||
print(f"uv pip install {special_dep}")
|
||||
for external_dep in external_provider_dependencies:
|
||||
print(f"uv pip install {external_dep}")
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -276,8 +279,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if config.external_providers_dir and not config.external_providers_dir.exists():
|
||||
config.external_providers_dir.mkdir(exist_ok=True)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||
run_args = formulate_run_args(args.image_type, args.image_name)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
|
||||
run_command(run_args)
|
||||
|
||||
|
||||
|
|
@ -303,27 +306,25 @@ def _generate_run_config(
|
|||
provider_registry = get_provider_registry(build_config)
|
||||
for api in apis:
|
||||
run_config.providers[api] = []
|
||||
provider_types = build_config.distribution_spec.providers[api]
|
||||
if isinstance(provider_types, str):
|
||||
provider_types = [provider_types]
|
||||
providers = build_config.distribution_spec.providers[api]
|
||||
|
||||
for i, provider_type in enumerate(provider_types):
|
||||
pid = provider_type.split("::")[-1]
|
||||
for provider in providers:
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
|
||||
p = provider_registry[Api(api)][provider_type]
|
||||
p = provider_registry[Api(api)][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
try:
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
except ModuleNotFoundError:
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
|
||||
except (ModuleNotFoundError, ValueError) as exc:
|
||||
# HACK ALERT:
|
||||
# This code executes after building is done, the import cannot work since the
|
||||
# package is either available in the venv or container - not available on the host.
|
||||
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
|
||||
# external
|
||||
cprint(
|
||||
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||
f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
|
@ -336,9 +337,10 @@ def _generate_run_config(
|
|||
config = {}
|
||||
|
||||
p_spec = Provider(
|
||||
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
|
||||
provider_type=provider_type,
|
||||
provider_id=pid,
|
||||
provider_type=provider.provider_type,
|
||||
config=config,
|
||||
module=provider.module,
|
||||
)
|
||||
run_config.providers[api].append(p_spec)
|
||||
|
||||
|
|
@ -401,9 +403,32 @@ def _run_stack_build_command_from_build_config(
|
|||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||
|
||||
with open(build_file_path, "w") as f:
|
||||
to_write = json.loads(build_config.model_dump_json())
|
||||
to_write = json.loads(build_config.model_dump_json(exclude_none=True))
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
# We first install the external APIs so that the build process can use them and discover the
|
||||
# providers dependencies
|
||||
if build_config.external_apis_dir:
|
||||
cprint("Installing external APIs", color="yellow", file=sys.stderr)
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
# install the external APIs
|
||||
packages = []
|
||||
for _, api_spec in external_apis.items():
|
||||
if api_spec.pip_packages:
|
||||
packages.extend(api_spec.pip_packages)
|
||||
cprint(
|
||||
f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return_code = run_command(["uv", "pip", "install", *packages])
|
||||
if return_code != 0:
|
||||
packages_str = ", ".join(packages)
|
||||
raise RuntimeError(
|
||||
f"Failed to install external APIs packages: {packages_str} (return code: {return_code})"
|
||||
)
|
||||
|
||||
return_code = build_image(
|
||||
build_config,
|
||||
build_file_path,
|
||||
|
|
|
|||
|
|
@ -82,39 +82,6 @@ class StackRun(Subcommand):
|
|||
return ImageType.CONDA.value, args.image_name
|
||||
return args.image_type, args.image_name
|
||||
|
||||
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and template name from args.config"""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, template_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
|
|
@ -125,8 +92,15 @@ class StackRun(Subcommand):
|
|||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
|
||||
# Resolve config file and template name first
|
||||
config_file, template_name = self._resolve_config_and_template(args)
|
||||
if args.config:
|
||||
try:
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
|
||||
config_file = resolve_config_or_template(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
# Check if config is required based on image type
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
|
||||
|
|
@ -164,18 +138,14 @@ class StackRun(Subcommand):
|
|||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config":
|
||||
if template_name:
|
||||
server_args.template = str(template_name)
|
||||
else:
|
||||
# Set the config file path
|
||||
server_args.config = str(config_file)
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name, config, template_name)
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
|
|
|
|||
48
llama_stack/cli/utils.py
Normal file
48
llama_stack/cli/utils.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
def add_config_template_args(parser: argparse.ArgumentParser):
|
||||
"""Add unified config/template arguments with backward compatibility."""
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
group.add_argument(
|
||||
"config",
|
||||
nargs="?",
|
||||
help="Configuration file path or template name",
|
||||
)
|
||||
|
||||
# Backward compatibility arguments (deprecated)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
dest="config_deprecated",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--template",
|
||||
dest="template_deprecated",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Template name",
|
||||
)
|
||||
|
||||
|
||||
def get_config_from_args(args: argparse.Namespace) -> str | None:
|
||||
"""Extract config value from parsed arguments, handling both new and deprecated forms."""
|
||||
if args.config is not None:
|
||||
return str(args.config)
|
||||
elif hasattr(args, "config_deprecated") and args.config_deprecated is not None:
|
||||
logger.warning("Using deprecated --config argument. Use positional argument [config] instead.")
|
||||
return str(args.config_deprecated)
|
||||
elif hasattr(args, "template_deprecated") and args.template_deprecated is not None:
|
||||
logger.warning("Using deprecated --template argument. Use positional argument [config] instead.")
|
||||
return str(args.template_deprecated)
|
||||
return None
|
||||
|
|
@ -14,6 +14,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.utils.exec import run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -41,7 +42,7 @@ class ApiInput(BaseModel):
|
|||
|
||||
def get_provider_dependencies(
|
||||
config: BuildConfig | DistributionTemplate,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> tuple[list[str], list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
if isinstance(config, DistributionTemplate):
|
||||
config = config.build_config()
|
||||
|
|
@ -50,6 +51,7 @@ def get_provider_dependencies(
|
|||
additional_pip_packages = config.additional_pip_packages
|
||||
|
||||
deps = []
|
||||
external_provider_deps = []
|
||||
registry = get_provider_registry(config)
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
providers_for_api = registry[Api(api_str)]
|
||||
|
|
@ -64,8 +66,16 @@ def get_provider_dependencies(
|
|||
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
|
||||
|
||||
provider_spec = providers_for_api[provider_type]
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
if provider_spec.container_image:
|
||||
if hasattr(provider_spec, "is_external") and provider_spec.is_external:
|
||||
# this ensures we install the top level module for our external providers
|
||||
if provider_spec.module:
|
||||
if isinstance(provider_spec.module, str):
|
||||
external_provider_deps.append(provider_spec.module)
|
||||
else:
|
||||
external_provider_deps.extend(provider_spec.module)
|
||||
if hasattr(provider_spec, "pip_packages"):
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
|
||||
raise ValueError("A stack's dependencies cannot have a container image")
|
||||
|
||||
normal_deps = []
|
||||
|
|
@ -78,7 +88,7 @@ def get_provider_dependencies(
|
|||
|
||||
normal_deps.extend(additional_pip_packages or [])
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps))
|
||||
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(config: BuildConfig):
|
||||
|
|
@ -103,41 +113,59 @@ def build_image(
|
|||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
|
||||
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
if build_config.external_apis_dir:
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
"--template-or-config",
|
||||
template_or_config,
|
||||
"--image-name",
|
||||
image_name,
|
||||
"--container-base",
|
||||
container_base,
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.append(run_config)
|
||||
args.extend(["--run-config", run_config])
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--build-file-path",
|
||||
str(build_file_path),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# Always pass both arguments, even if empty, to maintain consistent positional arguments
|
||||
if special_deps:
|
||||
args.append("#".join(special_deps))
|
||||
args.extend(["--optional-deps", "#".join(special_deps)])
|
||||
if external_provider_deps:
|
||||
args.extend(
|
||||
["--external-provider-deps", "#".join(external_provider_deps)]
|
||||
) # the script will install external provider module, get its deps, and install those too.
|
||||
|
||||
return_code = run_command(args)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,91 @@
|
|||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --env-name <conda_env_name> --build-file-path <build_file_path> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
env_name=""
|
||||
build_file_path=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--env-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --env-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
env_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--build-file-path)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --build-file-path requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
build_file_path="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
|
|
@ -20,50 +101,18 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
|||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 3 ]; then
|
||||
echo "Usage: $0 <distribution_type> <conda_env_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 <distribution_type> my-conda-env ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$4"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
env_name="$1"
|
||||
build_file_path="$2"
|
||||
pip_dependencies="$3"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
ensure_conda_env_python310() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
# Use only global variables set by flag parser
|
||||
local python_version="3.12"
|
||||
|
||||
# Check if conda command is available
|
||||
if ! is_command_available conda; then
|
||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if the environment exists
|
||||
if conda env list | grep -q "^${env_name} "; then
|
||||
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
|
||||
|
||||
# Check Python version in the environment
|
||||
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
|
||||
|
||||
if [ "$current_version" = "$python_version" ]; then
|
||||
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
|
||||
else
|
||||
|
|
@ -73,37 +122,37 @@ ensure_conda_env_python310() {
|
|||
else
|
||||
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
|
||||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
|
||||
ENVNAME="${env_name}"
|
||||
# setup_cleanup_handlers
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "${env_name}"
|
||||
|
||||
"$CONDA_PREFIX"/bin/pip install uv
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
uv pip install fastapi libcst
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-stack=="$TEST_PYPI_VERSION" \
|
||||
"$pip_dependencies"
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
"$normal_deps"
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install "$part"
|
||||
done
|
||||
fi
|
||||
else
|
||||
# Re-installing llama-stack in the new conda environment
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
else
|
||||
|
|
@ -115,31 +164,44 @@ ensure_conda_env_python310() {
|
|||
fi
|
||||
uv pip install --no-cache-dir "$SPEC_VERSION"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
printf "Installing pip dependencies\n"
|
||||
uv pip install $pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
uv pip install $normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "Getting provider spec for module: $part and installing dependencies"
|
||||
package_name=$(echo "$part" | sed 's/[<>=!].*//')
|
||||
python3 -c "
|
||||
import importlib
|
||||
import sys
|
||||
try:
|
||||
module = importlib.import_module(f'$package_name.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
print('\\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
|
||||
" | uv pip install -r -
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
||||
mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml
|
||||
echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml"
|
||||
}
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||
ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps"
|
||||
|
|
|
|||
|
|
@ -19,57 +19,111 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
# Mount command for cache container .cache, can be overridden by the user if needed
|
||||
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
|
||||
|
||||
# Path to the run.yaml file in the container
|
||||
RUN_CONFIG_PATH=/app/run.yaml
|
||||
|
||||
BUILD_CONTEXT_DIR=$(pwd)
|
||||
|
||||
if [ "$#" -lt 4 ]; then
|
||||
# This only works for templates
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
|
||||
exit 1
|
||||
fi
|
||||
set -euo pipefail
|
||||
|
||||
template_or_config="$1"
|
||||
shift
|
||||
image_name="$1"
|
||||
shift
|
||||
container_base="$1"
|
||||
shift
|
||||
pip_dependencies="$1"
|
||||
shift
|
||||
|
||||
# Handle optional arguments
|
||||
run_config=""
|
||||
special_pip_deps=""
|
||||
|
||||
# Check if there are more arguments
|
||||
# The logics is becoming cumbersom, we should refactor it if we can do better
|
||||
if [ $# -gt 0 ]; then
|
||||
# Check if the argument ends with .yaml
|
||||
if [[ "$1" == *.yaml ]]; then
|
||||
run_config="$1"
|
||||
shift
|
||||
# If there's another argument after .yaml, it must be special_pip_deps
|
||||
if [ $# -gt 0 ]; then
|
||||
special_pip_deps="$1"
|
||||
fi
|
||||
else
|
||||
# If it's not .yaml, it must be special_pip_deps
|
||||
special_pip_deps="$1"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --image-name <image_name> --container-base <container_base> --normal-deps <pip_dependencies> [--run-config <run_config>] [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
image_name=""
|
||||
container_base=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
run_config=""
|
||||
template_or_config=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--image-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --image-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
image_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--container-base)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --container-base requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
container_base="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--run-config)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --run-config requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
run_config="$2"
|
||||
shift 2
|
||||
;;
|
||||
--template-or-config)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --template-or-config requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
template_or_config="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --image-name, --container-base, and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
|
|
@ -78,18 +132,15 @@ add_to_container() {
|
|||
if [ -t 0 ]; then
|
||||
printf '%s\n' "$1" >>"$output_file"
|
||||
else
|
||||
# If stdin is not a terminal, read from it (heredoc)
|
||||
cat >>"$output_file"
|
||||
fi
|
||||
}
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available "$CONTAINER_BINARY"; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Update and install UBI9 components if UBI9 base image is used
|
||||
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||
add_to_container << EOF
|
||||
FROM $container_base
|
||||
|
|
@ -125,24 +176,59 @@ RUN pip install uv
|
|||
EOF
|
||||
fi
|
||||
|
||||
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
|
||||
add_to_container << EOF
|
||||
ENV UV_LINK_MODE=copy
|
||||
EOF
|
||||
|
||||
# Add pip dependencies first since llama-stack is what will change most often
|
||||
# so we can reuse layers.
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
if [ -n "$normal_deps" ]; then
|
||||
read -ra pip_args <<< "$normal_deps"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache $pip_dependencies
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN uv pip install --no-cache $part
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
EOF
|
||||
add_to_container <<EOF
|
||||
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r -
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
try:
|
||||
package_name = '$part'.split('==')[0].split('>=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0]
|
||||
module = importlib.import_module(f'{package_name}.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
if isinstance(spec.pip_packages, (list, tuple)):
|
||||
print('\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr)
|
||||
PYTHON
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
||||
# Function to get Python command
|
||||
get_python_cmd() {
|
||||
if is_command_available python; then
|
||||
echo "python"
|
||||
|
|
@ -207,7 +293,7 @@ COPY $dir $mount_point
|
|||
EOF
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache -e $mount_point
|
||||
RUN $MOUNT_CACHE uv pip install -e $mount_point
|
||||
EOF
|
||||
}
|
||||
|
||||
|
|
@ -222,10 +308,10 @@ else
|
|||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
add_to_container << EOF
|
||||
RUN uv pip install fastapi libcst
|
||||
RUN $MOUNT_CACHE uv pip install fastapi libcst
|
||||
EOF
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-stack==$TEST_PYPI_VERSION
|
||||
|
||||
|
|
@ -237,7 +323,7 @@ EOF
|
|||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache $SPEC_VERSION
|
||||
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION
|
||||
EOF
|
||||
fi
|
||||
fi
|
||||
|
|
@ -328,7 +414,7 @@ $CONTAINER_BINARY build \
|
|||
"$BUILD_CONTEXT_DIR"
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
rm -rf "$BUILD_CONTEXT_DIR/run.yaml" "$TEMP_DIR"
|
||||
set +x
|
||||
|
||||
echo "Success!"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,76 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
||||
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --env-name <env_name> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
env_name=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--env-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --env-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
env_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$env_name" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --env-name and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
|
|
@ -25,29 +95,6 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
|||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 2 ]; then
|
||||
echo "Usage: $0 <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$3"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
env_name="$1"
|
||||
pip_dependencies="$2"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# pre-run checks to make sure we can proceed with the installation
|
||||
pre_run_checks() {
|
||||
local env_name="$1"
|
||||
|
|
@ -71,49 +118,44 @@ pre_run_checks() {
|
|||
}
|
||||
|
||||
run() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
|
||||
# Use only global variables set by flag parser
|
||||
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
|
||||
echo "Installing dependencies in system Python environment"
|
||||
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
|
||||
export UV_SYSTEM_PYTHON=1
|
||||
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
|
||||
echo "Virtual environment $env_name is already active"
|
||||
else
|
||||
echo "Using virtual environment $env_name"
|
||||
uv venv "$env_name"
|
||||
# shellcheck source=/dev/null
|
||||
source "$env_name/bin/activate"
|
||||
fi
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
uv pip install fastapi libcst
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-stack=="$TEST_PYPI_VERSION" \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
$normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install "$part"
|
||||
done
|
||||
fi
|
||||
else
|
||||
# Re-installing llama-stack in the new virtual environment
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
else
|
||||
|
|
@ -125,27 +167,41 @@ run() {
|
|||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
printf "Installing pip dependencies\n"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install $pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
uv pip install $normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
echo "Installing special provider module: $part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "Installing external provider module: $part"
|
||||
uv pip install "$part"
|
||||
echo "Getting provider spec for module: $part and installing dependencies"
|
||||
package_name=$(echo "$part" | sed 's/[<>=!].*//')
|
||||
python3 -c "
|
||||
import importlib
|
||||
import sys
|
||||
try:
|
||||
module = importlib.import_module(f'$package_name.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
print('\\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
|
||||
" | uv pip install -r -
|
||||
done
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
pre_run_checks "$env_name"
|
||||
run "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||
run
|
||||
|
|
|
|||
|
|
@ -91,21 +91,22 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
|
||||
logger.info(f"Configuring API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
for i, provider in enumerate(plist):
|
||||
if i >= 1:
|
||||
others = ", ".join(plist[i:])
|
||||
others = ", ".join(p.provider_type for p in plist[i:])
|
||||
logger.info(
|
||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"> Configuring provider `({provider_type})`")
|
||||
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
|
||||
provider_type=provider_type,
|
||||
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
|
||||
provider_type=provider.provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
|||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class RegistryEntrySource(StrEnum):
|
||||
via_register_api = "via_register_api"
|
||||
listed_from_provider = "listed_from_provider"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
|
|
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
|
|||
resource. This can be used to constrain access to the resource."""
|
||||
|
||||
owner: User | None = None
|
||||
source: RegistryEntrySource = RegistryEntrySource.via_register_api
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
|
|
@ -130,29 +136,54 @@ class RoutingTableProviderSpec(ProviderSpec):
|
|||
pip_packages: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
# provider_id of None means that the provider is not enabled - this happens
|
||||
# when the provider is enabled via a conditional environment variable
|
||||
provider_id: str | None
|
||||
provider_type: str
|
||||
config: dict[str, Any] = {}
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class BuildProvider(BaseModel):
|
||||
provider_type: str
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class DistributionSpec(BaseModel):
|
||||
description: str | None = Field(
|
||||
default="",
|
||||
description="Description of the distribution",
|
||||
)
|
||||
container_image: str | None = None
|
||||
providers: dict[str, str | list[str]] = Field(
|
||||
providers: dict[str, list[BuildProvider]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.""",
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
# provider_id of None means that the provider is not enabled - this happens
|
||||
# when the provider is enabled via a conditional environment variable
|
||||
provider_id: str | None
|
||||
provider_type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
category_levels: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
|
|
@ -381,6 +412,11 @@ a default SQLite store will be used.""",
|
|||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
|
|
@ -412,6 +448,10 @@ class BuildConfig(BaseModel):
|
|||
default_factory=list,
|
||||
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
||||
)
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from typing import Any
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
|
@ -96,12 +98,10 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
|
|||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(
|
||||
config=None,
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files.
|
||||
This function loads both built-in providers and external providers from YAML files or from their provided modules.
|
||||
External providers are loaded from a directory structure like:
|
||||
|
||||
providers.d/
|
||||
|
|
@ -122,8 +122,13 @@ def get_provider_registry(
|
|||
safety/
|
||||
llama-guard.yaml
|
||||
|
||||
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
|
||||
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
||||
There is special handling for all of the potential cases this method can be called from.
|
||||
|
||||
Args:
|
||||
config: Optional object containing the external providers directory path
|
||||
building: Optional bool delineating whether or not this is being called from a build process
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
|
@ -133,58 +138,140 @@ def get_provider_registry(
|
|||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
ret: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
logger.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
# Check if config has the external_providers_dir attribute
|
||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
# Refresh providable APIs with external APIs if any
|
||||
external_apis = load_external_apis(config)
|
||||
for api, api_spec in external_apis.items():
|
||||
name = api_spec.name.lower()
|
||||
logger.info(f"Importing external API {name} module {api_spec.module}")
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except (ImportError, AttributeError) as e:
|
||||
# Populate the registry with an empty dict to avoid breaking the provider registry
|
||||
# This assume that the in-tree provider(s) are not available for this API which means
|
||||
# that users will need to use external providers for this API.
|
||||
registry[api] = {}
|
||||
logger.error(
|
||||
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
||||
"Install the API package to load any in-tree providers for this API."
|
||||
)
|
||||
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
# Check if config has external providers
|
||||
if config:
|
||||
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
registry = get_external_providers_from_dir(registry, config)
|
||||
# else lets check for modules in each provider
|
||||
registry = get_external_providers_from_module(
|
||||
registry=registry,
|
||||
config=config,
|
||||
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
|
||||
)
|
||||
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
return registry
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
def get_external_providers_from_dir(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
logger.warning(
|
||||
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
||||
)
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in ret[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
ret[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
return ret
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in registry[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
registry[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_external_providers_from_module(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
provider_list = None
|
||||
if isinstance(config, BuildConfig):
|
||||
provider_list = config.distribution_spec.providers.items()
|
||||
else:
|
||||
provider_list = config.providers.items()
|
||||
if provider_list is None:
|
||||
logger.warning("Could not get list of providers from config")
|
||||
return registry
|
||||
for provider_api, providers in provider_list:
|
||||
for provider in providers:
|
||||
if not hasattr(provider, "module") or provider.module is None:
|
||||
continue
|
||||
# get provider using module
|
||||
try:
|
||||
if not building:
|
||||
package_name = provider.module.split("==")[0]
|
||||
module = importlib.import_module(f"{package_name}.provider")
|
||||
# if config class is wrong you will get an error saying module could not be imported
|
||||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
is_external=True,
|
||||
module=provider.module,
|
||||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# return a partially filled out spec that the build script will populate.
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
) from exc
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
||||
raise e
|
||||
return registry
|
||||
|
|
|
|||
54
llama_stack/distribution/external.py
Normal file
54
llama_stack/distribution/external.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||
"""Load external API specifications from the configured directory.
|
||||
|
||||
Args:
|
||||
config: StackRunConfig or BuildConfig containing the external APIs directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping API names to their specifications
|
||||
"""
|
||||
if not config or not config.external_apis_dir:
|
||||
return {}
|
||||
|
||||
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
||||
if not external_apis_dir.is_dir():
|
||||
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||
return {}
|
||||
|
||||
logger.info(f"Loading external APIs from {external_apis_dir}")
|
||||
external_apis: dict[Api, ExternalApiSpec] = {}
|
||||
|
||||
# Look for YAML files in the external APIs directory
|
||||
for yaml_path in external_apis_dir.glob("*.yaml"):
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
spec = ExternalApiSpec(**spec_data)
|
||||
api = Api.add(spec.name)
|
||||
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||
external_apis[api] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to load external API spec from {yaml_path}")
|
||||
raise
|
||||
|
||||
return external_apis
|
||||
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
|
|||
VersionInfo,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
|
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
|
|||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
ret = []
|
||||
all_endpoints = get_all_api_routes()
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
|
|
@ -53,7 +55,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e in endpoints
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
@ -66,7 +69,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
|
|
@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
if not self.skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
return self.loop.run_until_complete(self.async_client.initialize())
|
||||
# use a new event loop to avoid interfering with the main event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(self.async_client.initialize())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
|
|
@ -243,15 +249,16 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
file=sys.stderr,
|
||||
)
|
||||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
# Convert Provider objects to their types
|
||||
provider_types: dict[str, str | list[str]] = {}
|
||||
for api, providers in self.config.providers.items():
|
||||
types = [p.provider_type for p in providers]
|
||||
# Convert single-item lists to strings
|
||||
provider_types[api] = types[0] if len(types) == 1 else types
|
||||
providers: dict[str, list[BuildProvider]] = {}
|
||||
for api, run_providers in self.config.providers.items():
|
||||
for provider in run_providers:
|
||||
providers.setdefault(api, []).append(
|
||||
BuildProvider(provider_type=provider.provider_type, module=provider.module)
|
||||
)
|
||||
providers = dict(providers)
|
||||
build_config = BuildConfig(
|
||||
distribution_spec=DistributionSpec(
|
||||
providers=provider_types,
|
||||
providers=providers,
|
||||
),
|
||||
external_providers_dir=self.config.external_providers_dir,
|
||||
)
|
||||
|
|
@ -353,13 +360,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
|
|
@ -409,12 +418,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
|
|
@ -445,8 +455,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]]
|
||||
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
|
|
@ -468,7 +479,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
|
|
|
|||
|
|
@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None:
|
|||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__authenticated_user")
|
||||
|
||||
|
||||
def user_from_scope(scope: dict) -> User | None:
|
||||
"""Create a User object from ASGI scope data (set by authentication middleware)"""
|
||||
user_attributes = scope.get("user_attributes", {})
|
||||
principal = scope.get("principal", "")
|
||||
|
||||
# auth not enabled
|
||||
if not principal and not user_attributes:
|
||||
return None
|
||||
|
||||
return User(principal=principal, attributes=user_attributes)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
|
|||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
|
|
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def api_protocol_map() -> dict[Api, Any]:
|
||||
return {
|
||||
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
|
||||
"""Get a mapping of API types to their protocol classes.
|
||||
|
||||
Args:
|
||||
external_apis: Optional dictionary of external API specifications
|
||||
|
||||
Returns:
|
||||
Dictionary mapping API types to their protocol classes
|
||||
"""
|
||||
protocols = {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
|
|
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
|
|||
Api.files: Files,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
for api, api_spec in external_apis.items():
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
api_class = getattr(module, api_spec.protocol)
|
||||
|
||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||
protocols[api] = api_class
|
||||
except (ImportError, AttributeError):
|
||||
logger.exception(f"Failed to load external API {api_spec.name}")
|
||||
|
||||
return protocols
|
||||
|
||||
|
||||
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
|
||||
external_apis = load_external_apis(config)
|
||||
return {
|
||||
**api_protocol_map(),
|
||||
**api_protocol_map(external_apis),
|
||||
Api.inference: InferenceProvider,
|
||||
}
|
||||
|
||||
|
|
@ -250,7 +273,7 @@ async def instantiate_providers(
|
|||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
) -> dict:
|
||||
) -> dict[Api, Any]:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
|
|
@ -322,7 +345,7 @@ async def instantiate_provider(
|
|||
policy: list[AccessRule],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
|
|
@ -360,7 +383,7 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check()
|
||||
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
|
|
|
|||
|
|
@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
|
|||
logger.debug(
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
start_index=start_index,
|
||||
limit=limit,
|
||||
|
|
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
|
|||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
|
|||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
score_response = await provider.score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
|
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
|
|||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
score_response = await provider.score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
|
@ -97,7 +99,8 @@ class EvalRouter(Eval):
|
|||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
|
@ -110,7 +113,8 @@ class EvalRouter(Eval):
|
|||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
|
|
@ -123,7 +127,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
|
|
@ -131,7 +136,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> None:
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
await provider.job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
|
@ -142,7 +148,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ class InferenceRouter(Inference):
|
|||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
|
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_chat_completion(
|
||||
model_id=model_id,
|
||||
messages_batch=messages_batch,
|
||||
|
|
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
|
|||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
content=content,
|
||||
|
|
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||
|
||||
async def embeddings(
|
||||
|
|
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
|
|||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
|
|
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
|
|||
suffix=suffix,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_completion(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
|
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
|
|||
user=user,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if self.store:
|
||||
|
|
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
|
|||
user=user,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_embeddings(**params)
|
||||
|
||||
async def list_chat_completions(
|
||||
|
|
|
|||
|
|
@ -50,7 +50,8 @@ class SafetyRouter(Safety):
|
|||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
return await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
|
|
|
|||
|
|
@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_db_ids, query_config)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
|
|
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
|
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
|
|||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.query_chunks(vector_db_id, query, params)
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
|
|
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
|
|||
provider_vector_db_id=vector_db_id,
|
||||
vector_db_name=name,
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
|
||||
return await provider.openai_create_vector_store(
|
||||
name=name,
|
||||
file_ids=file_ids,
|
||||
expires_after=expires_after,
|
||||
|
|
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
|
|||
all_stores = []
|
||||
for vector_db in vector_dbs:
|
||||
try:
|
||||
vector_store = await self.routing_table.get_provider_impl(
|
||||
vector_db.identifier
|
||||
).openai_retrieve_vector_store(vector_db.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
||||
all_stores.append(vector_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||
|
|
@ -214,9 +216,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
|
|
@ -226,9 +226,7 @@ class VectorIORouter(VectorIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
return await self.routing_table.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
|
|
@ -240,12 +238,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
# drop from registry
|
||||
await self.routing_table.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
|
|
@ -258,9 +251,7 @@ class VectorIORouter(VectorIO):
|
|||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
return await self.routing_table.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
|
|
@ -278,9 +269,7 @@ class VectorIORouter(VectorIO):
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
|
@ -297,9 +286,7 @@ class VectorIORouter(VectorIO):
|
|||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
return await self.routing_table.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
|
|
@ -314,9 +301,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
@ -327,9 +312,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
@ -341,9 +324,7 @@ class VectorIORouter(VectorIO):
|
|||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
return await self.routing_table.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
|
@ -355,9 +336,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
return await self.routing_table.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import Action
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
|
|
@ -115,7 +117,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
async def refresh(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
from .benchmarks import BenchmarksRoutingTable
|
||||
from .datasets import DatasetsRoutingTable
|
||||
from .models import ModelsRoutingTable
|
||||
|
|
@ -204,11 +209,24 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def assert_action_allowed(
|
||||
self,
|
||||
action: Action,
|
||||
type: str,
|
||||
identifier: str,
|
||||
) -> None:
|
||||
"""Fetch a registered object by type/identifier and enforce the given action permission."""
|
||||
obj = await self.get_object_by_identifier(type, identifier)
|
||||
if obj is None:
|
||||
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
|
||||
user = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, action, obj, user):
|
||||
raise AccessDeniedError(action, obj, user)
|
||||
|
||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
|
@ -220,3 +238,28 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
||||
# first try to get the model by identifier
|
||||
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
||||
model = await routing_table.get_object_by_identifier("model", model_id)
|
||||
if model is not None:
|
||||
return model
|
||||
|
||||
logger.warning(
|
||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
||||
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
|
||||
)
|
||||
# if not found, this means model_id is an unscoped provider_model_id, we need
|
||||
# to iterate (given a lack of an efficient index on the KVStore)
|
||||
models = await routing_table.get_all_with_type("model")
|
||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||
if len(matching_models) == 0:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
||||
if len(matching_models) > 1:
|
||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||
|
||||
return matching_models[0]
|
||||
|
|
|
|||
|
|
@ -10,15 +10,37 @@ from typing import Any
|
|||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
listed_providers: set[str] = set()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
refresh = await provider.should_refresh_models()
|
||||
refresh = refresh or provider_id not in self.listed_providers
|
||||
if not refresh:
|
||||
continue
|
||||
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
if models is None:
|
||||
continue
|
||||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
|
|
@ -36,10 +58,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
return model
|
||||
return await lookup_model(self, model_id)
|
||||
|
||||
async def get_provider_impl(self, model_id: str) -> Any:
|
||||
model = await lookup_model(self, model_id)
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
|
@ -49,28 +72,38 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
|
||||
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
|
||||
provider_model_id = provider_model_id or model_id
|
||||
metadata = metadata or {}
|
||||
model_type = model_type or ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
|
||||
# an identifier different than provider_model_id implies it is an alias, so that
|
||||
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
||||
# so as a general rule we must use the provider_id to disambiguate.
|
||||
|
||||
if model_id != provider_model_id:
|
||||
identifier = model_id
|
||||
else:
|
||||
identifier = f"{provider_id}/{provider_model_id}"
|
||||
|
||||
model = ModelWithOwner(
|
||||
identifier=model_id,
|
||||
identifier=identifier,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
source=RegistryEntrySource.via_register_api,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
|
@ -81,7 +114,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
async def update_registered_llm_models(
|
||||
async def update_registered_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
|
|
@ -92,18 +125,22 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
# from run.yaml) that we need to keep track of
|
||||
model_ids = {}
|
||||
for model in existing_models:
|
||||
# we leave embeddings models alone because often we don't get metadata
|
||||
# (embedding dimension, etc.) from the provider
|
||||
if model.provider_id == provider_id and model.model_type == ModelType.llm:
|
||||
if model.provider_id != provider_id:
|
||||
continue
|
||||
if model.source == RegistryEntrySource.via_register_api:
|
||||
model_ids[model.provider_resource_id] = model.identifier
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
continue
|
||||
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
|
||||
for model in models:
|
||||
if model.model_type != ModelType.llm:
|
||||
continue
|
||||
if model.provider_resource_id in model_ids:
|
||||
model.identifier = model_ids[model.provider_resource_id]
|
||||
# avoid overwriting a non-provider-registered model entry
|
||||
continue
|
||||
|
||||
if model.identifier == model.provider_resource_id:
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
|
|
@ -113,5 +150,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id=provider_id,
|
||||
metadata=model.metadata,
|
||||
model_type=model.model_type,
|
||||
source=RegistryEntrySource.listed_from_provider,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
if routing_key in self.tool_to_toolgroup:
|
||||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return super().get_provider_impl(routing_key, provider_id)
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
if toolgroup_id:
|
||||
|
|
@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
return ListToolsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
|
|
|
|||
|
|
@ -4,17 +4,30 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -38,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
|
@ -49,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding:
|
||||
|
|
@ -74,3 +86,145 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if existing_vector_db is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,12 @@
|
|||
import json
|
||||
|
||||
import httpx
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, User
|
||||
from llama_stack.distribution.request_headers import user_from_scope
|
||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
|
@ -78,12 +81,14 @@ class AuthenticationMiddleware:
|
|||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_config: AuthenticationConfig):
|
||||
def __init__(self, app, auth_config: AuthenticationConfig, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# First, handle authentication
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
|
|
@ -121,15 +126,50 @@ class AuthenticationMiddleware:
|
|||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
|
||||
# Scope-based API access control
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if webmethod.required_scope:
|
||||
user = user_from_scope(scope)
|
||||
if not _has_required_scope(webmethod.required_scope, user):
|
||||
return await self._send_auth_error(
|
||||
send,
|
||||
f"Access denied: user does not have required scope: {webmethod.required_scope}",
|
||||
status=403,
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message):
|
||||
async def _send_auth_error(self, send, message, status=401):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 401,
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps({"error": {"message": message}}).encode()
|
||||
error_key = "message" if status == 401 else "detail"
|
||||
error_msg = json.dumps({"error": {error_key: message}}).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
|
||||
def _has_required_scope(required_scope: str, user: User | None) -> bool:
|
||||
# if no user, assume auth is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not user.attributes:
|
||||
return False
|
||||
|
||||
user_scopes = user.attributes.get("scopes", [])
|
||||
return required_scope in user_scopes
|
||||
|
|
|
|||
|
|
@ -12,17 +12,18 @@ from typing import Any
|
|||
from aiohttp import hdrs
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
EndpointFunc = Callable[..., Any]
|
||||
PathParams = dict[str, str]
|
||||
RouteInfo = tuple[EndpointFunc, str]
|
||||
RouteInfo = tuple[EndpointFunc, str, WebMethod]
|
||||
PathImpl = dict[str, RouteInfo]
|
||||
RouteImpls = dict[str, PathImpl]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
|
|
@ -31,10 +32,12 @@ def toolgroup_protocol_map():
|
|||
}
|
||||
|
||||
|
||||
def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
protocols = api_protocol_map(external_apis)
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
routes = []
|
||||
|
|
@ -65,7 +68,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
|||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = routes
|
||||
|
|
@ -73,8 +76,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
|||
return apis
|
||||
|
||||
|
||||
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||
routes = get_all_api_routes()
|
||||
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
|
||||
api_to_routes = get_all_api_routes(external_apis)
|
||||
route_impls: RouteImpls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
|
|
@ -88,10 +91,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
|||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_routes in routes.items():
|
||||
for api, api_routes in api_to_routes.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for route in api_routes:
|
||||
for route, webmethod in api_routes:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, route.name)
|
||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||
|
|
@ -104,6 +107,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
|||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
||||
func,
|
||||
route.path,
|
||||
webmethod,
|
||||
)
|
||||
|
||||
return route_impls
|
||||
|
|
@ -118,7 +122,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
|||
route_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
|
|
@ -127,11 +131,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
|||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, descriptive_name) in impls.items():
|
||||
for regex, (func, route_path, webmethod) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, descriptive_name
|
||||
return func, path_params, route_path, webmethod
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from openai import BadRequestError
|
|||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_template_args, get_config_from_args
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
|
@ -39,7 +40,12 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.routes import (
|
||||
find_matching_route,
|
||||
|
|
@ -50,9 +56,11 @@ from llama_stack.distribution.stack import (
|
|||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
shutdown_stack,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -144,18 +152,7 @@ async def shutdown(app):
|
|||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
"""
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
await shutdown_stack(app.__llama_stack_impls__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -220,9 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
@functools.wraps(func)
|
||||
async def route_handler(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", "")
|
||||
user = User(principal=principal, attributes=user_attributes)
|
||||
user = user_from_scope(request.scope)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
|
|
@ -280,9 +275,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls):
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
|
|
@ -299,10 +295,12 @@ class TracingMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
|
|
@ -319,6 +317,7 @@ class TracingMiddleware:
|
|||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
|
|
@ -377,20 +376,8 @@ class ClientVersionMiddleware:
|
|||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
"--yaml-config",
|
||||
dest="config",
|
||||
help="(Deprecated) Path to YAML configuration file - use --config instead",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
dest="config",
|
||||
help="Path to YAML configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template",
|
||||
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
|
||||
)
|
||||
|
||||
add_config_template_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
|
|
@ -409,20 +396,8 @@ def main(args: argparse.Namespace | None = None):
|
|||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
log_line = ""
|
||||
if hasattr(args, "config") and args.config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif hasattr(args, "template") and args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --config or --template must be provided")
|
||||
config_or_template = get_config_from_args(args)
|
||||
config_file = resolve_config_or_template(config_or_template, Mode.RUN)
|
||||
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
|
|
@ -442,9 +417,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
|
|
@ -457,10 +429,21 @@ def main(args: argparse.Namespace | None = None):
|
|||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||
else:
|
||||
if config.server.quota:
|
||||
quota = config.server.quota
|
||||
|
|
@ -491,24 +474,14 @@ def main(args: argparse.Namespace | None = None):
|
|||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_routes = get_all_api_routes()
|
||||
# Load external APIs if configured
|
||||
external_apis = load_external_apis(config)
|
||||
all_routes = get_all_api_routes(external_apis)
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
|
|
@ -527,9 +500,12 @@ def main(args: argparse.Namespace | None = None):
|
|||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
impl = impls[api]
|
||||
try:
|
||||
impl = impls[api]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
||||
|
||||
for route in routes:
|
||||
for route, _ in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||
|
|
@ -558,7 +534,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls)
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
@ -592,12 +568,29 @@ def main(args: argparse.Namespace | None = None):
|
|||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
|
|
@ -618,11 +611,8 @@ def extract_path_params(route: str) -> list[str]:
|
|||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
if (
|
||||
obj.get("provider_id") == "__disabled__"
|
||||
or obj.get("shield_id") == "__disabled__"
|
||||
or obj.get("provider_model_id") == "__disabled__"
|
||||
):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib.resources
|
||||
import os
|
||||
import re
|
||||
|
|
@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -90,6 +92,10 @@ RESOURCES = [
|
|||
]
|
||||
|
||||
|
||||
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
|
||||
REGISTRY_REFRESH_TASK = None
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
|
|
@ -99,23 +105,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
# In complex templates, like our starter template, we may have dynamic model ids
|
||||
# given by environment variables. This allows those environment variables to have
|
||||
# a default value of __disabled__ to skip registration of the model if not set.
|
||||
if (
|
||||
hasattr(obj, "provider_model_id")
|
||||
and obj.provider_model_id is not None
|
||||
and "__disabled__" in obj.provider_model_id
|
||||
):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
|
||||
continue
|
||||
|
||||
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
|
|
@ -324,9 +317,61 @@ async def construct_stack(
|
|||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
return impls
|
||||
|
||||
|
||||
async def shutdown_stack(impls: dict[Api, Any]):
|
||||
for impl in impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning(f"No shutdown method for {impl_name}")
|
||||
except TimeoutError:
|
||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
REGISTRY_REFRESH_TASK.cancel()
|
||||
|
||||
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
logger.debug("refreshing registry")
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
||||
|
||||
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||
logger.info("starting registry refresh task")
|
||||
while True:
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
|||
set -x
|
||||
|
||||
if [ -n "$yaml_config" ]; then
|
||||
yaml_config_arg="--config $yaml_config"
|
||||
yaml_config_arg="$yaml_config"
|
||||
else
|
||||
yaml_config_arg=""
|
||||
fi
|
||||
|
|
|
|||
125
llama_stack/distribution/utils/config_resolution.py
Normal file
125
llama_stack/distribution/utils/config_resolution.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="config_resolution")
|
||||
|
||||
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
|
||||
class Mode(StrEnum):
|
||||
RUN = "run"
|
||||
BUILD = "build"
|
||||
|
||||
|
||||
def resolve_config_or_template(
|
||||
config_or_template: str,
|
||||
mode: Mode = Mode.RUN,
|
||||
) -> Path:
|
||||
"""
|
||||
Resolve a config/template argument to a concrete config file path.
|
||||
|
||||
Args:
|
||||
config_or_template: User input (file path, template name, or built distribution)
|
||||
mode: Mode resolving for ("run", "build", "server")
|
||||
|
||||
Returns:
|
||||
Path to the resolved config file
|
||||
|
||||
Raises:
|
||||
ValueError: If resolution fails
|
||||
"""
|
||||
|
||||
# Strategy 1: Try as file path first
|
||||
config_path = Path(config_or_template)
|
||||
if config_path.exists() and config_path.is_file():
|
||||
logger.info(f"Using file path: {config_path}")
|
||||
return config_path.resolve()
|
||||
|
||||
# Strategy 2: Try as template name (if no .yaml extension)
|
||||
if not config_or_template.endswith(".yaml"):
|
||||
template_config = _get_template_config_path(config_or_template, mode)
|
||||
if template_config.exists():
|
||||
logger.info(f"Using template: {template_config}")
|
||||
return template_config
|
||||
|
||||
# Strategy 3: Try as built distribution name
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
# Strategy 4: Failed - provide helpful error
|
||||
raise ValueError(_format_resolution_error(config_or_template, mode))
|
||||
|
||||
|
||||
def _get_template_config_path(template_name: str, mode: Mode) -> Path:
|
||||
"""Get the config file path for a template."""
|
||||
return TEMPLATE_DIR / template_name / f"{mode}.yaml"
|
||||
|
||||
|
||||
def _format_resolution_error(config_or_template: str, mode: Mode) -> str:
|
||||
"""Format a helpful error message for resolution failures."""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
template_path = _get_template_config_path(config_or_template, mode)
|
||||
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
|
||||
available_templates = _get_available_templates()
|
||||
templates_str = ", ".join(available_templates) if available_templates else "none found"
|
||||
|
||||
return f"""Could not resolve config or template '{config_or_template}'.
|
||||
|
||||
Tried the following locations:
|
||||
1. As file path: {Path(config_or_template).resolve()}
|
||||
2. As template: {template_path}
|
||||
3. As built distribution: ({distrib_path}, {distrib_path2})
|
||||
|
||||
Available templates: {templates_str}
|
||||
|
||||
Did you mean one of these templates?
|
||||
{_format_template_suggestions(available_templates, config_or_template)}
|
||||
"""
|
||||
|
||||
|
||||
def _get_available_templates() -> list[str]:
|
||||
"""Get list of available template names."""
|
||||
if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
|
||||
return []
|
||||
|
||||
return list(
|
||||
set(
|
||||
[d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _format_template_suggestions(templates: list[str], user_input: str) -> str:
|
||||
"""Format template suggestions for error messages, showing closest matches first."""
|
||||
if not templates:
|
||||
return " (no templates found)"
|
||||
|
||||
import difflib
|
||||
|
||||
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
|
||||
close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3)
|
||||
display_templates = close_matches if close_matches else templates[:3]
|
||||
|
||||
suggestions = [f" - {t}" for t in display_templates]
|
||||
return "\n".join(suggestions)
|
||||
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
|||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
def formulate_run_args(image_type: str, image_name: str) -> list[str]:
|
||||
env_name = ""
|
||||
|
||||
if image_type == LlamaStackImageType.CONDA.value:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from logging.config import dictConfig
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ CATEGORIES = [
|
|||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
"telemetry",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
|
|
@ -113,6 +115,11 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
|
|||
return category_levels
|
||||
|
||||
|
||||
def strip_rich_markup(text):
|
||||
"""Remove Rich markup tags like [dim], [bold magenta], etc."""
|
||||
return re.sub(r"\[/?[a-zA-Z0-9 _#=,]+\]", "", text)
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=150)
|
||||
|
|
@ -131,6 +138,19 @@ class CustomRichHandler(RichHandler):
|
|||
self.markup = original_markup
|
||||
|
||||
|
||||
class CustomFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename, mode="a", encoding=None, delay=False):
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
# Default formatter to match console output
|
||||
self.default_formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(category)s: %(message)s")
|
||||
self.setFormatter(self.default_formatter)
|
||||
|
||||
def emit(self, record):
|
||||
if hasattr(record, "msg"):
|
||||
record.msg = strip_rich_markup(str(record.msg))
|
||||
super().emit(record)
|
||||
|
||||
|
||||
def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
|
@ -167,8 +187,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
|
|||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"()": CustomFileHandler,
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
|
|
|
|||
|
|
@ -43,10 +43,24 @@ class ModelsProtocolPrivate(Protocol):
|
|||
-> Provider uses provider-model-id for inference
|
||||
"""
|
||||
|
||||
# this should be called `on_model_register` or something like that.
|
||||
# the provider should _not_ be able to change the object in this
|
||||
# callback
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
||||
# the Stack router will query each provider for their list of models
|
||||
# if a `refresh_interval_seconds` is provided, this method will be called
|
||||
# periodically to refresh the list of models
|
||||
#
|
||||
# NOTE: each model returned will be registered with the model registry. this means
|
||||
# a callback to the `register_model()` method will be made. this is duplicative and
|
||||
# may be removed in the future.
|
||||
async def list_models(self) -> list[Model] | None: ...
|
||||
|
||||
async def should_refresh_models(self) -> bool: ...
|
||||
|
||||
|
||||
class ShieldsProtocolPrivate(Protocol):
|
||||
async def register_shield(self, shield: Shield) -> None: ...
|
||||
|
|
@ -104,6 +118,19 @@ class ProviderSpec(BaseModel):
|
|||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||
)
|
||||
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: list[str] = Field(default_factory=list)
|
||||
|
||||
|
|
@ -113,7 +140,7 @@ class ProviderSpec(BaseModel):
|
|||
|
||||
|
||||
class RoutingTable(Protocol):
|
||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
|
|
@ -124,7 +151,7 @@ class AdapterSpec(BaseModel):
|
|||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
default_factory=str,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
|
|
@ -162,14 +189,7 @@ The container image to use for this implementation. If one is provided, pip_pack
|
|||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||
""",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
# module field is inherited from ProviderSpec
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
|
@ -212,9 +232,7 @@ API responses, specify the adapter here.
|
|||
def container_image(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
return self.adapter.module
|
||||
# module field is inherited from ProviderSpec
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
|
|
@ -226,14 +244,19 @@ API responses, specify the adapter here.
|
|||
|
||||
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
|
||||
api: Api,
|
||||
adapter: AdapterSpec,
|
||||
api_dependencies: list[Api] | None = None,
|
||||
optional_api_dependencies: list[Api] | None = None,
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
module=adapter.module,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
optional_api_dependencies=optional_api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
|
@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str:
|
|||
|
||||
|
||||
async def get_raw_document_text(document: Document) -> str:
|
||||
if not document.mime_type.startswith("text/"):
|
||||
# Handle deprecated text/yaml mime type with warning
|
||||
if document.mime_type == "text/yaml":
|
||||
warnings.warn(
|
||||
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
|
|
|
|||
|
|
@ -128,6 +128,11 @@ class AgentPersistence:
|
|||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
# The kvstore does not guarantee order, so we sort by started_at
|
||||
# to ensure consistent ordering of turns.
|
||||
turns.sort(key=lambda t: t.started_at)
|
||||
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
|
|
|
|||
|
|
@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
|
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
|
|||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
|
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
identifier="all-MiniLM-L6-v2",
|
||||
provider_resource_id="all-MiniLM-L6-v2",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# Allow any model to be registered as a shield
|
||||
# The model will be validated during runtime when making inference calls
|
||||
pass
|
||||
model_id = shield.provider_resource_id
|
||||
if not model_id:
|
||||
raise ValueError("Llama Guard shield must have a model id")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan
|
|||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
from opentelemetry.trace.status import StatusCode
|
||||
|
||||
# Colors for console output
|
||||
COLORS = {
|
||||
"reset": "\033[0m",
|
||||
"bold": "\033[1m",
|
||||
"dim": "\033[2m",
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
}
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name="console_span_processor", category="telemetry")
|
||||
|
||||
|
||||
class ConsoleSpanProcessor(SpanProcessor):
|
||||
|
|
@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[START]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[END]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
|
||||
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
|
||||
span_context += " [bold red][ERROR][/bold red]"
|
||||
elif span.status.status_code != StatusCode.UNSET:
|
||||
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
|
||||
|
||||
span_context += f" [{span.status.status_code}]"
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
||||
|
||||
print(span_context)
|
||||
span_context += f" ({duration_ms:.2f}ms)"
|
||||
logger.info(span_context)
|
||||
|
||||
if self.print_attributes and span.attributes:
|
||||
for key, value in span.attributes.items():
|
||||
|
|
@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
str_value = str(value)
|
||||
if len(str_value) > 1000:
|
||||
str_value = str_value[:997] + "..."
|
||||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
logger.info(f" [dim]{key}[/dim]: {str_value}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, dict | list):
|
||||
if isinstance(message, dict) or isinstance(message, list):
|
||||
message = json.dumps(message, indent=2)
|
||||
|
||||
severity_colors = {
|
||||
"error": f"{COLORS['bold']}{COLORS['red']}",
|
||||
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
|
||||
"info": COLORS["white"],
|
||||
"debug": COLORS["dim"],
|
||||
}
|
||||
msg_color = severity_colors.get(severity, COLORS["white"])
|
||||
|
||||
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
|
||||
|
||||
severity_color = {
|
||||
"error": "red",
|
||||
"warn": "yellow",
|
||||
"info": "white",
|
||||
"debug": "dim",
|
||||
}.get(severity, "white")
|
||||
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
|
||||
logger.info(f"/r[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
|||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,12 +6,25 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
def sample_run_config(
|
||||
cls, __distro_dir__: str, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": db_path,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_inline_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
|
||||
# A list of chunk id's in the same order as they are in the index,
|
||||
# must be updated when chunks are added or removed
|
||||
self.chunk_id_lock = asyncio.Lock()
|
||||
self.chunk_ids: list[Any] = []
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
instance = cls(dimension, kvstore, bank_id)
|
||||
|
|
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
||||
try:
|
||||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
raise ValueError(
|
||||
|
|
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
|
|||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
async with self.chunk_id_lock:
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
|
||||
|
||||
# Save updated index
|
||||
await self._save_index()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
if chunk_id not in self.chunk_ids:
|
||||
return
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
index = self.chunk_ids.index(chunk_id)
|
||||
self.index.remove_ids(np.array([index]))
|
||||
|
||||
new_chunk_by_index = {}
|
||||
for idx, chunk in self.chunk_by_index.items():
|
||||
# Shift all chunks after the removed chunk to the left
|
||||
if idx > index:
|
||||
new_chunk_by_index[idx - 1] = chunk
|
||||
else:
|
||||
new_chunk_by_index[idx] = chunk
|
||||
self.chunk_by_index = new_chunk_by_index
|
||||
self.chunk_ids.pop(index)
|
||||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
|
|
@ -261,47 +289,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file data to kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=content_key, value=json.dumps(file_contents))
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
stored_data = await self.kvstore.get(key)
|
||||
return json.loads(stored_data) if stored_data else {}
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
|
||||
stored_data = await self.kvstore.get(key)
|
||||
return json.loads(stored_data) if stored_data else []
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store data from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
|
||||
keys_to_delete = [
|
||||
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
|
||||
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
try:
|
||||
await self.kvstore.delete(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete key {key}: {e}")
|
||||
continue
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a faiss index"""
|
||||
faiss_index = self.cache[store_id].index
|
||||
for chunk_id in chunk_ids:
|
||||
await faiss_index.delete_chunk(chunk_id)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
|
|
@ -426,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
"""Remove a chunk from the SQLite vector store."""
|
||||
|
||||
def _delete_chunk():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
|
||||
# Delete from metadata table
|
||||
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
# Delete from vector table
|
||||
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
# Delete from FTS table
|
||||
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete_chunk)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
"""
|
||||
|
|
@ -506,140 +534,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to SQLite database."""
|
||||
|
||||
def _create_or_store():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
# Create a table to persist OpenAI vector store files.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
|
||||
store_id TEXT,
|
||||
file_id TEXT,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (store_id, file_id)
|
||||
);
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
|
||||
store_id TEXT,
|
||||
file_id TEXT,
|
||||
contents TEXT,
|
||||
PRIMARY KEY (store_id, file_id)
|
||||
);
|
||||
""")
|
||||
connection.commit()
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
|
||||
(store_id, file_id, json.dumps(file_info)),
|
||||
)
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_store_files_contents (store_id, file_id, contents) VALUES (?, ?, ?)",
|
||||
(store_id, file_id, json.dumps(file_contents)),
|
||||
)
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_create_or_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from SQLite database."""
|
||||
|
||||
def _load():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT metadata FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
(metadata,) = row
|
||||
return metadata
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
stored_data = await asyncio.to_thread(_load)
|
||||
return json.loads(stored_data) if stored_data else {}
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from SQLite database."""
|
||||
|
||||
def _load():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
(contents,) = row
|
||||
return contents
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
stored_contents = await asyncio.to_thread(_load)
|
||||
return json.loads(stored_contents) if stored_contents else []
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in SQLite database."""
|
||||
|
||||
def _update():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"UPDATE openai_vector_store_files SET metadata = ? WHERE store_id = ? AND file_id = ?",
|
||||
(json.dumps(file_info), store_id, file_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_update)
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store file metadata from SQLite database."""
|
||||
|
||||
def _delete():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
@ -655,3 +549,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
|
|
|
|||
|
|
@ -224,17 +224,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
||||
description="Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
@ -246,50 +235,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.together_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
|
||||
description="Together AI OpenAI-compatible provider for using Together models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.groq_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
|
||||
description="Groq OpenAI-compatible provider for using Groq models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
|
||||
description="Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval.
|
|||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Faiss.
|
||||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
|
@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class AnthropicConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"anthropic/claude-3-5-sonnet-latest",
|
||||
"anthropic/claude-3-7-sonnet-latest",
|
||||
"anthropic/claude-3-5-haiku-latest",
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
|
@ -21,17 +21,17 @@ MODEL_ENTRIES = (
|
|||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
provider_model_id="voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
provider_model_id="voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
provider_model_id="voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
|
|
|
|||
|
|
@ -63,18 +63,20 @@ class BedrockInferenceAdapter(
|
|||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
self._client = create_bedrock_client(config)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
if self._client is None:
|
||||
self._client = create_bedrock_client(self._config)
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ class CerebrasInferenceAdapter(
|
|||
)
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class CerebrasImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import CerebrasCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .cerebras import CerebrasCompatInferenceAdapter
|
||||
|
||||
adapter = CerebrasCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..cerebras.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: CerebrasCompatConfig
|
||||
|
||||
def __init__(self, config: CerebrasCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="cerebras_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
cerebras_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Cerebras models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Cerebras API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.cerebras.ai/v1",
|
||||
description="The URL for the Cerebras API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.cerebras.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -25,8 +25,8 @@ class DatabricksImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
||||
url: str = "${env.DATABRICKS_URL:=}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksImplConfig(BaseModel):
|
||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
|
|
@ -23,7 +24,7 @@ class FireworksImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import FireworksCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .fireworks import FireworksCompatInferenceAdapter
|
||||
|
||||
adapter = FireworksCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class FireworksProviderDataValidator(BaseModel):
|
||||
fireworks_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Fireworks models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Fireworks API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..fireworks.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: FireworksCompatConfig
|
||||
|
||||
def __init__(self, config: FireworksCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="fireworks_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -26,7 +26,7 @@ class GeminiConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="gemini",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash",
|
||||
"gemini/gemini-2.5-flash",
|
||||
"gemini/gemini-2.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
|
@ -23,7 +23,7 @@ MODEL_ENTRIES = (
|
|||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
provider_model_id="text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="groq",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
|
|
@ -96,7 +97,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
tool_choice = "required"
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id.replace("groq/", ""),
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
|
|
|||
|
|
@ -14,19 +14,19 @@ SAFETY_MODELS_ENTRIES = []
|
|||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-8b-8192",
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"groq/llama-3.1-8b-instant",
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-70b-8192",
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
|
|
@ -34,23 +34,15 @@ MODEL_ENTRIES = [
|
|||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.2-3b-preview",
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-scout-17b-16e-instruct",
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import GroqCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqCompatInferenceAdapter
|
||||
|
||||
adapter = GroqCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.groq.com/openai/v1",
|
||||
description="The URL for the Groq API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.groq.com/openai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..groq.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: GroqCompatConfig
|
||||
|
||||
def __init__(self, config: GroqCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -5,55 +5,53 @@
|
|||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
"""
|
||||
Llama API Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
_config: LlamaCompatConfig
|
||||
|
||||
def __init__(self, config: LlamaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="meta_llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Check if a specific model is available from Llama API.
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
:return: The Llama API base URL
|
||||
"""
|
||||
try:
|
||||
llama_api_client = self._get_llama_api_client()
|
||||
retrieved_model = await llama_api_client.models.retrieve(model)
|
||||
logger.info(f"Model {retrieved_model.id} is available from Llama API")
|
||||
return True
|
||||
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available from Llama API")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability from Llama API: {e}")
|
||||
return False
|
||||
return self.config.openai_compat_api_base
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
||||
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
|
||||
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,8 @@
|
|||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
|
||||
from openai import APIConnectionError, BadRequestError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -28,12 +27,6 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
|
@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
|
@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability(). It also
|
||||
must come before Inference to ensure that OpenAIMixin methods are available
|
||||
in the Inference interface.
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||
"""
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
|
@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
self._config = config
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
Check if a specific model is available.
|
||||
Get the API key for OpenAI mixin.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
try:
|
||||
await self._client.models.retrieve(model)
|
||||
return True
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability: {e}")
|
||||
return False
|
||||
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncOpenAI:
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Returns an OpenAI client for the configured NVIDIA API endpoint.
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: An OpenAI client
|
||||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
|
||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
if not self.model_store:
|
||||
raise RuntimeError("Model store is not set")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {model_id} is unknown")
|
||||
return model.provider_model_id
|
||||
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.completions.create(**request)
|
||||
response = await self.client.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._client.embeddings.create(
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
|
|
@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
response = await self.client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._client.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._client.chat.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
|
|
|||
|
|
@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
|||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
|
||||
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -96,14 +96,16 @@ class OllamaInferenceAdapter(
|
|||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
||||
self._openai_client = None
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = AsyncClient(host=self.config.url)
|
||||
return self._client
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
|
|
@ -119,59 +121,61 @@ class OllamaInferenceAdapter(
|
|||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
if self.config.refresh_models:
|
||||
logger.debug("ollama starting background model refresh task")
|
||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||
|
||||
def cb(task):
|
||||
if task.cancelled():
|
||||
import traceback
|
||||
|
||||
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||
elif task.exception():
|
||||
logger.error(f"ollama background refresh task died: {task.exception()}")
|
||||
else:
|
||||
logger.error("ollama background refresh task completed unexpectedly")
|
||||
|
||||
self._refresh_task.add_done_callback(cb)
|
||||
|
||||
async def _refresh_models(self) -> None:
|
||||
# Wait for model store to be available (with timeout)
|
||||
waited_time = 0
|
||||
while not self.model_store and waited_time < 60:
|
||||
await asyncio.sleep(1)
|
||||
waited_time += 1
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set after waiting 60 seconds")
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
provider_id = self.__provider_id__
|
||||
while True:
|
||||
try:
|
||||
response = await self.client.list()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list models: {str(e)}")
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
response = await self.client.list()
|
||||
|
||||
# always add the two embedding models which can be pulled on demand
|
||||
models = [
|
||||
Model(
|
||||
identifier="all-minilm:l6-v2",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
# add all-minilm alias
|
||||
Model(
|
||||
identifier="all-minilm",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if "bert" in m.details.family:
|
||||
continue
|
||||
|
||||
models = []
|
||||
for m in response.models:
|
||||
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
|
||||
if model_type == ModelType.embedding:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
||||
logger.debug(f"ollama refreshed model list ({len(models)} models)")
|
||||
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
|
|
@ -223,12 +227,7 @@ class OllamaInferenceAdapter(
|
|||
return available_models
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
|
||||
logger.debug("ollama cancelling background refresh task")
|
||||
self._refresh_task.cancel()
|
||||
|
||||
self._client = None
|
||||
self._openai_client = None
|
||||
self._clients.clear()
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
|
|||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.OPENAI_API_KEY:=}",
|
||||
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,11 +12,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
# the models w/ "openai/" prefix are the litellm specific model names.
|
||||
# they should be deprecated in favor of the canonical openai model names.
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openai/chatgpt-4o-latest",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
|
|
@ -43,8 +38,6 @@ class EmbeddingModelInfo:
|
|||
|
||||
|
||||
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
|
||||
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,23 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI, NotFoundError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
|
@ -30,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
#
|
||||
# This OpenAI adapter implements Inference methods using two clients -
|
||||
# This OpenAI adapter implements Inference methods using two mixins -
|
||||
#
|
||||
# | Inference Method | Implementation Source |
|
||||
# |----------------------------|--------------------------|
|
||||
|
|
@ -39,15 +25,27 @@ logger = logging.getLogger(__name__)
|
|||
# | embedding | LiteLLMOpenAIMixin |
|
||||
# | batch_completion | LiteLLMOpenAIMixin |
|
||||
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
||||
# | openai_completion | AsyncOpenAI |
|
||||
# | openai_chat_completion | AsyncOpenAI |
|
||||
# | openai_embeddings | AsyncOpenAI |
|
||||
# | openai_completion | OpenAIMixin |
|
||||
# | openai_chat_completion | OpenAIMixin |
|
||||
# | openai_embeddings | OpenAIMixin |
|
||||
#
|
||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
"""
|
||||
OpenAI Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="openai",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
|
|
@ -60,191 +58,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Check if a specific model is available from OpenAI.
|
||||
Get the OpenAI API base URL.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
try:
|
||||
openai_client = self._get_openai_client()
|
||||
retrieved_model = await openai_client.models.retrieve(model)
|
||||
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
|
||||
return True
|
||||
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available from OpenAI")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability from OpenAI: {e}")
|
||||
return False
|
||||
return self.config.base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if guided_choice is not None:
|
||||
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||
if prompt_logprobs is not None:
|
||||
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
|
||||
# Prepare parameters for OpenAI embeddings API
|
||||
params = {
|
||||
"model": model_id,
|
||||
"input": input,
|
||||
}
|
||||
|
||||
if encoding_format is not None:
|
||||
params["encoding_format"] = encoding_format
|
||||
if dimensions is not None:
|
||||
params["dimensions"] = dimensions
|
||||
if user is not None:
|
||||
params["user"] = user
|
||||
|
||||
# Call OpenAI embeddings API
|
||||
response = await self._get_openai_client().embeddings.create(**params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class SambaNovaImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -9,49 +9,20 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-405B-Instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.2-1B-Instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.2-3B-Instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.3-70B-Instruct",
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-3.2-11B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-3.2-90B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -182,6 +182,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="sambanova",
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import SambaNovaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .sambanova import SambaNovaCompatInferenceAdapter
|
||||
|
||||
adapter = SambaNovaCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for SambaNova models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.sambanova_openai_compat.config import SambaNovaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..sambanova.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: SambaNovaCompatConfig
|
||||
|
||||
def __init__(self, config: SambaNovaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -19,7 +19,7 @@ class TGIImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.TGI_URL}",
|
||||
url: str = "${env.TGI_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -305,6 +305,8 @@ class _HfAdapter(
|
|||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url,
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherImplConfig(BaseModel):
|
||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
|
|
@ -26,5 +27,5 @@ class TogetherImplConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,15 +69,9 @@ MODEL_ENTRIES = [
|
|||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
additional_aliases=[
|
||||
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
],
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
additional_aliases=[
|
||||
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
],
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import TogetherCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .together import TogetherCompatInferenceAdapter
|
||||
|
||||
adapter = TogetherCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Together models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Together API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.together.xyz/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.together_openai_compat.config import TogetherCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..together.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class TogetherCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: TogetherCompatConfig
|
||||
|
||||
def __init__(self, config: TogetherCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="together_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
refresh_models_interval: int = Field(
|
||||
default=300,
|
||||
description="Interval in seconds to refresh models",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
|
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
model_store: ModelStore | None = None
|
||||
_refresh_task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
|
|
@ -302,64 +300,32 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.url:
|
||||
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
|
||||
# or available in distributions like "starter" without causing a ruckus
|
||||
return
|
||||
raise ValueError(
|
||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||
)
|
||||
|
||||
if self.config.refresh_models:
|
||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||
elif task.exception():
|
||||
# print the stack trace for the exception
|
||||
exc = task.exception()
|
||||
log.error(f"vLLM background refresh task died: {exc}")
|
||||
traceback.print_exception(exc)
|
||||
else:
|
||||
log.error("vLLM background refresh task completed unexpectedly")
|
||||
|
||||
self._refresh_task.add_done_callback(cb)
|
||||
|
||||
async def _refresh_models(self) -> None:
|
||||
provider_id = self.__provider_id__
|
||||
waited_time = 0
|
||||
while not self.model_store and waited_time < 60:
|
||||
await asyncio.sleep(1)
|
||||
waited_time += 1
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set after waiting 60 seconds")
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None # mypy
|
||||
while True:
|
||||
try:
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
||||
log.debug(f"vLLM refreshed model list ({len(models)} models)")
|
||||
except Exception as e:
|
||||
log.error(f"vLLM background refresh task failed: {e}")
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
@ -374,9 +340,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
if not self.config.url:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
|
||||
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
|
|
@ -392,11 +355,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if self.client is not None:
|
||||
return
|
||||
|
||||
if not self.config.url:
|
||||
raise ValueError(
|
||||
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
|
||||
)
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class SambaNovaSafetyConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue