mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 04:08:04 +00:00
Merge branch 'main' into feature/dpo-training
This commit is contained in:
commit
b68b818539
265 changed files with 10254 additions and 7796 deletions
|
|
@ -4,15 +4,83 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DynamicApiMeta(EnumMeta):
|
||||
def __new__(cls, name, bases, namespace):
|
||||
# Store the original enum values
|
||||
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
|
||||
|
||||
# Create the enum class
|
||||
cls = super().__new__(cls, name, bases, namespace)
|
||||
|
||||
# Store the original values for reference
|
||||
cls._original_values = original_values
|
||||
# Initialize _dynamic_values
|
||||
cls._dynamic_values = {}
|
||||
|
||||
return cls
|
||||
|
||||
def __call__(cls, value):
|
||||
try:
|
||||
return super().__call__(value)
|
||||
except ValueError as e:
|
||||
# If this value was already dynamically added, return it
|
||||
if value in cls._dynamic_values:
|
||||
return cls._dynamic_values[value]
|
||||
|
||||
# If the value doesn't exist, create a new enum member
|
||||
# Create a new member name from the value
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return the existing member
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
|
||||
# register new APIs explicitly
|
||||
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
|
||||
|
||||
def __iter__(cls):
|
||||
# Allow iteration over both static and dynamic members
|
||||
yield from super().__iter__()
|
||||
if hasattr(cls, "_dynamic_values"):
|
||||
yield from cls._dynamic_values.values()
|
||||
|
||||
def add(cls, value):
|
||||
"""
|
||||
Add a new API to the enum.
|
||||
Used to register external APIs.
|
||||
"""
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return it
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Create a new enum member
|
||||
member = object.__new__(cls)
|
||||
member._name_ = member_name
|
||||
member._value_ = value
|
||||
|
||||
# Add it to the enum class
|
||||
cls._member_map_[member_name] = member
|
||||
cls._member_names_.append(member_name)
|
||||
cls._member_type_ = str
|
||||
|
||||
# Store it in our dynamic values
|
||||
cls._dynamic_values[value] = member
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
class Api(Enum, metaclass=DynamicApiMeta):
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
|
|
@ -54,3 +122,12 @@ class Error(BaseModel):
|
|||
title: str
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
class ExternalApiSpec(BaseModel):
|
||||
"""Specification for an external API implementation."""
|
||||
|
||||
module: str = Field(..., description="Python module containing the API implementation")
|
||||
name: str = Field(..., description="Name of the API")
|
||||
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
|
||||
protocol: str = Field(..., description="Name of the protocol class for the API")
|
||||
|
|
|
|||
|
|
@ -455,8 +455,21 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
|||
image_url: OpenAIImageURL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFileFile(BaseModel):
|
||||
file_data: str | None = None
|
||||
file_id: str | None = None
|
||||
filename: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFile(BaseModel):
|
||||
type: Literal["file"] = "file"
|
||||
file: OpenAIFileFile
|
||||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
|
@ -464,6 +477,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
|
|||
|
||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||
|
||||
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
|
|
@ -489,7 +504,7 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
|
@ -518,7 +533,7 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: OpenAIChatCompletionMessageContent | None = None
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
|
@ -534,7 +549,7 @@ class OpenAIToolMessageParam(BaseModel):
|
|||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -547,7 +562,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
# Add this constant near the top of the file, after the imports
|
||||
DEFAULT_TTL_DAYS = 7
|
||||
|
||||
REQUIRED_SCOPE = "telemetry.read"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
|
|
@ -259,7 +261,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
|
|
@ -277,7 +279,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
||||
|
|
@ -286,7 +288,9 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
|
||||
)
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
||||
|
|
@ -296,7 +300,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
|
@ -312,7 +316,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
|
@ -345,7 +349,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
|
@ -88,7 +88,7 @@ class RAGQueryGenerator(Enum):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class RAGSearchMode(Enum):
|
||||
class RAGSearchMode(StrEnum):
|
||||
"""
|
||||
Search modes for RAG query retrieval:
|
||||
- VECTOR: Uses vector similarity search for semantic matching
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class VectorDBInput(BaseModel):
|
|||
vector_db_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_id: str | None = None
|
||||
provider_vector_db_id: str | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -338,7 +338,7 @@ class VectorIO(Protocol):
|
|||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
name: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
|
|
|
|||
|
|
@ -31,11 +31,13 @@ from llama_stack.distribution.build import (
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildConfig,
|
||||
BuildProvider,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
|
|
@ -93,7 +95,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
elif args.providers:
|
||||
providers_list: dict[str, str | list[str]] = dict()
|
||||
provider_list: dict[str, list[BuildProvider]] = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
|
|
@ -102,7 +104,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider = api_provider.split("=")
|
||||
api, provider_type = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
|
|
@ -111,16 +113,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider in providers_for_api:
|
||||
if api not in providers_list:
|
||||
providers_list[api] = []
|
||||
# Use type guarding to ensure we have a list
|
||||
provider_value = providers_list[api]
|
||||
if isinstance(provider_value, list):
|
||||
provider_value.append(provider)
|
||||
else:
|
||||
# Convert string to list and append
|
||||
providers_list[api] = [provider_value, provider]
|
||||
if provider_type in providers_for_api:
|
||||
provider = BuildProvider(
|
||||
provider_type=provider_type,
|
||||
module=None,
|
||||
)
|
||||
provider_list.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
|
|
@ -129,7 +127,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
distribution_spec = DistributionSpec(
|
||||
providers=providers_list,
|
||||
providers=provider_list,
|
||||
description=",".join(args.providers),
|
||||
)
|
||||
if not args.image_type:
|
||||
|
|
@ -190,7 +188,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
|
||||
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
|
||||
|
||||
providers: dict[str, str | list[str]] = dict()
|
||||
providers: dict[str, list[BuildProvider]] = dict()
|
||||
for api, providers_for_api in get_provider_registry().items():
|
||||
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
|
||||
if not available_providers:
|
||||
|
|
@ -205,7 +203,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
),
|
||||
)
|
||||
|
||||
providers[api.value] = api_provider
|
||||
string_providers = api_provider.split(" ")
|
||||
|
||||
for provider in string_providers:
|
||||
providers.setdefault(api.value, []).append(BuildProvider(provider_type=provider))
|
||||
|
||||
description = prompt(
|
||||
"\n > (Optional) Enter a short description for your Llama Stack: ",
|
||||
|
|
@ -236,11 +237,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
|
||||
if args.print_deps_only:
|
||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
print(f"uv pip install {' '.join(normal_deps)}")
|
||||
for special_dep in special_deps:
|
||||
print(f"uv pip install {special_dep}")
|
||||
for external_dep in external_provider_dependencies:
|
||||
print(f"uv pip install {external_dep}")
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -276,8 +279,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if config.external_providers_dir and not config.external_providers_dir.exists():
|
||||
config.external_providers_dir.mkdir(exist_ok=True)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||
run_args = formulate_run_args(args.image_type, args.image_name)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
|
||||
run_command(run_args)
|
||||
|
||||
|
||||
|
|
@ -303,27 +306,25 @@ def _generate_run_config(
|
|||
provider_registry = get_provider_registry(build_config)
|
||||
for api in apis:
|
||||
run_config.providers[api] = []
|
||||
provider_types = build_config.distribution_spec.providers[api]
|
||||
if isinstance(provider_types, str):
|
||||
provider_types = [provider_types]
|
||||
providers = build_config.distribution_spec.providers[api]
|
||||
|
||||
for i, provider_type in enumerate(provider_types):
|
||||
pid = provider_type.split("::")[-1]
|
||||
for provider in providers:
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
|
||||
p = provider_registry[Api(api)][provider_type]
|
||||
p = provider_registry[Api(api)][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
try:
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
except ModuleNotFoundError:
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
|
||||
except (ModuleNotFoundError, ValueError) as exc:
|
||||
# HACK ALERT:
|
||||
# This code executes after building is done, the import cannot work since the
|
||||
# package is either available in the venv or container - not available on the host.
|
||||
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
|
||||
# external
|
||||
cprint(
|
||||
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||
f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
|
@ -336,9 +337,10 @@ def _generate_run_config(
|
|||
config = {}
|
||||
|
||||
p_spec = Provider(
|
||||
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
|
||||
provider_type=provider_type,
|
||||
provider_id=pid,
|
||||
provider_type=provider.provider_type,
|
||||
config=config,
|
||||
module=provider.module,
|
||||
)
|
||||
run_config.providers[api].append(p_spec)
|
||||
|
||||
|
|
@ -401,9 +403,32 @@ def _run_stack_build_command_from_build_config(
|
|||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||
|
||||
with open(build_file_path, "w") as f:
|
||||
to_write = json.loads(build_config.model_dump_json())
|
||||
to_write = json.loads(build_config.model_dump_json(exclude_none=True))
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
# We first install the external APIs so that the build process can use them and discover the
|
||||
# providers dependencies
|
||||
if build_config.external_apis_dir:
|
||||
cprint("Installing external APIs", color="yellow", file=sys.stderr)
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
# install the external APIs
|
||||
packages = []
|
||||
for _, api_spec in external_apis.items():
|
||||
if api_spec.pip_packages:
|
||||
packages.extend(api_spec.pip_packages)
|
||||
cprint(
|
||||
f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return_code = run_command(["uv", "pip", "install", *packages])
|
||||
if return_code != 0:
|
||||
packages_str = ", ".join(packages)
|
||||
raise RuntimeError(
|
||||
f"Failed to install external APIs packages: {packages_str} (return code: {return_code})"
|
||||
)
|
||||
|
||||
return_code = build_image(
|
||||
build_config,
|
||||
build_file_path,
|
||||
|
|
|
|||
|
|
@ -82,39 +82,6 @@ class StackRun(Subcommand):
|
|||
return ImageType.CONDA.value, args.image_name
|
||||
return args.image_type, args.image_name
|
||||
|
||||
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and template name from args.config"""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, template_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
|
|
@ -125,8 +92,15 @@ class StackRun(Subcommand):
|
|||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
|
||||
# Resolve config file and template name first
|
||||
config_file, template_name = self._resolve_config_and_template(args)
|
||||
if args.config:
|
||||
try:
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
|
||||
config_file = resolve_config_or_template(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
# Check if config is required based on image type
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
|
||||
|
|
@ -164,18 +138,14 @@ class StackRun(Subcommand):
|
|||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config":
|
||||
if template_name:
|
||||
server_args.template = str(template_name)
|
||||
else:
|
||||
# Set the config file path
|
||||
server_args.config = str(config_file)
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name, config, template_name)
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
|
|
|
|||
48
llama_stack/cli/utils.py
Normal file
48
llama_stack/cli/utils.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
def add_config_template_args(parser: argparse.ArgumentParser):
|
||||
"""Add unified config/template arguments with backward compatibility."""
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
group.add_argument(
|
||||
"config",
|
||||
nargs="?",
|
||||
help="Configuration file path or template name",
|
||||
)
|
||||
|
||||
# Backward compatibility arguments (deprecated)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
dest="config_deprecated",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--template",
|
||||
dest="template_deprecated",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Template name",
|
||||
)
|
||||
|
||||
|
||||
def get_config_from_args(args: argparse.Namespace) -> str | None:
|
||||
"""Extract config value from parsed arguments, handling both new and deprecated forms."""
|
||||
if args.config is not None:
|
||||
return str(args.config)
|
||||
elif hasattr(args, "config_deprecated") and args.config_deprecated is not None:
|
||||
logger.warning("Using deprecated --config argument. Use positional argument [config] instead.")
|
||||
return str(args.config_deprecated)
|
||||
elif hasattr(args, "template_deprecated") and args.template_deprecated is not None:
|
||||
logger.warning("Using deprecated --template argument. Use positional argument [config] instead.")
|
||||
return str(args.template_deprecated)
|
||||
return None
|
||||
|
|
@ -14,6 +14,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.utils.exec import run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -41,7 +42,7 @@ class ApiInput(BaseModel):
|
|||
|
||||
def get_provider_dependencies(
|
||||
config: BuildConfig | DistributionTemplate,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> tuple[list[str], list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
if isinstance(config, DistributionTemplate):
|
||||
config = config.build_config()
|
||||
|
|
@ -50,6 +51,7 @@ def get_provider_dependencies(
|
|||
additional_pip_packages = config.additional_pip_packages
|
||||
|
||||
deps = []
|
||||
external_provider_deps = []
|
||||
registry = get_provider_registry(config)
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
providers_for_api = registry[Api(api_str)]
|
||||
|
|
@ -64,8 +66,16 @@ def get_provider_dependencies(
|
|||
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
|
||||
|
||||
provider_spec = providers_for_api[provider_type]
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
if provider_spec.container_image:
|
||||
if hasattr(provider_spec, "is_external") and provider_spec.is_external:
|
||||
# this ensures we install the top level module for our external providers
|
||||
if provider_spec.module:
|
||||
if isinstance(provider_spec.module, str):
|
||||
external_provider_deps.append(provider_spec.module)
|
||||
else:
|
||||
external_provider_deps.extend(provider_spec.module)
|
||||
if hasattr(provider_spec, "pip_packages"):
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
|
||||
raise ValueError("A stack's dependencies cannot have a container image")
|
||||
|
||||
normal_deps = []
|
||||
|
|
@ -78,7 +88,7 @@ def get_provider_dependencies(
|
|||
|
||||
normal_deps.extend(additional_pip_packages or [])
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps))
|
||||
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(config: BuildConfig):
|
||||
|
|
@ -103,41 +113,59 @@ def build_image(
|
|||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
|
||||
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
if build_config.external_apis_dir:
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
"--template-or-config",
|
||||
template_or_config,
|
||||
"--image-name",
|
||||
image_name,
|
||||
"--container-base",
|
||||
container_base,
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.append(run_config)
|
||||
args.extend(["--run-config", run_config])
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--build-file-path",
|
||||
str(build_file_path),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# Always pass both arguments, even if empty, to maintain consistent positional arguments
|
||||
if special_deps:
|
||||
args.append("#".join(special_deps))
|
||||
args.extend(["--optional-deps", "#".join(special_deps)])
|
||||
if external_provider_deps:
|
||||
args.extend(
|
||||
["--external-provider-deps", "#".join(external_provider_deps)]
|
||||
) # the script will install external provider module, get its deps, and install those too.
|
||||
|
||||
return_code = run_command(args)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,91 @@
|
|||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --env-name <conda_env_name> --build-file-path <build_file_path> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
env_name=""
|
||||
build_file_path=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--env-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --env-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
env_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--build-file-path)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --build-file-path requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
build_file_path="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
|
|
@ -20,50 +101,18 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
|||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 3 ]; then
|
||||
echo "Usage: $0 <distribution_type> <conda_env_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 <distribution_type> my-conda-env ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$4"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
env_name="$1"
|
||||
build_file_path="$2"
|
||||
pip_dependencies="$3"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
ensure_conda_env_python310() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
# Use only global variables set by flag parser
|
||||
local python_version="3.12"
|
||||
|
||||
# Check if conda command is available
|
||||
if ! is_command_available conda; then
|
||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if the environment exists
|
||||
if conda env list | grep -q "^${env_name} "; then
|
||||
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
|
||||
|
||||
# Check Python version in the environment
|
||||
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
|
||||
|
||||
if [ "$current_version" = "$python_version" ]; then
|
||||
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
|
||||
else
|
||||
|
|
@ -73,37 +122,37 @@ ensure_conda_env_python310() {
|
|||
else
|
||||
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
|
||||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
|
||||
ENVNAME="${env_name}"
|
||||
# setup_cleanup_handlers
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "${env_name}"
|
||||
|
||||
"$CONDA_PREFIX"/bin/pip install uv
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
uv pip install fastapi libcst
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-stack=="$TEST_PYPI_VERSION" \
|
||||
"$pip_dependencies"
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
"$normal_deps"
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install "$part"
|
||||
done
|
||||
fi
|
||||
else
|
||||
# Re-installing llama-stack in the new conda environment
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
else
|
||||
|
|
@ -115,31 +164,44 @@ ensure_conda_env_python310() {
|
|||
fi
|
||||
uv pip install --no-cache-dir "$SPEC_VERSION"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
printf "Installing pip dependencies\n"
|
||||
uv pip install $pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
uv pip install $normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "Getting provider spec for module: $part and installing dependencies"
|
||||
package_name=$(echo "$part" | sed 's/[<>=!].*//')
|
||||
python3 -c "
|
||||
import importlib
|
||||
import sys
|
||||
try:
|
||||
module = importlib.import_module(f'$package_name.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
print('\\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
|
||||
" | uv pip install -r -
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
||||
mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml
|
||||
echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml"
|
||||
}
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||
ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps"
|
||||
|
|
|
|||
|
|
@ -18,58 +18,108 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
|
||||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
# Path to the run.yaml file in the container
|
||||
RUN_CONFIG_PATH=/app/run.yaml
|
||||
|
||||
BUILD_CONTEXT_DIR=$(pwd)
|
||||
|
||||
if [ "$#" -lt 4 ]; then
|
||||
# This only works for templates
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
|
||||
exit 1
|
||||
fi
|
||||
set -euo pipefail
|
||||
|
||||
template_or_config="$1"
|
||||
shift
|
||||
image_name="$1"
|
||||
shift
|
||||
container_base="$1"
|
||||
shift
|
||||
pip_dependencies="$1"
|
||||
shift
|
||||
|
||||
# Handle optional arguments
|
||||
run_config=""
|
||||
special_pip_deps=""
|
||||
|
||||
# Check if there are more arguments
|
||||
# The logics is becoming cumbersom, we should refactor it if we can do better
|
||||
if [ $# -gt 0 ]; then
|
||||
# Check if the argument ends with .yaml
|
||||
if [[ "$1" == *.yaml ]]; then
|
||||
run_config="$1"
|
||||
shift
|
||||
# If there's another argument after .yaml, it must be special_pip_deps
|
||||
if [ $# -gt 0 ]; then
|
||||
special_pip_deps="$1"
|
||||
fi
|
||||
else
|
||||
# If it's not .yaml, it must be special_pip_deps
|
||||
special_pip_deps="$1"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --image-name <image_name> --container-base <container_base> --normal-deps <pip_dependencies> [--run-config <run_config>] [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
image_name=""
|
||||
container_base=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
run_config=""
|
||||
template_or_config=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--image-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --image-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
image_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--container-base)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --container-base requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
container_base="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--run-config)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --run-config requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
run_config="$2"
|
||||
shift 2
|
||||
;;
|
||||
--template-or-config)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --template-or-config requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
template_or_config="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --image-name, --container-base, and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
|
|
@ -78,18 +128,15 @@ add_to_container() {
|
|||
if [ -t 0 ]; then
|
||||
printf '%s\n' "$1" >>"$output_file"
|
||||
else
|
||||
# If stdin is not a terminal, read from it (heredoc)
|
||||
cat >>"$output_file"
|
||||
fi
|
||||
}
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available "$CONTAINER_BINARY"; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Update and install UBI9 components if UBI9 base image is used
|
||||
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||
add_to_container << EOF
|
||||
FROM $container_base
|
||||
|
|
@ -127,22 +174,52 @@ fi
|
|||
|
||||
# Add pip dependencies first since llama-stack is what will change most often
|
||||
# so we can reuse layers.
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
if [ -n "$normal_deps" ]; then
|
||||
read -ra pip_args <<< "$normal_deps"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache $pip_dependencies
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN uv pip install --no-cache $part
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
add_to_container <<EOF
|
||||
RUN python3 - <<PYTHON | uv pip install --no-cache -r -
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
try:
|
||||
package_name = '$part'.split('==')[0].split('>=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0]
|
||||
module = importlib.import_module(f'{package_name}.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
if isinstance(spec.pip_packages, (list, tuple)):
|
||||
print('\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr)
|
||||
PYTHON
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
||||
# Function to get Python command
|
||||
get_python_cmd() {
|
||||
if is_command_available python; then
|
||||
echo "python"
|
||||
|
|
@ -222,7 +299,7 @@ else
|
|||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
add_to_container << EOF
|
||||
RUN uv pip install fastapi libcst
|
||||
RUN uv pip install --no-cache fastapi libcst
|
||||
EOF
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
|
|
@ -328,7 +405,7 @@ $CONTAINER_BINARY build \
|
|||
"$BUILD_CONTEXT_DIR"
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
rm -rf "$BUILD_CONTEXT_DIR/run.yaml" "$TEMP_DIR"
|
||||
set +x
|
||||
|
||||
echo "Success!"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,76 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
||||
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --env-name <env_name> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
env_name=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--env-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --env-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
env_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$env_name" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --env-name and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
|
|
@ -25,29 +95,6 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
|||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 2 ]; then
|
||||
echo "Usage: $0 <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$3"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
env_name="$1"
|
||||
pip_dependencies="$2"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# pre-run checks to make sure we can proceed with the installation
|
||||
pre_run_checks() {
|
||||
local env_name="$1"
|
||||
|
|
@ -71,49 +118,44 @@ pre_run_checks() {
|
|||
}
|
||||
|
||||
run() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
|
||||
# Use only global variables set by flag parser
|
||||
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
|
||||
echo "Installing dependencies in system Python environment"
|
||||
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
|
||||
export UV_SYSTEM_PYTHON=1
|
||||
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
|
||||
echo "Virtual environment $env_name is already active"
|
||||
else
|
||||
echo "Using virtual environment $env_name"
|
||||
uv venv "$env_name"
|
||||
# shellcheck source=/dev/null
|
||||
source "$env_name/bin/activate"
|
||||
fi
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
uv pip install fastapi libcst
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-stack=="$TEST_PYPI_VERSION" \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
$normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install "$part"
|
||||
done
|
||||
fi
|
||||
else
|
||||
# Re-installing llama-stack in the new virtual environment
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
else
|
||||
|
|
@ -125,27 +167,41 @@ run() {
|
|||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
printf "Installing pip dependencies\n"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
uv pip install $pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
uv pip install $normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
# shellcheck disable=SC2086
|
||||
# we are building a command line so word splitting is expected
|
||||
echo "Installing special provider module: $part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "Installing external provider module: $part"
|
||||
uv pip install "$part"
|
||||
echo "Getting provider spec for module: $part and installing dependencies"
|
||||
package_name=$(echo "$part" | sed 's/[<>=!].*//')
|
||||
python3 -c "
|
||||
import importlib
|
||||
import sys
|
||||
try:
|
||||
module = importlib.import_module(f'$package_name.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
print('\\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
|
||||
" | uv pip install -r -
|
||||
done
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
pre_run_checks "$env_name"
|
||||
run "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||
run
|
||||
|
|
|
|||
|
|
@ -91,21 +91,22 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
|
||||
logger.info(f"Configuring API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
for i, provider in enumerate(plist):
|
||||
if i >= 1:
|
||||
others = ", ".join(plist[i:])
|
||||
others = ", ".join(p.provider_type for p in plist[i:])
|
||||
logger.info(
|
||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"> Configuring provider `({provider_type})`")
|
||||
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
|
||||
provider_type=provider_type,
|
||||
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
|
||||
provider_type=provider.provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
|||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class RegistryEntrySource(StrEnum):
|
||||
via_register_api = "via_register_api"
|
||||
listed_from_provider = "listed_from_provider"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
|
|
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
|
|||
resource. This can be used to constrain access to the resource."""
|
||||
|
||||
owner: User | None = None
|
||||
source: RegistryEntrySource = RegistryEntrySource.via_register_api
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
|
|
@ -130,29 +136,54 @@ class RoutingTableProviderSpec(ProviderSpec):
|
|||
pip_packages: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
# provider_id of None means that the provider is not enabled - this happens
|
||||
# when the provider is enabled via a conditional environment variable
|
||||
provider_id: str | None
|
||||
provider_type: str
|
||||
config: dict[str, Any] = {}
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class BuildProvider(BaseModel):
|
||||
provider_type: str
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class DistributionSpec(BaseModel):
|
||||
description: str | None = Field(
|
||||
default="",
|
||||
description="Description of the distribution",
|
||||
)
|
||||
container_image: str | None = None
|
||||
providers: dict[str, str | list[str]] = Field(
|
||||
providers: dict[str, list[BuildProvider]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.""",
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
# provider_id of None means that the provider is not enabled - this happens
|
||||
# when the provider is enabled via a conditional environment variable
|
||||
provider_id: str | None
|
||||
provider_type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
category_levels: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
|
|
@ -381,6 +412,11 @@ a default SQLite store will be used.""",
|
|||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
|
|
@ -412,6 +448,10 @@ class BuildConfig(BaseModel):
|
|||
default_factory=list,
|
||||
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
||||
)
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from typing import Any
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
|
@ -96,12 +98,10 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
|
|||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(
|
||||
config=None,
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files.
|
||||
This function loads both built-in providers and external providers from YAML files or from their provided modules.
|
||||
External providers are loaded from a directory structure like:
|
||||
|
||||
providers.d/
|
||||
|
|
@ -122,8 +122,13 @@ def get_provider_registry(
|
|||
safety/
|
||||
llama-guard.yaml
|
||||
|
||||
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
|
||||
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
||||
There is special handling for all of the potential cases this method can be called from.
|
||||
|
||||
Args:
|
||||
config: Optional object containing the external providers directory path
|
||||
building: Optional bool delineating whether or not this is being called from a build process
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
|
@ -133,58 +138,140 @@ def get_provider_registry(
|
|||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
ret: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
logger.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
# Check if config has the external_providers_dir attribute
|
||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
# Refresh providable APIs with external APIs if any
|
||||
external_apis = load_external_apis(config)
|
||||
for api, api_spec in external_apis.items():
|
||||
name = api_spec.name.lower()
|
||||
logger.info(f"Importing external API {name} module {api_spec.module}")
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except (ImportError, AttributeError) as e:
|
||||
# Populate the registry with an empty dict to avoid breaking the provider registry
|
||||
# This assume that the in-tree provider(s) are not available for this API which means
|
||||
# that users will need to use external providers for this API.
|
||||
registry[api] = {}
|
||||
logger.error(
|
||||
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
||||
"Install the API package to load any in-tree providers for this API."
|
||||
)
|
||||
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
# Check if config has external providers
|
||||
if config:
|
||||
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
registry = get_external_providers_from_dir(registry, config)
|
||||
# else lets check for modules in each provider
|
||||
registry = get_external_providers_from_module(
|
||||
registry=registry,
|
||||
config=config,
|
||||
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
|
||||
)
|
||||
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
return registry
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
def get_external_providers_from_dir(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
logger.warning(
|
||||
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
||||
)
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in ret[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
ret[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
return ret
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in registry[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
registry[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_external_providers_from_module(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
provider_list = None
|
||||
if isinstance(config, BuildConfig):
|
||||
provider_list = config.distribution_spec.providers.items()
|
||||
else:
|
||||
provider_list = config.providers.items()
|
||||
if provider_list is None:
|
||||
logger.warning("Could not get list of providers from config")
|
||||
return registry
|
||||
for provider_api, providers in provider_list:
|
||||
for provider in providers:
|
||||
if not hasattr(provider, "module") or provider.module is None:
|
||||
continue
|
||||
# get provider using module
|
||||
try:
|
||||
if not building:
|
||||
package_name = provider.module.split("==")[0]
|
||||
module = importlib.import_module(f"{package_name}.provider")
|
||||
# if config class is wrong you will get an error saying module could not be imported
|
||||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
is_external=True,
|
||||
module=provider.module,
|
||||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# return a partially filled out spec that the build script will populate.
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
) from exc
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
||||
raise e
|
||||
return registry
|
||||
|
|
|
|||
54
llama_stack/distribution/external.py
Normal file
54
llama_stack/distribution/external.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||
"""Load external API specifications from the configured directory.
|
||||
|
||||
Args:
|
||||
config: StackRunConfig or BuildConfig containing the external APIs directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping API names to their specifications
|
||||
"""
|
||||
if not config or not config.external_apis_dir:
|
||||
return {}
|
||||
|
||||
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
||||
if not external_apis_dir.is_dir():
|
||||
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||
return {}
|
||||
|
||||
logger.info(f"Loading external APIs from {external_apis_dir}")
|
||||
external_apis: dict[Api, ExternalApiSpec] = {}
|
||||
|
||||
# Look for YAML files in the external APIs directory
|
||||
for yaml_path in external_apis_dir.glob("*.yaml"):
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
spec = ExternalApiSpec(**spec_data)
|
||||
api = Api.add(spec.name)
|
||||
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||
external_apis[api] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to load external API spec from {yaml_path}")
|
||||
raise
|
||||
|
||||
return external_apis
|
||||
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
|
|||
VersionInfo,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
|
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
|
|||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
ret = []
|
||||
all_endpoints = get_all_api_routes()
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
|
|
@ -53,7 +55,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e in endpoints
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
@ -66,7 +69,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ import os
|
|||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import Response as FastAPIResponse
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
APIResponse,
|
||||
|
|
@ -31,7 +33,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
|
|
@ -112,6 +114,27 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||
|
||||
|
||||
class LibraryClientUploadFile:
|
||||
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
|
||||
|
||||
def __init__(self, filename: str, content: bytes):
|
||||
self.filename = filename
|
||||
self.content = content
|
||||
self.content_type = "application/octet-stream"
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return self.content
|
||||
|
||||
|
||||
class LibraryClientHttpxResponse:
|
||||
"""LibraryClient httpx Response object for FastAPI Response conversion."""
|
||||
|
||||
def __init__(self, response):
|
||||
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
|
||||
self.status_code = response.status_code
|
||||
self.headers = response.headers
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -128,6 +151,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
self.skip_logger_removal = skip_logger_removal
|
||||
self.provider_data = provider_data
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
|
@ -136,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
if not self.skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
return asyncio.run(self.async_client.initialize())
|
||||
# use a new event loop to avoid interfering with the main event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(self.async_client.initialize())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
|
|
@ -149,10 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
# NOTE: We are using AsyncLlamaStackClient under the hood
|
||||
# A new event loop is needed to convert the AsyncStream
|
||||
# from async client into SyncStream return type for streaming
|
||||
loop = asyncio.new_event_loop()
|
||||
loop = self.loop
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if kwargs.get("stream"):
|
||||
|
|
@ -169,7 +197,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
|
||||
return sync_generator()
|
||||
else:
|
||||
|
|
@ -179,7 +206,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -223,15 +249,16 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
file=sys.stderr,
|
||||
)
|
||||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
# Convert Provider objects to their types
|
||||
provider_types: dict[str, str | list[str]] = {}
|
||||
for api, providers in self.config.providers.items():
|
||||
types = [p.provider_type for p in providers]
|
||||
# Convert single-item lists to strings
|
||||
provider_types[api] = types[0] if len(types) == 1 else types
|
||||
providers: dict[str, list[BuildProvider]] = {}
|
||||
for api, run_providers in self.config.providers.items():
|
||||
for provider in run_providers:
|
||||
providers.setdefault(api, []).append(
|
||||
BuildProvider(provider_type=provider.provider_type, module=provider.module)
|
||||
)
|
||||
providers = dict(providers)
|
||||
build_config = BuildConfig(
|
||||
distribution_spec=DistributionSpec(
|
||||
providers=provider_types,
|
||||
providers=providers,
|
||||
),
|
||||
external_providers_dir=self.config.external_providers_dir,
|
||||
)
|
||||
|
|
@ -295,6 +322,31 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return response
|
||||
|
||||
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
|
||||
"""Handle file uploads from OpenAI client and add them to the request body."""
|
||||
if not (hasattr(options, "files") and options.files):
|
||||
return body, []
|
||||
|
||||
if not isinstance(options.files, list):
|
||||
return body, []
|
||||
|
||||
field_names = []
|
||||
for file_tuple in options.files:
|
||||
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
|
||||
continue
|
||||
|
||||
field_name = file_tuple[0]
|
||||
file_object = file_tuple[1]
|
||||
|
||||
if isinstance(file_object, BytesIO):
|
||||
file_object.seek(0)
|
||||
file_content = file_object.read()
|
||||
filename = getattr(file_object, "name", "uploaded_file")
|
||||
field_names.append(field_name)
|
||||
body[field_name] = LibraryClientUploadFile(filename, file_content)
|
||||
|
||||
return body, field_names
|
||||
|
||||
async def _call_non_streaming(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -308,17 +360,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
body = self._convert_body(path, options.method, body)
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Handle FastAPI Response objects (e.g., from file content retrieval)
|
||||
if isinstance(result, FastAPIResponse):
|
||||
return LibraryClientHttpxResponse(result)
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
|
||||
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
|
|
@ -330,7 +392,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=convert_pydantic_to_json_value(body),
|
||||
json=convert_pydantic_to_json_value(filtered_body),
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
|
|
@ -356,12 +418,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
|
|
@ -392,8 +455,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]]
|
||||
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
|
|
@ -404,14 +468,18 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
|
||||
) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
func, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
|
|
@ -422,6 +490,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
if param_name in exclude_params:
|
||||
converted_body[param_name] = value
|
||||
else:
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
|
||||
return converted_body
|
||||
|
|
|
|||
|
|
@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None:
|
|||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__authenticated_user")
|
||||
|
||||
|
||||
def user_from_scope(scope: dict) -> User | None:
|
||||
"""Create a User object from ASGI scope data (set by authentication middleware)"""
|
||||
user_attributes = scope.get("user_attributes", {})
|
||||
principal = scope.get("principal", "")
|
||||
|
||||
# auth not enabled
|
||||
if not principal and not user_attributes:
|
||||
return None
|
||||
|
||||
return User(principal=principal, attributes=user_attributes)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
|
|||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
|
|
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def api_protocol_map() -> dict[Api, Any]:
|
||||
return {
|
||||
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
|
||||
"""Get a mapping of API types to their protocol classes.
|
||||
|
||||
Args:
|
||||
external_apis: Optional dictionary of external API specifications
|
||||
|
||||
Returns:
|
||||
Dictionary mapping API types to their protocol classes
|
||||
"""
|
||||
protocols = {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
|
|
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
|
|||
Api.files: Files,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
for api, api_spec in external_apis.items():
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
api_class = getattr(module, api_spec.protocol)
|
||||
|
||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||
protocols[api] = api_class
|
||||
except (ImportError, AttributeError):
|
||||
logger.exception(f"Failed to load external API {api_spec.name}")
|
||||
|
||||
return protocols
|
||||
|
||||
|
||||
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
|
||||
external_apis = load_external_apis(config)
|
||||
return {
|
||||
**api_protocol_map(),
|
||||
**api_protocol_map(external_apis),
|
||||
Api.inference: InferenceProvider,
|
||||
}
|
||||
|
||||
|
|
@ -250,7 +273,7 @@ async def instantiate_providers(
|
|||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
) -> dict:
|
||||
) -> dict[Api, Any]:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
|
|
@ -322,7 +345,7 @@ async def instantiate_provider(
|
|||
policy: list[AccessRule],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
|
|
@ -360,7 +383,7 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check()
|
||||
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
|
|
|
|||
|
|
@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
|
|||
logger.debug(
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
start_index=start_index,
|
||||
limit=limit,
|
||||
|
|
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
|
|||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
|
|||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
score_response = await provider.score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
|
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
|
|||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
score_response = await provider.score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
|
@ -97,7 +99,8 @@ class EvalRouter(Eval):
|
|||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
|
@ -110,7 +113,8 @@ class EvalRouter(Eval):
|
|||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
|
|
@ -123,7 +127,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
|
|
@ -131,7 +136,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> None:
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
await provider.job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
|
@ -142,7 +148,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ class InferenceRouter(Inference):
|
|||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
|
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_chat_completion(
|
||||
model_id=model_id,
|
||||
messages_batch=messages_batch,
|
||||
|
|
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
|
|||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
content=content,
|
||||
|
|
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||
|
||||
async def embeddings(
|
||||
|
|
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
|
|||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
|
|
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
|
|||
suffix=suffix,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_completion(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
|
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
|
|||
user=user,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if self.store:
|
||||
|
|
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
|
|||
user=user,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_embeddings(**params)
|
||||
|
||||
async def list_chat_completions(
|
||||
|
|
|
|||
|
|
@ -50,7 +50,8 @@ class SafetyRouter(Safety):
|
|||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
return await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
|
|
|
|||
|
|
@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_db_ids, query_config)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
|
|
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
|
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
|
|||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.query_chunks(vector_db_id, query, params)
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
|
|
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
|
|||
provider_vector_db_id=vector_db_id,
|
||||
vector_db_name=name,
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
|
||||
return await provider.openai_create_vector_store(
|
||||
name=name,
|
||||
file_ids=file_ids,
|
||||
expires_after=expires_after,
|
||||
|
|
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
|
|||
all_stores = []
|
||||
for vector_db in vector_dbs:
|
||||
try:
|
||||
vector_store = await self.routing_table.get_provider_impl(
|
||||
vector_db.identifier
|
||||
).openai_retrieve_vector_store(vector_db.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
||||
all_stores.append(vector_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||
|
|
@ -214,9 +216,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
|
|
@ -226,9 +226,7 @@ class VectorIORouter(VectorIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
return await self.routing_table.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
|
|
@ -240,12 +238,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
# drop from registry
|
||||
await self.routing_table.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
|
|
@ -258,9 +251,7 @@ class VectorIORouter(VectorIO):
|
|||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
return await self.routing_table.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
|
|
@ -278,9 +269,7 @@ class VectorIORouter(VectorIO):
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
|
@ -297,9 +286,7 @@ class VectorIORouter(VectorIO):
|
|||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
return await self.routing_table.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
|
|
@ -314,9 +301,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
@ -327,9 +312,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
@ -341,9 +324,7 @@ class VectorIORouter(VectorIO):
|
|||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
return await self.routing_table.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
|
@ -355,9 +336,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
return await self.routing_table.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import Action
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
|
|
@ -115,7 +117,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
async def refresh(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
from .benchmarks import BenchmarksRoutingTable
|
||||
from .datasets import DatasetsRoutingTable
|
||||
from .models import ModelsRoutingTable
|
||||
|
|
@ -204,11 +209,24 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def assert_action_allowed(
|
||||
self,
|
||||
action: Action,
|
||||
type: str,
|
||||
identifier: str,
|
||||
) -> None:
|
||||
"""Fetch a registered object by type/identifier and enforce the given action permission."""
|
||||
obj = await self.get_object_by_identifier(type, identifier)
|
||||
if obj is None:
|
||||
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
|
||||
user = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, action, obj, user):
|
||||
raise AccessDeniedError(action, obj, user)
|
||||
|
||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
|
@ -220,3 +238,28 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
||||
# first try to get the model by identifier
|
||||
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
||||
model = await routing_table.get_object_by_identifier("model", model_id)
|
||||
if model is not None:
|
||||
return model
|
||||
|
||||
logger.warning(
|
||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
||||
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
|
||||
)
|
||||
# if not found, this means model_id is an unscoped provider_model_id, we need
|
||||
# to iterate (given a lack of an efficient index on the KVStore)
|
||||
models = await routing_table.get_all_with_type("model")
|
||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||
if len(matching_models) == 0:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
||||
if len(matching_models) > 1:
|
||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||
|
||||
return matching_models[0]
|
||||
|
|
|
|||
|
|
@ -10,15 +10,37 @@ from typing import Any
|
|||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
listed_providers: set[str] = set()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
refresh = await provider.should_refresh_models()
|
||||
refresh = refresh or provider_id not in self.listed_providers
|
||||
if not refresh:
|
||||
continue
|
||||
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
if models is None:
|
||||
continue
|
||||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
|
|
@ -36,10 +58,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
return model
|
||||
return await lookup_model(self, model_id)
|
||||
|
||||
async def get_provider_impl(self, model_id: str) -> Any:
|
||||
model = await lookup_model(self, model_id)
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
|
@ -49,28 +72,38 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
|
||||
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
|
||||
provider_model_id = provider_model_id or model_id
|
||||
metadata = metadata or {}
|
||||
model_type = model_type or ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
|
||||
# an identifier different than provider_model_id implies it is an alias, so that
|
||||
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
||||
# so as a general rule we must use the provider_id to disambiguate.
|
||||
|
||||
if model_id != provider_model_id:
|
||||
identifier = model_id
|
||||
else:
|
||||
identifier = f"{provider_id}/{provider_model_id}"
|
||||
|
||||
model = ModelWithOwner(
|
||||
identifier=model_id,
|
||||
identifier=identifier,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
source=RegistryEntrySource.via_register_api,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
|
@ -80,3 +113,43 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
async def update_registered_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
) -> None:
|
||||
existing_models = await self.get_all_with_type("model")
|
||||
|
||||
# we may have an alias for the model registered by the user (or during initialization
|
||||
# from run.yaml) that we need to keep track of
|
||||
model_ids = {}
|
||||
for model in existing_models:
|
||||
if model.provider_id != provider_id:
|
||||
continue
|
||||
if model.source == RegistryEntrySource.via_register_api:
|
||||
model_ids[model.provider_resource_id] = model.identifier
|
||||
continue
|
||||
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
|
||||
for model in models:
|
||||
if model.provider_resource_id in model_ids:
|
||||
# avoid overwriting a non-provider-registered model entry
|
||||
continue
|
||||
|
||||
if model.identifier == model.provider_resource_id:
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
ModelWithOwner(
|
||||
identifier=model.identifier,
|
||||
provider_resource_id=model.provider_resource_id,
|
||||
provider_id=provider_id,
|
||||
metadata=model.metadata,
|
||||
model_type=model.model_type,
|
||||
source=RegistryEntrySource.listed_from_provider,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
if routing_key in self.tool_to_toolgroup:
|
||||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return super().get_provider_impl(routing_key, provider_id)
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
if toolgroup_id:
|
||||
|
|
@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
return ListToolsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
|
|
|
|||
|
|
@ -4,17 +4,30 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -38,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
|
@ -49,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding:
|
||||
|
|
@ -74,3 +86,145 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if existing_vector_db is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,12 @@
|
|||
import json
|
||||
|
||||
import httpx
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, User
|
||||
from llama_stack.distribution.request_headers import user_from_scope
|
||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
|
@ -78,12 +81,14 @@ class AuthenticationMiddleware:
|
|||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_config: AuthenticationConfig):
|
||||
def __init__(self, app, auth_config: AuthenticationConfig, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# First, handle authentication
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
|
|
@ -121,15 +126,50 @@ class AuthenticationMiddleware:
|
|||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
|
||||
# Scope-based API access control
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if webmethod.required_scope:
|
||||
user = user_from_scope(scope)
|
||||
if not _has_required_scope(webmethod.required_scope, user):
|
||||
return await self._send_auth_error(
|
||||
send,
|
||||
f"Access denied: user does not have required scope: {webmethod.required_scope}",
|
||||
status=403,
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message):
|
||||
async def _send_auth_error(self, send, message, status=401):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 401,
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps({"error": {"message": message}}).encode()
|
||||
error_key = "message" if status == 401 else "detail"
|
||||
error_msg = json.dumps({"error": {error_key: message}}).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
|
||||
def _has_required_scope(required_scope: str, user: User | None) -> bool:
|
||||
# if no user, assume auth is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not user.attributes:
|
||||
return False
|
||||
|
||||
user_scopes = user.attributes.get("scopes", [])
|
||||
return required_scope in user_scopes
|
||||
|
|
|
|||
|
|
@ -12,17 +12,18 @@ from typing import Any
|
|||
from aiohttp import hdrs
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
EndpointFunc = Callable[..., Any]
|
||||
PathParams = dict[str, str]
|
||||
RouteInfo = tuple[EndpointFunc, str]
|
||||
RouteInfo = tuple[EndpointFunc, str, WebMethod]
|
||||
PathImpl = dict[str, RouteInfo]
|
||||
RouteImpls = dict[str, PathImpl]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
|
|
@ -31,10 +32,12 @@ def toolgroup_protocol_map():
|
|||
}
|
||||
|
||||
|
||||
def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
protocols = api_protocol_map(external_apis)
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
routes = []
|
||||
|
|
@ -65,7 +68,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
|||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = routes
|
||||
|
|
@ -73,8 +76,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
|||
return apis
|
||||
|
||||
|
||||
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||
routes = get_all_api_routes()
|
||||
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
|
||||
api_to_routes = get_all_api_routes(external_apis)
|
||||
route_impls: RouteImpls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
|
|
@ -88,10 +91,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
|||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_routes in routes.items():
|
||||
for api, api_routes in api_to_routes.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for route in api_routes:
|
||||
for route, webmethod in api_routes:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, route.name)
|
||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||
|
|
@ -104,6 +107,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
|||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
||||
func,
|
||||
route.path,
|
||||
webmethod,
|
||||
)
|
||||
|
||||
return route_impls
|
||||
|
|
@ -118,7 +122,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
|||
route_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
|
|
@ -127,11 +131,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
|||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, descriptive_name) in impls.items():
|
||||
for regex, (func, route_path, webmethod) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, descriptive_name
|
||||
return func, path_params, route_path, webmethod
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from openai import BadRequestError
|
|||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_template_args, get_config_from_args
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
|
@ -39,7 +40,12 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.routes import (
|
||||
find_matching_route,
|
||||
|
|
@ -50,9 +56,11 @@ from llama_stack.distribution.stack import (
|
|||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
shutdown_stack,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -144,18 +152,7 @@ async def shutdown(app):
|
|||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
"""
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
await shutdown_stack(app.__llama_stack_impls__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -220,9 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
@functools.wraps(func)
|
||||
async def route_handler(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", "")
|
||||
user = User(principal=principal, attributes=user_attributes)
|
||||
user = user_from_scope(request.scope)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
|
|
@ -280,9 +275,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls):
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
|
|
@ -299,10 +295,12 @@ class TracingMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
|
|
@ -319,6 +317,7 @@ class TracingMiddleware:
|
|||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
|
|
@ -377,20 +376,8 @@ class ClientVersionMiddleware:
|
|||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
"--yaml-config",
|
||||
dest="config",
|
||||
help="(Deprecated) Path to YAML configuration file - use --config instead",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
dest="config",
|
||||
help="Path to YAML configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template",
|
||||
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
|
||||
)
|
||||
|
||||
add_config_template_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
|
|
@ -409,20 +396,8 @@ def main(args: argparse.Namespace | None = None):
|
|||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
log_line = ""
|
||||
if hasattr(args, "config") and args.config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif hasattr(args, "template") and args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --config or --template must be provided")
|
||||
config_or_template = get_config_from_args(args)
|
||||
config_file = resolve_config_or_template(config_or_template, Mode.RUN)
|
||||
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
|
|
@ -442,9 +417,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
|
|
@ -457,10 +429,21 @@ def main(args: argparse.Namespace | None = None):
|
|||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||
else:
|
||||
if config.server.quota:
|
||||
quota = config.server.quota
|
||||
|
|
@ -491,24 +474,14 @@ def main(args: argparse.Namespace | None = None):
|
|||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_routes = get_all_api_routes()
|
||||
# Load external APIs if configured
|
||||
external_apis = load_external_apis(config)
|
||||
all_routes = get_all_api_routes(external_apis)
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
|
|
@ -527,9 +500,12 @@ def main(args: argparse.Namespace | None = None):
|
|||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
impl = impls[api]
|
||||
try:
|
||||
impl = impls[api]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
||||
|
||||
for route in routes:
|
||||
for route, _ in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||
|
|
@ -558,7 +534,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls)
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
@ -592,12 +568,29 @@ def main(args: argparse.Namespace | None = None):
|
|||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
|
|
@ -618,11 +611,8 @@ def extract_path_params(route: str) -> list[str]:
|
|||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
if (
|
||||
obj.get("provider_id") == "__disabled__"
|
||||
or obj.get("shield_id") == "__disabled__"
|
||||
or obj.get("provider_model_id") == "__disabled__"
|
||||
):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib.resources
|
||||
import os
|
||||
import re
|
||||
|
|
@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -90,6 +92,10 @@ RESOURCES = [
|
|||
]
|
||||
|
||||
|
||||
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
|
||||
REGISTRY_REFRESH_TASK = None
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
|
|
@ -99,23 +105,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
# In complex templates, like our starter template, we may have dynamic model ids
|
||||
# given by environment variables. This allows those environment variables to have
|
||||
# a default value of __disabled__ to skip registration of the model if not set.
|
||||
if (
|
||||
hasattr(obj, "provider_model_id")
|
||||
and obj.provider_model_id is not None
|
||||
and "__disabled__" in obj.provider_model_id
|
||||
):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
|
||||
continue
|
||||
|
||||
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
|
|
@ -324,9 +317,61 @@ async def construct_stack(
|
|||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
return impls
|
||||
|
||||
|
||||
async def shutdown_stack(impls: dict[Api, Any]):
|
||||
for impl in impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning(f"No shutdown method for {impl_name}")
|
||||
except TimeoutError:
|
||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
REGISTRY_REFRESH_TASK.cancel()
|
||||
|
||||
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
logger.debug("refreshing registry")
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
||||
|
||||
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||
logger.info("starting registry refresh task")
|
||||
while True:
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
|||
set -x
|
||||
|
||||
if [ -n "$yaml_config" ]; then
|
||||
yaml_config_arg="--config $yaml_config"
|
||||
yaml_config_arg="$yaml_config"
|
||||
else
|
||||
yaml_config_arg=""
|
||||
fi
|
||||
|
|
|
|||
125
llama_stack/distribution/utils/config_resolution.py
Normal file
125
llama_stack/distribution/utils/config_resolution.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="config_resolution")
|
||||
|
||||
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
|
||||
class Mode(StrEnum):
|
||||
RUN = "run"
|
||||
BUILD = "build"
|
||||
|
||||
|
||||
def resolve_config_or_template(
|
||||
config_or_template: str,
|
||||
mode: Mode = Mode.RUN,
|
||||
) -> Path:
|
||||
"""
|
||||
Resolve a config/template argument to a concrete config file path.
|
||||
|
||||
Args:
|
||||
config_or_template: User input (file path, template name, or built distribution)
|
||||
mode: Mode resolving for ("run", "build", "server")
|
||||
|
||||
Returns:
|
||||
Path to the resolved config file
|
||||
|
||||
Raises:
|
||||
ValueError: If resolution fails
|
||||
"""
|
||||
|
||||
# Strategy 1: Try as file path first
|
||||
config_path = Path(config_or_template)
|
||||
if config_path.exists() and config_path.is_file():
|
||||
logger.info(f"Using file path: {config_path}")
|
||||
return config_path.resolve()
|
||||
|
||||
# Strategy 2: Try as template name (if no .yaml extension)
|
||||
if not config_or_template.endswith(".yaml"):
|
||||
template_config = _get_template_config_path(config_or_template, mode)
|
||||
if template_config.exists():
|
||||
logger.info(f"Using template: {template_config}")
|
||||
return template_config
|
||||
|
||||
# Strategy 3: Try as built distribution name
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
# Strategy 4: Failed - provide helpful error
|
||||
raise ValueError(_format_resolution_error(config_or_template, mode))
|
||||
|
||||
|
||||
def _get_template_config_path(template_name: str, mode: Mode) -> Path:
|
||||
"""Get the config file path for a template."""
|
||||
return TEMPLATE_DIR / template_name / f"{mode}.yaml"
|
||||
|
||||
|
||||
def _format_resolution_error(config_or_template: str, mode: Mode) -> str:
|
||||
"""Format a helpful error message for resolution failures."""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
template_path = _get_template_config_path(config_or_template, mode)
|
||||
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
|
||||
available_templates = _get_available_templates()
|
||||
templates_str = ", ".join(available_templates) if available_templates else "none found"
|
||||
|
||||
return f"""Could not resolve config or template '{config_or_template}'.
|
||||
|
||||
Tried the following locations:
|
||||
1. As file path: {Path(config_or_template).resolve()}
|
||||
2. As template: {template_path}
|
||||
3. As built distribution: ({distrib_path}, {distrib_path2})
|
||||
|
||||
Available templates: {templates_str}
|
||||
|
||||
Did you mean one of these templates?
|
||||
{_format_template_suggestions(available_templates, config_or_template)}
|
||||
"""
|
||||
|
||||
|
||||
def _get_available_templates() -> list[str]:
|
||||
"""Get list of available template names."""
|
||||
if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
|
||||
return []
|
||||
|
||||
return list(
|
||||
set(
|
||||
[d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _format_template_suggestions(templates: list[str], user_input: str) -> str:
|
||||
"""Format template suggestions for error messages, showing closest matches first."""
|
||||
if not templates:
|
||||
return " (no templates found)"
|
||||
|
||||
import difflib
|
||||
|
||||
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
|
||||
close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3)
|
||||
display_templates = close_matches if close_matches else templates[:3]
|
||||
|
||||
suggestions = [f" - {t}" for t in display_templates]
|
||||
return "\n".join(suggestions)
|
||||
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
|||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
def formulate_run_args(image_type: str, image_name: str) -> list[str]:
|
||||
env_name = ""
|
||||
|
||||
if image_type == LlamaStackImageType.CONDA.value:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from logging.config import dictConfig
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ CATEGORIES = [
|
|||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
"telemetry",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
|
|
@ -113,6 +115,11 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
|
|||
return category_levels
|
||||
|
||||
|
||||
def strip_rich_markup(text):
|
||||
"""Remove Rich markup tags like [dim], [bold magenta], etc."""
|
||||
return re.sub(r"\[/?[a-zA-Z0-9 _#=,]+\]", "", text)
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=150)
|
||||
|
|
@ -131,6 +138,19 @@ class CustomRichHandler(RichHandler):
|
|||
self.markup = original_markup
|
||||
|
||||
|
||||
class CustomFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename, mode="a", encoding=None, delay=False):
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
# Default formatter to match console output
|
||||
self.default_formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(category)s: %(message)s")
|
||||
self.setFormatter(self.default_formatter)
|
||||
|
||||
def emit(self, record):
|
||||
if hasattr(record, "msg"):
|
||||
record.msg = strip_rich_markup(str(record.msg))
|
||||
super().emit(record)
|
||||
|
||||
|
||||
def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
|
@ -167,8 +187,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
|
|||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"()": CustomFileHandler,
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
|
|
|
|||
|
|
@ -43,10 +43,24 @@ class ModelsProtocolPrivate(Protocol):
|
|||
-> Provider uses provider-model-id for inference
|
||||
"""
|
||||
|
||||
# this should be called `on_model_register` or something like that.
|
||||
# the provider should _not_ be able to change the object in this
|
||||
# callback
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
||||
# the Stack router will query each provider for their list of models
|
||||
# if a `refresh_interval_seconds` is provided, this method will be called
|
||||
# periodically to refresh the list of models
|
||||
#
|
||||
# NOTE: each model returned will be registered with the model registry. this means
|
||||
# a callback to the `register_model()` method will be made. this is duplicative and
|
||||
# may be removed in the future.
|
||||
async def list_models(self) -> list[Model] | None: ...
|
||||
|
||||
async def should_refresh_models(self) -> bool: ...
|
||||
|
||||
|
||||
class ShieldsProtocolPrivate(Protocol):
|
||||
async def register_shield(self, shield: Shield) -> None: ...
|
||||
|
|
@ -104,6 +118,19 @@ class ProviderSpec(BaseModel):
|
|||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||
)
|
||||
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: list[str] = Field(default_factory=list)
|
||||
|
||||
|
|
@ -113,7 +140,7 @@ class ProviderSpec(BaseModel):
|
|||
|
||||
|
||||
class RoutingTable(Protocol):
|
||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
|
|
@ -124,7 +151,7 @@ class AdapterSpec(BaseModel):
|
|||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
default_factory=str,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
|
|
@ -162,14 +189,7 @@ The container image to use for this implementation. If one is provided, pip_pack
|
|||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||
""",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
# module field is inherited from ProviderSpec
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
|
@ -212,9 +232,7 @@ API responses, specify the adapter here.
|
|||
def container_image(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
return self.adapter.module
|
||||
# module field is inherited from ProviderSpec
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
|
|
@ -226,14 +244,19 @@ API responses, specify the adapter here.
|
|||
|
||||
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
|
||||
api: Api,
|
||||
adapter: AdapterSpec,
|
||||
api_dependencies: list[Api] | None = None,
|
||||
optional_api_dependencies: list[Api] | None = None,
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
module=adapter.module,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
optional_api_dependencies=optional_api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
|
@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str:
|
|||
|
||||
|
||||
async def get_raw_document_text(document: Document) -> str:
|
||||
if not document.mime_type.startswith("text/"):
|
||||
# Handle deprecated text/yaml mime type with warning
|
||||
if document.mime_type == "text/yaml":
|
||||
warnings.warn(
|
||||
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
|
|
|
|||
|
|
@ -128,6 +128,11 @@ class AgentPersistence:
|
|||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
# The kvstore does not guarantee order, so we sort by started_at
|
||||
# to ensure consistent ordering of turns.
|
||||
turns.sort(key=lambda t: t.started_at)
|
||||
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import AccessRule, Api
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
from .files import LocalfsFilesImpl
|
||||
|
|
@ -14,7 +14,7 @@ from .files import LocalfsFilesImpl
|
|||
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]):
|
||||
impl = LocalfsFilesImpl(config)
|
||||
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
impl = LocalfsFilesImpl(config, policy)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -19,16 +19,19 @@ from llama_stack.apis.files import (
|
|||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
|
||||
|
||||
class LocalfsFilesImpl(Files):
|
||||
def __init__(self, config: LocalfsFilesImplConfig) -> None:
|
||||
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self.config = config
|
||||
self.sql_store: SqlStore | None = None
|
||||
self.policy = policy
|
||||
self.sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the files provider by setting up storage directory and metadata database."""
|
||||
|
|
@ -37,7 +40,7 @@ class LocalfsFilesImpl(Files):
|
|||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = sqlstore_impl(self.config.metadata_store)
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
|
@ -126,6 +129,7 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
@ -156,7 +160,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
@ -174,7 +178,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
@ -197,7 +201,7 @@ class LocalfsFilesImpl(Files):
|
|||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
# Get file metadata
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
|
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
|
|||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
|
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
identifier="all-MiniLM-L6-v2",
|
||||
provider_resource_id="all-MiniLM-L6-v2",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMConfig(BaseModel):
|
||||
"""Configuration for the vLLM inference provider.
|
||||
|
||||
Note that the model name is no longer part of this static configuration.
|
||||
You can bind an instance of this provider to a specific model with the
|
||||
``models.register()`` API call."""
|
||||
|
||||
tensor_parallel_size: int = Field(
|
||||
default=1,
|
||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
|
||||
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
|
||||
enforce_eager: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||
)
|
||||
gpu_memory_utilization: float = Field(
|
||||
default=0.3,
|
||||
description=(
|
||||
"How much GPU memory will be allocated when this provider has finished "
|
||||
"loading, including memory that was already allocated before loading."
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:=4096}",
|
||||
"max_model_len": "${env.MAX_MODEL_LEN:=4096}",
|
||||
"max_num_seqs": "${env.MAX_NUM_SEQS:=4}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:=False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}",
|
||||
}
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import vllm
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
GrammarResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
)
|
||||
|
||||
###############################################################################
|
||||
# This file contains OpenAI compatibility code that is currently only used
|
||||
# by the inline vLLM connector. Some or all of this code may be moved to a
|
||||
# central location at a later date.
|
||||
|
||||
|
||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||
"""
|
||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||
the content field for compabilitiy with OpenAI-style APIs.
|
||||
|
||||
Generates a content string that emulates the current behavior
|
||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||
|
||||
:param message: Message that may include ``context`` field
|
||||
|
||||
:returns: A version of ``message`` with any context merged into the
|
||||
``content`` field.
|
||||
"""
|
||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||
return message
|
||||
if message.context is None:
|
||||
return message
|
||||
return UserMessage(
|
||||
role=message.role,
|
||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||
content=message.content + "\n\n" + message.context,
|
||||
context=None,
|
||||
)
|
||||
|
||||
|
||||
def _llama_stack_tools_to_openai_tools(
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
"""
|
||||
if tools is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for t in tools:
|
||||
if isinstance(t.tool_name, BuiltinTool):
|
||||
raise NotImplementedError("Built-in tools not yet implemented")
|
||||
if t.parameters is None:
|
||||
parameters = None
|
||||
else: # if t.parameters is not None
|
||||
# Convert the "required" flags to a list of required params
|
||||
required_params = [k for k, v in t.parameters.items() if v.required]
|
||||
parameters = {
|
||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||
"properties": {
|
||||
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
||||
},
|
||||
"required": required_params,
|
||||
}
|
||||
|
||||
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
||||
name=t.tool_name, description=t.description, parameters=parameters
|
||||
)
|
||||
|
||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
||||
return result
|
||||
|
||||
|
||||
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a chat completion request in Llama Stack format into an
|
||||
equivalent set of arguments to pass to an OpenAI-compatible
|
||||
chat completions API.
|
||||
|
||||
:param request: Bundled request parameters in Llama Stack format.
|
||||
|
||||
:returns: Dictionary of key-value pairs to use as an initializer
|
||||
for a dataclass or to be converted directly to JSON and sent
|
||||
over the wire.
|
||||
"""
|
||||
|
||||
converted_messages = [
|
||||
# This mystery async call makes the parent function also be async
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
|
||||
for m in request.messages
|
||||
]
|
||||
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if (
|
||||
request.tool_config is not None
|
||||
and request.tool_config.tool_choice == ToolChoice.auto
|
||||
and request.tools is not None
|
||||
and len(request.tools) > 0
|
||||
):
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
# Other connectors appear to drop it quietly.
|
||||
|
||||
# Use Llama Stack shared code to translate sampling parameters.
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
|
||||
# get_sampling_options() translates repetition penalties to an option that
|
||||
# OpenAI's APIs don't know about.
|
||||
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
|
||||
# For now, translate repetition penalties into a format that vLLM's broken
|
||||
# API will handle correctly. Two wrongs make a right...
|
||||
if "repeat_penalty" in sampling_options:
|
||||
del sampling_options["repeat_penalty"]
|
||||
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
|
||||
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
# Convert a single response format into four different parameters, per
|
||||
# the OpenAI spec
|
||||
guided_decoding_options = dict()
|
||||
if request.response_format is None:
|
||||
# Use defaults
|
||||
pass
|
||||
elif isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
guided_decoding_options["guided_json"] = request.response_format.json_schema
|
||||
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
|
||||
|
||||
logprob_options = dict()
|
||||
if request.logprobs is not None:
|
||||
logprob_options["logprobs"] = request.logprobs.top_k
|
||||
|
||||
# Marshall together all the arguments for a ChatCompletionRequest
|
||||
request_options = {
|
||||
"model": request.model,
|
||||
"messages": converted_messages,
|
||||
"tools": converted_tools,
|
||||
"tool_choice": converted_tool_choice,
|
||||
"stream": request.stream,
|
||||
**sampling_options,
|
||||
**guided_decoding_options,
|
||||
**logprob_options,
|
||||
}
|
||||
|
||||
return request_options
|
||||
|
|
@ -1,811 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama import sku_list
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_stop_reason,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMConfig
|
||||
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
||||
|
||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
||||
# available parsers.
|
||||
# TODO: Expand this list
|
||||
CONFIG_TYPE_TO_TOOL_PARSER = {
|
||||
"GraniteConfig": "granite",
|
||||
"MllamaConfig": "llama3_json",
|
||||
"LlamaConfig": "llama3_json",
|
||||
}
|
||||
DEFAULT_TOOL_PARSER = "pythonic"
|
||||
|
||||
|
||||
logger = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
def _random_uuid_str() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||
|
||||
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
||||
indicating no constraints.
|
||||
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
||||
"""
|
||||
if response_format is None:
|
||||
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
||||
# value that crashes the executor on some code paths. Use ``None`` instead.
|
||||
return None
|
||||
|
||||
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
||||
# Translate the types that exist and detect if Llama Stack adds new ones.
|
||||
if isinstance(response_format, JsonSchemaResponseFormat):
|
||||
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
||||
elif isinstance(response_format, GrammarResponseFormat):
|
||||
# BNF grammar.
|
||||
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
||||
# representation of the grammar.
|
||||
raise TypeError(
|
||||
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
||||
"reference implementation does not implement it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
||||
|
||||
|
||||
def _convert_sampling_params(
|
||||
sampling_params: SamplingParams | None,
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
log_prob_config: LogProbConfig | None,
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||
format."""
|
||||
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
||||
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if log_prob_config is None:
|
||||
log_prob_config = LogProbConfig()
|
||||
|
||||
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
||||
if sampling_params.strategy.top_k == 0:
|
||||
# vLLM treats "k" differently for top-k sampling
|
||||
vllm_top_k = -1
|
||||
else:
|
||||
vllm_top_k = sampling_params.strategy.top_k
|
||||
else:
|
||||
vllm_top_k = -1
|
||||
|
||||
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
vllm_top_p = sampling_params.strategy.top_p
|
||||
# Llama Stack only allows temperature with top-P.
|
||||
vllm_temperature = sampling_params.strategy.temperature
|
||||
else:
|
||||
vllm_top_p = 1.0
|
||||
vllm_temperature = 0.0
|
||||
|
||||
# vLLM allows top-p and top-k at the same time.
|
||||
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
||||
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
||||
temperature=vllm_temperature,
|
||||
top_p=vllm_top_p,
|
||||
top_k=vllm_top_k,
|
||||
repetition_penalty=sampling_params.repetition_penalty,
|
||||
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
||||
logprobs=log_prob_config.top_k,
|
||||
)
|
||||
return vllm_sampling_params
|
||||
|
||||
|
||||
class VLLMInferenceImpl(
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
"""
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
||||
|
||||
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
||||
"""
|
||||
|
||||
config: VLLMConfig
|
||||
register_helper: ModelRegistryHelper
|
||||
model_ids: set[str]
|
||||
resolved_model_id: str | None
|
||||
engine: AsyncLLMEngine | None
|
||||
chat: OpenAIServingChat | None
|
||||
is_meta_llama_model: bool
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
logger.info(f"Config is: {self.config}")
|
||||
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# The following are initialized when paths are bound to this provider
|
||||
self.resolved_model_id = None
|
||||
self.model_ids = set()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.is_meta_llama_model = False
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
||||
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Callback that is invoked through many levels of indirection during provider class
|
||||
instantiation, sometime after when __init__() is called and before any model registration
|
||||
methods or methods connected to a REST API are called.
|
||||
|
||||
It's not clear what assumptions the class can make about the platform's initialization
|
||||
state here that can't be made during __init__(), and vLLM can't be started until we know
|
||||
what model it's supposed to be serving, so nothing happens here currently.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.model_ids = set()
|
||||
self.resolved_model_id = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
||||
|
||||
# Note that the return type of the superclass method is WRONG
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Callback that is called when the server associates an inference endpoint with an
|
||||
inference provider.
|
||||
|
||||
:param model: Object that encapsulates parameters necessary for identifying a specific
|
||||
LLM.
|
||||
|
||||
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
||||
before returning this object.
|
||||
"""
|
||||
logger.debug(f"In register_model({model})")
|
||||
|
||||
# First attempt to interpret the model coordinates as a Llama model name
|
||||
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
||||
if resolved_llama_model is not None:
|
||||
# Load from Hugging Face repo into default local cache dir
|
||||
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
||||
|
||||
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
|
||||
# Don't set self.is_meta_llama_model until we actually load the model.
|
||||
is_meta_llama_model = True
|
||||
else: # if resolved_llama_model is None
|
||||
# Not a Llama model name. Pass the model id through to vLLM's loader
|
||||
model_id_for_vllm = model.provider_model_id
|
||||
is_meta_llama_model = False
|
||||
|
||||
if self.resolved_model_id is not None:
|
||||
if model_id_for_vllm != self.resolved_model_id:
|
||||
raise ValueError(
|
||||
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
||||
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
|
||||
f"copies of the provider instead."
|
||||
)
|
||||
else:
|
||||
# Model already loaded
|
||||
logger.info(
|
||||
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
||||
)
|
||||
self.model_ids.add(model.model_id)
|
||||
return model
|
||||
|
||||
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
||||
if is_meta_llama_model:
|
||||
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
||||
self.is_meta_llama_model = is_meta_llama_model
|
||||
|
||||
# If we get here, this is the first time registering a model.
|
||||
# Preload so that the first inference request won't time out.
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_id_for_vllm,
|
||||
tokenizer=model_id_for_vllm,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
max_num_seqs=self.config.max_num_seqs,
|
||||
max_model_len=self.config.max_model_len,
|
||||
)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
||||
# parser, we need to determine what model architecture is being used. For now, we infer
|
||||
# that information from what config class the model uses.
|
||||
low_level_model_config = self.engine.engine.get_model_config()
|
||||
hf_config = low_level_model_config.hf_config
|
||||
hf_config_class_name = hf_config.__class__.__name__
|
||||
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
||||
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
||||
else:
|
||||
# No info -- choose a default so we can at least attempt tool
|
||||
# use.
|
||||
tool_parser = DEFAULT_TOOL_PARSER
|
||||
logger.debug(f"{hf_config_class_name=}")
|
||||
logger.debug(f"{tool_parser=}")
|
||||
|
||||
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
||||
model_config = await self.engine.get_model_config()
|
||||
self.chat = OpenAIServingChat(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
models=OpenAIServingModels(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
base_model_paths=[
|
||||
# The layer below us will only see resolved model IDs
|
||||
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
|
||||
],
|
||||
),
|
||||
response_role="assistant",
|
||||
request_logger=None, # Use default logging
|
||||
chat_template=None, # Use default template from model checkpoint
|
||||
enable_auto_tools=True,
|
||||
tool_parser=tool_parser,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
self.resolved_model_id = model_id_for_vllm
|
||||
self.model_ids.add(model.model_id)
|
||||
|
||||
logger.info(f"Finished preloading model: {model_id_for_vllm}")
|
||||
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
"""
|
||||
Callback that is called when the server removes an inference endpoint from an inference
|
||||
provider.
|
||||
|
||||
:param model_id: The same external ID that the higher layers of the stack previously passed
|
||||
to :func:`register_model()`
|
||||
"""
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
||||
)
|
||||
self.model_ids.remove(model_id)
|
||||
|
||||
if len(self.model_ids) == 0:
|
||||
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
||||
# resources.
|
||||
# Note that this operation may cause in-flight chat completion requests on the
|
||||
# now-unregistered model to return errors.
|
||||
self.resolved_model_id = None
|
||||
self.chat = None
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM Inference INTERFACE
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
if not isinstance(content, str):
|
||||
raise NotImplementedError("Multimodal input not currently supported")
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||
|
||||
logger.debug(f"{converted_sampling_params=}")
|
||||
|
||||
if stream:
|
||||
return self._streaming_completion(content, converted_sampling_params)
|
||||
else:
|
||||
streaming_result = None
|
||||
async for _ in self._streaming_completion(content, converted_sampling_params):
|
||||
pass
|
||||
return CompletionResponse(
|
||||
content=streaming_result.delta,
|
||||
stop_reason=streaming_result.stop_reason,
|
||||
logprobs=streaming_result.logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message], # type: ignore
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None, # type: ignore
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
|
||||
# Convert to Llama Stack internal format for consistency
|
||||
request = ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if self.is_meta_llama_model:
|
||||
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
||||
# templating layer in Llama Stack currently produces better results.
|
||||
logger.debug(
|
||||
f"Routing {self.resolved_model_id} chat completion through "
|
||||
f"Llama Stack's templating layer instead of vLLM's."
|
||||
)
|
||||
return await self._chat_completion_for_meta_llama(request)
|
||||
|
||||
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
||||
|
||||
logger.debug(f"Converted request: {chat_completion_request}")
|
||||
|
||||
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
||||
logger.debug(f"Result from vLLM: {vllm_result}")
|
||||
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
||||
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
||||
|
||||
# Return type depends on "stream" argument
|
||||
if stream:
|
||||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator.
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
||||
return self._convert_non_streaming_results(vllm_result)
|
||||
|
||||
###########################################################################
|
||||
# INTERNAL METHODS
|
||||
|
||||
async def _streaming_completion(
|
||||
self, content: str, sampling_params: vllm.SamplingParams
|
||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
||||
that arguments have been validated upstream.
|
||||
|
||||
:param content: Must be a string
|
||||
:param sampling_params: Paramters from public API's ``response_format``
|
||||
and ``sampling_params`` arguments, converted to VLLM format
|
||||
"""
|
||||
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
||||
# layer, because doing so simplifies the code here.
|
||||
|
||||
# The vLLM engine requires a unique identifier for each call to generate()
|
||||
request_id = _random_uuid_str()
|
||||
|
||||
# The vLLM generate() API is streaming-only and returns an async generator.
|
||||
# The generator returns objects of type vllm.RequestOutput.
|
||||
results_generator = self.engine.generate(content, sampling_params, request_id)
|
||||
|
||||
# Need to know the model's EOS token ID for the conversion code below.
|
||||
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
|
||||
# we drill down to the LLMEngine inside the AsyncLLMEngine.
|
||||
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
|
||||
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
|
||||
llm_engine = self.engine.engine
|
||||
tokenizer_group = llm_engine.tokenizer
|
||||
eos_token_id = tokenizer_group.tokenizer.eos_token_id
|
||||
|
||||
request_output: vllm.RequestOutput = None
|
||||
async for request_output in results_generator:
|
||||
# Check for weird inference failures
|
||||
if request_output.outputs is None or len(request_output.outputs) == 0:
|
||||
# This case also should never happen
|
||||
raise ValueError("Inference produced empty result")
|
||||
|
||||
# If we get here, then request_output contains the final output of the generate() call.
|
||||
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
||||
# us to return one.
|
||||
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||
completion_string = output.text
|
||||
|
||||
# Convert logprobs from vLLM's format to Llama Stack's format
|
||||
logprobs = [
|
||||
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
||||
for logprob_dict in output.logprobs
|
||||
]
|
||||
|
||||
# The final output chunk should be labeled with the reason that the overall generate()
|
||||
# call completed.
|
||||
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
||||
if output.stop_reason is None:
|
||||
stop_reason = None # Still going
|
||||
elif output.stop_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif output.stop_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
elif isinstance(output.stop_reason, int):
|
||||
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
||||
# will return the token ID of the EOS token in the stop_reason field.
|
||||
stop_reason = StopReason.end_of_turn
|
||||
else:
|
||||
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
||||
|
||||
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
||||
# some reason.
|
||||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||
|
||||
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
||||
# provide one if it runs out of tokens.
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=completion_string,
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def _convert_non_streaming_results(
|
||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
||||
equivalent Llama Stack object.
|
||||
|
||||
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
||||
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
||||
the fields that aren't currently present in the Llama Stack dataclass.
|
||||
"""
|
||||
|
||||
# There may be multiple responses, but we can only pass through the first one.
|
||||
if len(vllm_result.choices) == 0:
|
||||
raise ValueError("Don't know how to convert response object without any responses")
|
||||
vllm_message = vllm_result.choices[0].message
|
||||
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
||||
|
||||
converted_message = CompletionMessage(
|
||||
role=vllm_message.role,
|
||||
# Llama Stack API won't accept None for content field.
|
||||
content=("" if vllm_message.content is None else vllm_message.content),
|
||||
stop_reason=get_stop_reason(vllm_finish_reason),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id=t.id,
|
||||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
arguments_json=t.function.arguments,
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Convert logprobs
|
||||
|
||||
logger.debug(f"Converted message: {converted_message}")
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=converted_message,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
of the chat template currently produces more reliable outputs.
|
||||
|
||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
||||
"""
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# Note that this function call modifies `request` in place.
|
||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
||||
|
||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||
completion_response_or_iterator = await self.completion(
|
||||
model_id=model_id,
|
||||
content=prompt,
|
||||
sampling_params=request.sampling_params,
|
||||
response_format=request.response_format,
|
||||
stream=request.stream,
|
||||
logprobs=request.logprobs,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
||||
)
|
||||
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
||||
|
||||
# elsif not request.stream:
|
||||
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
||||
)
|
||||
completion_response: CompletionResponse = completion_response_or_iterator
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
completion_response.content, completion_response.stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=completion_response.logprobs,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama_streaming(
|
||||
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
||||
) -> AsyncIterator:
|
||||
"""
|
||||
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
||||
method to keep asyncio happy.
|
||||
"""
|
||||
|
||||
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
||||
last_text_len = 0
|
||||
async for chunk in results_iterator:
|
||||
if chunk.stop_reason == StopReason.end_of_turn:
|
||||
finish_reason = "stop"
|
||||
elif chunk.stop_reason == StopReason.end_of_message:
|
||||
finish_reason = "eos"
|
||||
elif chunk.stop_reason == StopReason.out_of_tokens:
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = None
|
||||
|
||||
# Convert delta back to an actual delta
|
||||
text_delta = chunk.delta[last_text_len:]
|
||||
last_text_len = len(chunk.delta)
|
||||
|
||||
logger.debug(f"{text_delta=}; {finish_reason=}")
|
||||
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
logger.debug(f"Returning chunk: {chunk}")
|
||||
yield chunk
|
||||
|
||||
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
||||
"""
|
||||
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
||||
API into a second async iterator that returns Llama Stack objects.
|
||||
|
||||
:param vllm_result: Stream of strings that need to be parsed
|
||||
"""
|
||||
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||
# those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: dict[int, dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||
# simplify logic below
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=None,
|
||||
)
|
||||
)
|
||||
|
||||
converted_stop_reason = None
|
||||
async for chunk_str in vllm_result:
|
||||
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
||||
# end with "\n\n".
|
||||
_prefix = "data: "
|
||||
_suffix = "\n\n"
|
||||
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
||||
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
||||
|
||||
# In between the "data: " and newlines is an event record
|
||||
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
||||
|
||||
# The end of the stream is indicated with "[DONE]"
|
||||
if data_str == "[DONE]":
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Anything that is not "[DONE]" should be a JSON record
|
||||
parsed_chunk = json.loads(data_str)
|
||||
|
||||
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs only support
|
||||
# returning one.
|
||||
first_choice = parsed_chunk["choices"][0]
|
||||
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
||||
delta_record = first_choice["delta"]
|
||||
|
||||
if "content" in delta_record:
|
||||
# Text delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=delta_record["content"]),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
elif "tool_calls" in delta_record:
|
||||
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
||||
# calls, so buffer until we get a "tool calls" stop reason
|
||||
for tc in delta_record["tool_calls"]:
|
||||
index = tc["index"]
|
||||
if index not in index_to_tool_call:
|
||||
# First time this tool call is showing up
|
||||
index_to_tool_call[index] = dict()
|
||||
tool_call = index_to_tool_call[index]
|
||||
if "id" in tc:
|
||||
tool_call["call_id"] = tc["id"]
|
||||
if "function" in tc:
|
||||
if "name" in tc["function"]:
|
||||
tool_call["tool_name"] = tc["function"]["name"]
|
||||
if "arguments" in tc["function"]:
|
||||
# Arguments comes in as pieces of a string
|
||||
if "arguments_str" not in tool_call:
|
||||
tool_call["arguments_str"] = ""
|
||||
tool_call["arguments_str"] += tc["function"]["arguments"]
|
||||
else:
|
||||
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
||||
|
||||
if first_choice["finish_reason"] == "tool_calls":
|
||||
# Special OpenAI code for "tool calls complete".
|
||||
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
||||
# call.
|
||||
for tool_call_record in index_to_tool_call.values():
|
||||
# Arguments come in as a string. Parse the completed string.
|
||||
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
||||
del tool_call_record["arguments_str"]
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
||||
# normally.
|
||||
raise ValueError("vLLM event stream ended without [DONE] message.")
|
||||
|
|
@ -319,7 +319,7 @@ class HFFinetuningSingleDevice:
|
|||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||
save_strategy=save_strategy,
|
||||
report_to="none",
|
||||
max_seq_length=provider_config.max_seq_length,
|
||||
max_length=provider_config.max_seq_length,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
gradient_checkpointing=provider_config.gradient_checkpointing,
|
||||
learning_rate=lr,
|
||||
|
|
|
|||
|
|
@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# Allow any model to be registered as a shield
|
||||
# The model will be validated during runtime when making inference calls
|
||||
pass
|
||||
model_id = shield.provider_resource_id
|
||||
if not model_id:
|
||||
raise ValueError("Llama Guard shield must have a model id")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan
|
|||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
from opentelemetry.trace.status import StatusCode
|
||||
|
||||
# Colors for console output
|
||||
COLORS = {
|
||||
"reset": "\033[0m",
|
||||
"bold": "\033[1m",
|
||||
"dim": "\033[2m",
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
}
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name="console_span_processor", category="telemetry")
|
||||
|
||||
|
||||
class ConsoleSpanProcessor(SpanProcessor):
|
||||
|
|
@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[START]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[END]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
|
||||
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
|
||||
span_context += " [bold red][ERROR][/bold red]"
|
||||
elif span.status.status_code != StatusCode.UNSET:
|
||||
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
|
||||
|
||||
span_context += f" [{span.status.status_code}]"
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
||||
|
||||
print(span_context)
|
||||
span_context += f" ({duration_ms:.2f}ms)"
|
||||
logger.info(span_context)
|
||||
|
||||
if self.print_attributes and span.attributes:
|
||||
for key, value in span.attributes.items():
|
||||
|
|
@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
str_value = str(value)
|
||||
if len(str_value) > 1000:
|
||||
str_value = str_value[:997] + "..."
|
||||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
logger.info(f" [dim]{key}[/dim]: {str_value}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, dict | list):
|
||||
if isinstance(message, dict) or isinstance(message, list):
|
||||
message = json.dumps(message, indent=2)
|
||||
|
||||
severity_colors = {
|
||||
"error": f"{COLORS['bold']}{COLORS['red']}",
|
||||
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
|
||||
"info": COLORS["white"],
|
||||
"debug": COLORS["dim"],
|
||||
}
|
||||
msg_color = severity_colors.get(severity, COLORS["white"])
|
||||
|
||||
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
|
||||
|
||||
severity_color = {
|
||||
"error": "red",
|
||||
"warn": "yellow",
|
||||
"info": "white",
|
||||
"debug": "dim",
|
||||
}.get(severity, "white")
|
||||
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
|
||||
logger.info(f"/r[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
|||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,12 +6,25 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
def sample_run_config(
|
||||
cls, __distro_dir__: str, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": db_path,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_inline_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
|
||||
# A list of chunk id's in the same order as they are in the index,
|
||||
# must be updated when chunks are added or removed
|
||||
self.chunk_id_lock = asyncio.Lock()
|
||||
self.chunk_ids: list[Any] = []
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
instance = cls(dimension, kvstore, bank_id)
|
||||
|
|
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
||||
try:
|
||||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
raise ValueError(
|
||||
|
|
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
|
|||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
async with self.chunk_id_lock:
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
|
||||
|
||||
# Save updated index
|
||||
await self._save_index()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
if chunk_id not in self.chunk_ids:
|
||||
return
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
index = self.chunk_ids.index(chunk_id)
|
||||
self.index.remove_ids(np.array([index]))
|
||||
|
||||
new_chunk_by_index = {}
|
||||
for idx, chunk in self.chunk_by_index.items():
|
||||
# Shift all chunks after the removed chunk to the left
|
||||
if idx > index:
|
||||
new_chunk_by_index[idx - 1] = chunk
|
||||
else:
|
||||
new_chunk_by_index[idx] = chunk
|
||||
self.chunk_by_index = new_chunk_by_index
|
||||
self.chunk_ids.pop(index)
|
||||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
|
|
@ -261,47 +289,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file data to kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=content_key, value=json.dumps(file_contents))
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
stored_data = await self.kvstore.get(key)
|
||||
return json.loads(stored_data) if stored_data else {}
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
|
||||
stored_data = await self.kvstore.get(key)
|
||||
return json.loads(stored_data) if stored_data else []
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store data from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
|
||||
keys_to_delete = [
|
||||
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
|
||||
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
try:
|
||||
await self.kvstore.delete(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete key {key}: {e}")
|
||||
continue
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a faiss index"""
|
||||
faiss_index = self.cache[store_id].index
|
||||
for chunk_id in chunk_ids:
|
||||
await faiss_index.delete_chunk(chunk_id)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
|
|
@ -426,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
"""Remove a chunk from the SQLite vector store."""
|
||||
|
||||
def _delete_chunk():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
|
||||
# Delete from metadata table
|
||||
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
# Delete from vector table
|
||||
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
# Delete from FTS table
|
||||
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete_chunk)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
"""
|
||||
|
|
@ -506,140 +534,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to SQLite database."""
|
||||
|
||||
def _create_or_store():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
# Create a table to persist OpenAI vector store files.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
|
||||
store_id TEXT,
|
||||
file_id TEXT,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (store_id, file_id)
|
||||
);
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
|
||||
store_id TEXT,
|
||||
file_id TEXT,
|
||||
contents TEXT,
|
||||
PRIMARY KEY (store_id, file_id)
|
||||
);
|
||||
""")
|
||||
connection.commit()
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
|
||||
(store_id, file_id, json.dumps(file_info)),
|
||||
)
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_store_files_contents (store_id, file_id, contents) VALUES (?, ?, ?)",
|
||||
(store_id, file_id, json.dumps(file_contents)),
|
||||
)
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_create_or_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from SQLite database."""
|
||||
|
||||
def _load():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT metadata FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
(metadata,) = row
|
||||
return metadata
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
stored_data = await asyncio.to_thread(_load)
|
||||
return json.loads(stored_data) if stored_data else {}
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from SQLite database."""
|
||||
|
||||
def _load():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
(contents,) = row
|
||||
return contents
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
stored_contents = await asyncio.to_thread(_load)
|
||||
return json.loads(stored_contents) if stored_contents else []
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in SQLite database."""
|
||||
|
||||
def _update():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"UPDATE openai_vector_store_files SET metadata = ? WHERE store_id = ? AND file_id = ?",
|
||||
(json.dumps(file_info), store_id, file_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_update)
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store file metadata from SQLite database."""
|
||||
|
||||
def _delete():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
@ -655,3 +549,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
|
|
|
|||
|
|
@ -37,16 +37,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
||||
description="Meta's reference implementation of inference with support for various model formats and optimization techniques.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::vllm",
|
||||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.vllm",
|
||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||
description="vLLM inference provider for high-performance model serving with PagedAttention and continuous batching.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::sentence-transformers",
|
||||
|
|
@ -234,17 +224,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
||||
description="Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
@ -256,50 +235,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.together_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
|
||||
description="Together AI OpenAI-compatible provider for using Together models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.groq_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
|
||||
description="Groq OpenAI-compatible provider for using Groq models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
|
||||
description="Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval.
|
|||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Faiss.
|
||||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
|
@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class AnthropicConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"anthropic/claude-3-5-sonnet-latest",
|
||||
"anthropic/claude-3-7-sonnet-latest",
|
||||
"anthropic/claude-3-5-haiku-latest",
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
|
@ -21,17 +21,17 @@ MODEL_ENTRIES = (
|
|||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
provider_model_id="voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
provider_model_id="voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
provider_model_id="voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
|
|
|
|||
|
|
@ -63,18 +63,20 @@ class BedrockInferenceAdapter(
|
|||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
self._client = create_bedrock_client(config)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
if self._client is None:
|
||||
self._client = create_bedrock_client(self._config)
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ class CerebrasInferenceAdapter(
|
|||
)
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class CerebrasImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import CerebrasCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .cerebras import CerebrasCompatInferenceAdapter
|
||||
|
||||
adapter = CerebrasCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..cerebras.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: CerebrasCompatConfig
|
||||
|
||||
def __init__(self, config: CerebrasCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="cerebras_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
cerebras_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Cerebras models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Cerebras API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.cerebras.ai/v1",
|
||||
description="The URL for the Cerebras API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.cerebras.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -25,8 +25,8 @@ class DatabricksImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
||||
url: str = "${env.DATABRICKS_URL:=}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksImplConfig(BaseModel):
|
||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
|
|
@ -23,7 +24,7 @@ class FireworksImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import FireworksCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .fireworks import FireworksCompatInferenceAdapter
|
||||
|
||||
adapter = FireworksCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class FireworksProviderDataValidator(BaseModel):
|
||||
fireworks_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Fireworks models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Fireworks API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..fireworks.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: FireworksCompatConfig
|
||||
|
||||
def __init__(self, config: FireworksCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="fireworks_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -26,7 +26,7 @@ class GeminiConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="gemini",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash",
|
||||
"gemini/gemini-2.5-flash",
|
||||
"gemini/gemini-2.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
|
@ -23,7 +23,7 @@ MODEL_ENTRIES = (
|
|||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
provider_model_id="text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="groq",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
|
|
@ -96,7 +97,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
tool_choice = "required"
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id.replace("groq/", ""),
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
|
|
|||
|
|
@ -14,19 +14,19 @@ SAFETY_MODELS_ENTRIES = []
|
|||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-8b-8192",
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"groq/llama-3.1-8b-instant",
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-70b-8192",
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
|
|
@ -34,23 +34,15 @@ MODEL_ENTRIES = [
|
|||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.2-3b-preview",
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-scout-17b-16e-instruct",
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import GroqCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqCompatInferenceAdapter
|
||||
|
||||
adapter = GroqCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.groq.com/openai/v1",
|
||||
description="The URL for the Groq API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.groq.com/openai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..groq.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: GroqCompatConfig
|
||||
|
||||
def __init__(self, config: GroqCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -5,55 +5,53 @@
|
|||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
"""
|
||||
Llama API Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
_config: LlamaCompatConfig
|
||||
|
||||
def __init__(self, config: LlamaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="meta_llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Check if a specific model is available from Llama API.
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
:return: The Llama API base URL
|
||||
"""
|
||||
try:
|
||||
llama_api_client = self._get_llama_api_client()
|
||||
retrieved_model = await llama_api_client.models.retrieve(model)
|
||||
logger.info(f"Model {retrieved_model.id} is available from Llama API")
|
||||
return True
|
||||
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available from Llama API")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability from Llama API: {e}")
|
||||
return False
|
||||
return self.config.openai_compat_api_base
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
||||
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
|
||||
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,8 @@
|
|||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
|
||||
from openai import APIConnectionError, BadRequestError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -28,12 +27,6 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
|
@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
|
@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability(). It also
|
||||
must come before Inference to ensure that OpenAIMixin methods are available
|
||||
in the Inference interface.
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||
"""
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
|
@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
self._config = config
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
Check if a specific model is available.
|
||||
Get the API key for OpenAI mixin.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
try:
|
||||
await self._client.models.retrieve(model)
|
||||
return True
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability: {e}")
|
||||
return False
|
||||
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncOpenAI:
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Returns an OpenAI client for the configured NVIDIA API endpoint.
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: An OpenAI client
|
||||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
|
||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
if not self.model_store:
|
||||
raise RuntimeError("Model store is not set")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {model_id} is unknown")
|
||||
return model.provider_model_id
|
||||
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.completions.create(**request)
|
||||
response = await self.client.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._client.embeddings.create(
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
|
|
@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
response = await self.client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._client.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._client.chat.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
|
|
|||
|
|
@ -6,13 +6,17 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
|
|
@ -91,23 +92,92 @@ class OllamaInferenceAdapter(
|
|||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
self.url = config.url
|
||||
self.config = config
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
||||
self._openai_client = None
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.url)
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
|
||||
if self._openai_client is None:
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
|
||||
return self._openai_client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
health_response = await self.health()
|
||||
if health_response["status"] == HealthStatus.ERROR:
|
||||
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
logger.warning(
|
||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
provider_id = self.__provider_id__
|
||||
response = await self.client.list()
|
||||
|
||||
# always add the two embedding models which can be pulled on demand
|
||||
models = [
|
||||
Model(
|
||||
identifier="all-minilm:l6-v2",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
# add all-minilm alias
|
||||
Model(
|
||||
identifier="all-minilm",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if "bert" in m.details.family:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
|
|
@ -124,7 +194,7 @@ class OllamaInferenceAdapter(
|
|||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
self._clients.clear()
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
@ -350,12 +420,7 @@ class OllamaInferenceAdapter(
|
|||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError("Model provider_resource_id cannot be None")
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
# TODO: you should pull here only if the model is not found in a list
|
||||
response = await self.client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
|
|
@ -365,9 +430,9 @@ class OllamaInferenceAdapter(
|
|||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m.model for m in response.models]
|
||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id is None:
|
||||
provider_resource_id = model.provider_resource_id
|
||||
|
||||
provider_resource_id = model.provider_resource_id
|
||||
assert provider_resource_id is not None # mypy
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
|
|
@ -375,7 +440,9 @@ class OllamaInferenceAdapter(
|
|||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise UnsupportedModelError(model.provider_resource_id, available_models)
|
||||
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||
|
||||
# mutating this should be considered an anti-pattern
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
|
|||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.OPENAI_API_KEY:=}",
|
||||
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,11 +12,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
# the models w/ "openai/" prefix are the litellm specific model names.
|
||||
# they should be deprecated in favor of the canonical openai model names.
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openai/chatgpt-4o-latest",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
|
|
@ -43,8 +38,6 @@ class EmbeddingModelInfo:
|
|||
|
||||
|
||||
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
|
||||
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,23 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI, NotFoundError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
|
@ -30,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
#
|
||||
# This OpenAI adapter implements Inference methods using two clients -
|
||||
# This OpenAI adapter implements Inference methods using two mixins -
|
||||
#
|
||||
# | Inference Method | Implementation Source |
|
||||
# |----------------------------|--------------------------|
|
||||
|
|
@ -39,15 +25,27 @@ logger = logging.getLogger(__name__)
|
|||
# | embedding | LiteLLMOpenAIMixin |
|
||||
# | batch_completion | LiteLLMOpenAIMixin |
|
||||
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
||||
# | openai_completion | AsyncOpenAI |
|
||||
# | openai_chat_completion | AsyncOpenAI |
|
||||
# | openai_embeddings | AsyncOpenAI |
|
||||
# | openai_completion | OpenAIMixin |
|
||||
# | openai_chat_completion | OpenAIMixin |
|
||||
# | openai_embeddings | OpenAIMixin |
|
||||
#
|
||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
"""
|
||||
OpenAI Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="openai",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
|
|
@ -60,191 +58,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Check if a specific model is available from OpenAI.
|
||||
Get the OpenAI API base URL.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
try:
|
||||
openai_client = self._get_openai_client()
|
||||
retrieved_model = await openai_client.models.retrieve(model)
|
||||
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
|
||||
return True
|
||||
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available from OpenAI")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability from OpenAI: {e}")
|
||||
return False
|
||||
return self.config.base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if guided_choice is not None:
|
||||
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||
if prompt_logprobs is not None:
|
||||
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
|
||||
# Prepare parameters for OpenAI embeddings API
|
||||
params = {
|
||||
"model": model_id,
|
||||
"input": input,
|
||||
}
|
||||
|
||||
if encoding_format is not None:
|
||||
params["encoding_format"] = encoding_format
|
||||
if dimensions is not None:
|
||||
params["dimensions"] = dimensions
|
||||
if user is not None:
|
||||
params["user"] = user
|
||||
|
||||
# Call OpenAI embeddings API
|
||||
response = await self._get_openai_client().embeddings.create(**params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class SambaNovaImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -9,49 +9,20 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-405B-Instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.2-1B-Instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.2-3B-Instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.3-70B-Instruct",
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-3.2-11B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-3.2-90B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -182,6 +182,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="sambanova",
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import SambaNovaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .sambanova import SambaNovaCompatInferenceAdapter
|
||||
|
||||
adapter = SambaNovaCompatInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for SambaNova models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.sambanova_openai_compat.config import SambaNovaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..sambanova.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: SambaNovaCompatConfig
|
||||
|
||||
def __init__(self, config: SambaNovaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
@ -19,7 +19,7 @@ class TGIImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.TGI_URL}",
|
||||
url: str = "${env.TGI_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -305,6 +305,8 @@ class _HfAdapter(
|
|||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url,
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherImplConfig(BaseModel):
|
||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
|
|
@ -26,5 +27,5 @@ class TogetherImplConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,15 +69,9 @@ MODEL_ENTRIES = [
|
|||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
additional_aliases=[
|
||||
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
],
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
additional_aliases=[
|
||||
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
],
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue